diff --git a/src/circuit/boolean.rs b/src/circuit/boolean.rs index 239d404..f06fc99 100644 --- a/src/circuit/boolean.rs +++ b/src/circuit/boolean.rs @@ -375,35 +375,44 @@ impl Boolean { where E: Engine, CS: ConstraintSystem { - let c = Self::xor(&mut cs, a, b)?; - - match c { - Boolean::Constant(false) => { - Ok(()) + match (a, b) { + (&Boolean::Constant(a), &Boolean::Constant(b)) => { + if a == b { + Ok(()) + } else { + Err(SynthesisError::Unsatisfiable) + } }, - Boolean::Constant(true) => { - Err(SynthesisError::Unsatisfiable) - }, - Boolean::Is(ref res) => { + (&Boolean::Constant(true), a) | (a, &Boolean::Constant(true)) => { cs.enforce( - || "enforce equals zero", + || "enforce equal to one", |lc| lc, |lc| lc, - |lc| lc + res.get_variable() + |lc| lc + CS::one() - &a.lc(CS::one(), E::Fr::one()) ); Ok(()) }, - Boolean::Not(ref res) => { + (&Boolean::Constant(false), a) | (a, &Boolean::Constant(false)) => { cs.enforce( - || "enforce equals one", + || "enforce equal to zero", |lc| lc, |lc| lc, - |lc| lc + CS::one() - res.get_variable() + |_| a.lc(CS::one(), E::Fr::one()) ); Ok(()) }, + (a, b) => { + cs.enforce( + || "enforce equal", + |lc| lc, + |lc| lc, + |_| a.lc(CS::one(), E::Fr::one()) - &b.lc(CS::one(), E::Fr::one()) + ); + + Ok(()) + } } } @@ -636,24 +645,88 @@ mod test { for b_bool in [false, true].iter().cloned() { for a_neg in [false, true].iter().cloned() { for b_neg in [false, true].iter().cloned() { - let mut cs = TestConstraintSystem::::new(); + { + let mut cs = TestConstraintSystem::::new(); - let mut a = Boolean::from(AllocatedBit::alloc(cs.namespace(|| "a"), Some(a_bool)).unwrap()); - let mut b = Boolean::from(AllocatedBit::alloc(cs.namespace(|| "b"), Some(b_bool)).unwrap()); + let mut a = Boolean::from(AllocatedBit::alloc(cs.namespace(|| "a"), Some(a_bool)).unwrap()); + let mut b = Boolean::from(AllocatedBit::alloc(cs.namespace(|| "b"), Some(b_bool)).unwrap()); - if a_neg { - a = a.not(); + if a_neg { + a = a.not(); + } + if b_neg { + b = b.not(); + } + + Boolean::enforce_equal(&mut cs, &a, &b).unwrap(); + + assert_eq!( + cs.is_satisfied(), + (a_bool ^ a_neg) == (b_bool ^ b_neg) + ); } - if b_neg { - b = b.not(); + { + let mut cs = TestConstraintSystem::::new(); + + let mut a = Boolean::Constant(a_bool); + let mut b = Boolean::from(AllocatedBit::alloc(cs.namespace(|| "b"), Some(b_bool)).unwrap()); + + if a_neg { + a = a.not(); + } + if b_neg { + b = b.not(); + } + + Boolean::enforce_equal(&mut cs, &a, &b).unwrap(); + + assert_eq!( + cs.is_satisfied(), + (a_bool ^ a_neg) == (b_bool ^ b_neg) + ); } + { + let mut cs = TestConstraintSystem::::new(); - Boolean::enforce_equal(&mut cs, &a, &b).unwrap(); + let mut a = Boolean::from(AllocatedBit::alloc(cs.namespace(|| "a"), Some(a_bool)).unwrap()); + let mut b = Boolean::Constant(b_bool); - assert_eq!( - cs.is_satisfied(), - (a_bool ^ a_neg) == (b_bool ^ b_neg) - ); + if a_neg { + a = a.not(); + } + if b_neg { + b = b.not(); + } + + Boolean::enforce_equal(&mut cs, &a, &b).unwrap(); + + assert_eq!( + cs.is_satisfied(), + (a_bool ^ a_neg) == (b_bool ^ b_neg) + ); + } + { + let mut cs = TestConstraintSystem::::new(); + + let mut a = Boolean::Constant(a_bool); + let mut b = Boolean::Constant(b_bool); + + if a_neg { + a = a.not(); + } + if b_neg { + b = b.not(); + } + + let result = Boolean::enforce_equal(&mut cs, &a, &b); + + if (a_bool ^ a_neg) == (b_bool ^ b_neg) { + assert!(result.is_ok()); + assert!(cs.is_satisfied()); + } else { + assert!(result.is_err()); + } + } } } }