From e24fcfdc5cdec213ac185990b912947bafc84c3e Mon Sep 17 00:00:00 2001 From: Sean Bowe Date: Thu, 28 Jan 2016 20:37:54 -0700 Subject: [PATCH] Added primitive circuit abstraction, tests for sha3. --- src/bit.rs | 87 +++++++++++++++++++---- src/circuit.rs | 148 ++++++++++++++++++++++++++++++++++++++++ src/keccak.rs | 54 +++++++++++++-- src/main.rs | 33 +-------- src/variable.rs | 18 +++-- tinysnark/src/r1cs.rs | 10 +++ tinysnark/tinysnark.cpp | 12 ++++ 7 files changed, 309 insertions(+), 53 deletions(-) create mode 100644 src/circuit.rs diff --git a/src/bit.rs b/src/bit.rs index 49e93d0..dfe19af 100644 --- a/src/bit.rs +++ b/src/bit.rs @@ -5,6 +5,7 @@ use std::cell::RefCell; use super::variable::*; use self::Bit::*; use self::Op::*; +use super::circuit::*; macro_rules! mirror { ($a:pat, $b:pat) => (($a, $b) | ($b, $a)) @@ -206,6 +207,59 @@ pub enum Bit { Bin(BinaryOp, bool) } +struct BitEquality { + a: Bit, + b: Var +} + +impl Constrainable for BitEquality { + type Result = Var; + + fn synthesize(&self, enforce: &Bit) -> Var { + // TODO: currently only support unconditional enforcement + match enforce { + &Bit::Constant(true) => {}, + _ => unimplemented!() + } + + match self.a { + Bin(ref binop, inverted) => { + // TODO: figure this out later + assert!(binop.resolved.borrow().is_none()); + + let mut op = binop.op; + + if inverted { + op = op.not(); + } + + gadget(&[&binop.a, &binop.b, &self.b], 0, move |vals| { + let a = vals.get_input(0); + let b = vals.get_input(1); + + unsafe { vals.set_input(2, op.val(a, b)) }; + }, |i, o, cs| { + cs.push(binaryop_constraint(i[0], i[1], i[2], op)); + + vec![i[2]] + }).remove(0) + }, + _ => unimplemented!() + } + } +} + +impl Equals for Bit { + type Result = BitEquality; + + fn must_equal(&self, other: &Var) -> BitEquality { + BitEquality { + a: self.clone(), + b: other.clone() + } + } +} + fn binaryop_constraint(a: &Var, b: &Var, c: &Var, op: Op) -> Constraint { match op { // a * b = c @@ -286,6 +340,24 @@ fn resolve(a: &Var, b: &Var, op: Op) -> Var { }).remove(0) } +impl ConstraintWalker for Bit { + fn walk(&self, counter: &mut usize, constraints: &mut Vec, witness_map: &mut WitnessMap) + { + match *self { + Constant(_) => {}, + Not(ref v) => { + v.walk(counter, constraints, witness_map); + }, + Is(ref v) => { + v.walk(counter, constraints, witness_map); + }, + Bin(ref bin, _) => { + bin.walk(counter, constraints, witness_map); + } + } + } +} + impl Bit { pub fn val(&self, map: &[FieldT]) -> bool { match *self { @@ -304,21 +376,6 @@ impl Bit { } } - pub fn walk(&self, counter: &mut usize, constraints: &mut Vec, witness_map: &mut WitnessMap) { - match *self { - Constant(_) => {}, - Not(ref v) => { - v.walk(counter, constraints, witness_map); - }, - Is(ref v) => { - v.walk(counter, constraints, witness_map); - }, - Bin(ref bin, _) => { - bin.walk(counter, constraints, witness_map); - } - } - } - pub fn new(v: &Var) -> Bit { Is(gadget(&[v], 0, |_| {}, |i, o, cs| { // boolean constraint: diff --git a/src/circuit.rs b/src/circuit.rs new file mode 100644 index 0000000..c2f504e --- /dev/null +++ b/src/circuit.rs @@ -0,0 +1,148 @@ +use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem}; +use super::variable::{Var,Constraint,WitnessMap,witness_field_elements}; +use super::bit::Bit; + +pub trait ConstraintWalker: 'static { + fn walk(&self, + counter: &mut usize, + constraints: &mut Vec, + witness_map: &mut WitnessMap); +} + +impl ConstraintWalker for Vec { + fn walk(&self, + counter: &mut usize, + constraints: &mut Vec, + witness_map: &mut WitnessMap) + { + for i in self { + i.walk(counter, constraints, witness_map); + } + } +} + +pub trait Constrainable { + type Result: ConstraintWalker; + + fn synthesize(&self, enforce: &Bit) -> Self::Result; +} + +impl Constrainable for Vec { + type Result = Vec; + + fn synthesize(&self, enforce: &Bit) -> Vec { + self.iter().map(|a| a.synthesize(enforce)).collect() + } +} + +pub trait Equals { + type Result: Constrainable; + + fn must_equal(&self, other: &Rhs) -> Self::Result; +} + +impl Equals<[Rhs]> for [Lhs] where Lhs: Equals { + type Result = Vec; + + fn must_equal(&self, other: &[Rhs]) -> Vec { + assert_eq!(self.len(), other.len()); + + self.iter().zip(other.iter()).map(|(a, b)| a.must_equal(b)).collect() + } +} + +pub struct Circuit { + public_inputs: usize, + private_inputs: usize, + aux_inputs: usize, + keypair: Keypair, + witness_map: WitnessMap +} + +impl Circuit { + pub fn verify(&self, proof: &Proof, public: &[FieldT]) -> bool + { + proof.verify(&self.keypair, public) + } + + pub fn prove(&self, public: &[FieldT], private: &[FieldT]) -> Result + { + assert_eq!(public.len(), self.public_inputs); + assert_eq!(private.len(), self.private_inputs); + + let mut vars = Vec::new(); + vars.push(FieldT::one()); + vars.extend_from_slice(public); + vars.extend_from_slice(private); + + for i in 0..self.aux_inputs { + vars.push(FieldT::zero()); + } + + witness_field_elements(&mut vars, &self.witness_map); + + let primary = &vars[1..public.len()+1]; + let aux = &vars[1+public.len()..]; + + if !self.keypair.is_satisfied(primary, aux) { + return Err(()) + } + + Ok(Proof::new(&self.keypair, primary, aux)) + } +} + +pub struct CircuitBuilder { + public_inputs: usize, + private_inputs: usize, + constraints: Vec> +} + +impl CircuitBuilder { + pub fn new(num_public: usize, num_private: usize) -> (Vec, Vec, CircuitBuilder) { + ( + (0..num_public).map(|x| Var::new(1+x)).collect(), + (0..num_private).map(|x| Var::new(1+num_public+x)).collect(), + CircuitBuilder { + public_inputs: num_public, + private_inputs: num_private, + constraints: Vec::new() + }, + ) + } + + pub fn constrain(&mut self, constraint: C) { + self.constraints.push(Box::new(constraint.synthesize(&Bit::constant(true)))); + } + + pub fn finalize(self) -> Circuit { + let mut counter = 1 + self.public_inputs + self.private_inputs; + let mut constraints = vec![]; + let mut witness_map = WitnessMap::new(); + + for c in self.constraints.into_iter() { + c.walk(&mut counter, &mut constraints, &mut witness_map); + } + + let mut cs = ConstraintSystem::new(self.public_inputs, (counter - 1) - self.public_inputs); + + for Constraint(a, b, c) in constraints { + let a: Vec<_> = a.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect(); + let b: Vec<_> = b.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect(); + let c: Vec<_> = c.into_iter().map(|x| LinearTerm { coeff: x.0, index: x.1.index() }).collect(); + + cs.add_constraint(&a, &b, &c); + } + + let kp = Keypair::new(&cs); + + Circuit { + public_inputs: self.public_inputs, + private_inputs: self.private_inputs, + aux_inputs: ((counter - 1) - self.public_inputs) - self.private_inputs, + keypair: kp, + witness_map: witness_map + } + } +} + diff --git a/src/keccak.rs b/src/keccak.rs index 9439060..87b2603 100644 --- a/src/keccak.rs +++ b/src/keccak.rs @@ -191,13 +191,13 @@ fn keccakf(st: &mut [Byte], rounds: usize) } } -pub fn sha3_256(message: &[Byte]) -> Vec { +pub fn sha3_256(message: &[Byte]) -> Vec { // As defined by FIPS202 keccak(1088, 512, message, 0x06, 32, 24) } fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8, mut mdlen: usize, num_rounds: usize) - -> Vec + -> Vec { use std::cmp::min; @@ -249,11 +249,15 @@ fn keccak(rate: usize, capacity: usize, mut input: &[Byte], delimited_suffix: u8 } } - output + output.into_iter().flat_map(|byte| byte.bits.into_iter()).collect() } #[test] fn test_sha3_256() { + use super::circuit::{CircuitBuilder,Equals}; + use super::variable::Var; + use tinysnark::{self,FieldT}; + let test_vector: Vec<(Vec, [u8; 32])> = vec![ (vec![0xff], [0x44,0x4b,0x89,0xec,0xce,0x39,0x5a,0xec,0x5d,0xc9,0x8f,0x19,0xde,0xfd,0x3a,0x23,0xbc,0xa0,0x82,0x2f,0xc7,0x22,0x26,0xf5,0x8c,0xa4,0x6a,0x17,0xee,0xec,0xa4,0x42] @@ -289,7 +293,11 @@ fn test_sha3_256() { for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() { let message: Vec = message.iter().map(|a| Byte::new(*a)).collect(); - let result: Vec = sha3_256(&message).into_iter().map(|a| a.unwrap_constant()).collect(); + let result: Vec = sha3_256(&message) + .chunks(8) + .map(|a| Byte::from(a)) + .map(|a| a.unwrap_constant()) + .collect(); if &*result != expected { print!("Got: "); @@ -306,6 +314,44 @@ fn test_sha3_256() { println!("--- HASH {} SUCCESS ---", i+1); } } + + tinysnark::init(); + + for (i, &(ref message, ref expected)) in test_vector.iter().enumerate() { + fn into_bytes(a: &[Var]) -> Vec { + let a: Vec<_> = a.into_iter().map(|a| Bit::new(a)).collect(); + + a.chunks(8).map(|a| Byte::from(a)).collect() + } + + fn into_fieldt(a: &[u8], vars: &mut [FieldT]) { + let mut counter = 0; + + for byte in a { + for bit in (0..8).map(|i| byte & (1 << i) != 0).rev() { + if bit { vars[counter] = FieldT::one() } else { vars[counter] = FieldT::zero() } + counter += 1; + } + } + } + + let (public, private, mut circuit) = CircuitBuilder::new(expected.len() * 8, message.len() * 8); + + let private = into_bytes(&private); + + circuit.constrain(sha3_256(&private).must_equal(&public)); + + let circuit = circuit.finalize(); + + let mut input: Vec = (0..message.len() * 8).map(|_| FieldT::zero()).collect(); + let mut output: Vec = (0..expected.len() * 8).map(|_| FieldT::zero()).collect(); + + into_fieldt(message, &mut input); + into_fieldt(expected, &mut output); + + let proof = circuit.prove(&output, &input).unwrap(); + assert!(circuit.verify(&proof, &output)); + } } #[derive(Clone)] diff --git a/src/main.rs b/src/main.rs index 7390fab..84773b6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,42 +5,15 @@ extern crate rand; use tinysnark::{Proof, Keypair, FieldT, LinearTerm, ConstraintSystem}; use variable::*; +use circuit::*; use keccak::*; use bit::*; mod variable; mod keccak; mod bit; +mod circuit; fn main() { - tinysnark::init(); - let inbytes = 64; - //for inbits in 0..1024 { - let inbits = inbytes * 8; - let input: Vec = (0..inbits).map(|i| Bit::new(&Var::new(i+1))).collect(); - let input: Vec = input.chunks(8).map(|c| Byte::from(c)).collect(); - - let output = sha3_256(&input); - - let mut counter = 1 + (8*input.len()); - let mut constraints = vec![]; - let mut witness_map = WitnessMap::new(); - - for o in output.iter().flat_map(|e| e.bits().into_iter()) { - o.walk(&mut counter, &mut constraints, &mut witness_map); - } - - let mut vars: Vec = (0..counter).map(|_| FieldT::zero()).collect(); - vars[0] = FieldT::one(); - - witness_field_elements(&mut vars, &witness_map); - - for b in output.iter().flat_map(|e| e.bits()) { - print!("{}", if b.val(&vars) { 1 } else { 0 }); - } - println!(""); - - println!("{}: {} constraints", inbits, constraints.len()); - //} -} +} \ No newline at end of file diff --git a/src/variable.rs b/src/variable.rs index b83c478..cdfeff4 100644 --- a/src/variable.rs +++ b/src/variable.rs @@ -2,6 +2,7 @@ use tinysnark::FieldT; use std::cell::Cell; use std::rc::Rc; use std::collections::BTreeMap; +use super::circuit::ConstraintWalker; pub type WitnessMap = BTreeMap, Vec, Rc)>>; @@ -21,6 +22,14 @@ impl<'a> VariableView<'a> { pub fn get_input(&self, index: usize) -> FieldT { self.vars[self.inputs[index]] } + + /// Sets the value of an input variable. This is unsafe + /// because theoretically this should not be necessary, + /// and could cause soundness problems, but I've temporarily + /// done this to make testing easier. + pub fn set_input(&mut self, index: usize, to: FieldT) { + self.vars[self.inputs[index]] = to; + } } use std::collections::Bound::Unbounded; @@ -102,9 +111,8 @@ impl Var { } } - // make this not public or unsafe too - pub fn index(&self) -> Rc> { - self.index.clone() + pub fn index(&self) -> usize { + self.index.get() } pub fn val(&self, map: &[FieldT]) -> FieldT { @@ -119,8 +127,10 @@ impl Var { Some(ref g) => g.group } } +} - pub fn walk(&self, counter: &mut usize, constraints: &mut Vec, witness_map: &mut WitnessMap) { +impl ConstraintWalker for Var { + fn walk(&self, counter: &mut usize, constraints: &mut Vec, witness_map: &mut WitnessMap) { match self.gadget { None => {}, Some(ref g) => g.walk(counter, constraints, witness_map) diff --git a/tinysnark/src/r1cs.rs b/tinysnark/src/r1cs.rs index fffa77e..9916003 100644 --- a/tinysnark/src/r1cs.rs +++ b/tinysnark/src/r1cs.rs @@ -88,6 +88,15 @@ impl Keypair { aux_size: constraint_system.aux_size } } + + pub fn is_satisfied(&self, primary: &[FieldT], aux: &[FieldT]) -> bool { + assert_eq!(primary.len(), self.primary_size); + assert_eq!(aux.len(), self.aux_size); + + unsafe { + tinysnark_keypair_satisfies_test(self.kp, primary.get_unchecked(0), aux.get_unchecked(0)) + } + } } impl Drop for Keypair { @@ -99,6 +108,7 @@ impl Drop for Keypair { extern "C" { fn tinysnark_gen_keypair(cs: *mut R1ConstraintSystem) -> *mut R1CSKeypair; fn tinysnark_drop_keypair(cs: *mut R1CSKeypair); + fn tinysnark_keypair_satisfies_test(kp: *mut R1CSKeypair, primary: *const FieldT, aux: *const FieldT) -> bool; } #[repr(C)] diff --git a/tinysnark/tinysnark.cpp b/tinysnark/tinysnark.cpp index 81a62c9..0a01465 100644 --- a/tinysnark/tinysnark.cpp +++ b/tinysnark/tinysnark.cpp @@ -78,6 +78,18 @@ extern "C" void tinysnark_drop_r1cs(void * ics) { delete cs; } +extern "C" bool tinysnark_keypair_satisfies_test(void * kp, FieldT* primary, FieldT* aux) +{ + r1cs_ppzksnark_keypair* keypair = static_cast*>(kp); + + r1cs_constraint_system* cs = &keypair->pk.constraint_system; + + r1cs_primary_input primary_input(primary, primary+(cs->primary_input_size)); + r1cs_auxiliary_input aux_input(aux, aux+(cs->auxiliary_input_size)); + + return cs->is_valid() && cs->is_satisfied(primary_input, aux_input); +} + extern "C" bool tinysnark_satisfy_test(void * ics, FieldT* primary, FieldT* aux) { r1cs_constraint_system* cs = static_cast*>(ics);