From 232f0a50b8c95a32fc2dc557179b5cd7f257cf45 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Sat, 28 Mar 2020 12:02:32 +1300 Subject: [PATCH] ff: Rework BitIterator to work with both u8 and u64 limb sizes This enables BitIterator to be used with both the byte encoding and limb representation of scalars. --- bellman/src/gadgets/boolean.rs | 4 +- bellman/src/gadgets/num.rs | 8 ++-- ff/src/lib.rs | 46 ++++++++++++++++++++--- pairing/src/bls12_381/ec.rs | 27 +++++++++---- pairing/src/bls12_381/mod.rs | 4 +- zcash_primitives/src/jubjub/edwards.rs | 2 +- zcash_primitives/src/jubjub/fs.rs | 7 +--- zcash_primitives/src/jubjub/montgomery.rs | 2 +- zcash_primitives/src/sapling.rs | 4 +- zcash_proofs/src/circuit/ecc.rs | 4 +- zcash_proofs/src/circuit/sapling.rs | 8 ++-- 11 files changed, 80 insertions(+), 36 deletions(-) diff --git a/bellman/src/gadgets/boolean.rs b/bellman/src/gadgets/boolean.rs index d3c882d..f117681 100644 --- a/bellman/src/gadgets/boolean.rs +++ b/bellman/src/gadgets/boolean.rs @@ -313,12 +313,12 @@ pub fn field_into_allocated_bits_le, F: // Deconstruct in big-endian bit order let values = match value { Some(ref value) => { - let mut field_char = BitIterator::new(F::char()); + let mut field_char = BitIterator::::new(F::char()); let mut tmp = Vec::with_capacity(F::NUM_BITS as usize); let mut found_one = false; - for b in BitIterator::new(value.into_repr()) { + for b in BitIterator::::new(value.into_repr()) { // Skip leading bits found_one |= field_char.next().unwrap(); if !found_one { diff --git a/bellman/src/gadgets/num.rs b/bellman/src/gadgets/num.rs index e460d20..f8ce6d3 100644 --- a/bellman/src/gadgets/num.rs +++ b/bellman/src/gadgets/num.rs @@ -103,7 +103,9 @@ impl AllocatedNum { // We want to ensure that the bit representation of a is // less than or equal to r - 1. - let mut a = self.value.map(|e| BitIterator::new(e.into_repr())); + let mut a = self + .value + .map(|e| BitIterator::::new(e.into_repr())); let mut b = E::Fr::char(); b.sub_noborrow(&1.into()); @@ -115,7 +117,7 @@ impl AllocatedNum { let mut found_one = false; let mut i = 0; - for b in BitIterator::new(b) { + for b in BitIterator::::new(b) { let a_bit = a.as_mut().map(|e| e.next().unwrap()); // Skip over unset bits at the beginning @@ -558,7 +560,7 @@ mod test { assert!(cs.is_satisfied()); - for (b, a) in BitIterator::new(r.into_repr()) + for (b, a) in BitIterator::::new(r.into_repr()) .skip(1) .zip(bits.iter().rev()) { diff --git a/ff/src/lib.rs b/ff/src/lib.rs index 8bb8ffa..e91210f 100644 --- a/ff/src/lib.rs +++ b/ff/src/lib.rs @@ -13,6 +13,7 @@ extern crate std; pub use ff_derive::*; use core::fmt; +use core::marker::PhantomData; use core::ops::{Add, AddAssign, BitAnd, Mul, MulAssign, Neg, Shr, Sub, SubAssign}; use rand_core::RngCore; #[cfg(feature = "std")] @@ -338,20 +339,25 @@ pub trait ScalarEngine: Sized + 'static + Clone { } #[derive(Debug)] -pub struct BitIterator { +pub struct BitIterator> { t: E, n: usize, + _limb: PhantomData, } -impl> BitIterator { +impl> BitIterator { pub fn new(t: E) -> Self { let n = t.as_ref().len() * 64; - BitIterator { t, n } + BitIterator { + t, + n, + _limb: PhantomData::default(), + } } } -impl> Iterator for BitIterator { +impl> Iterator for BitIterator { type Item = bool; fn next(&mut self) -> Option { @@ -367,9 +373,37 @@ impl> Iterator for BitIterator { } } +impl> BitIterator { + pub fn new(t: E) -> Self { + let n = t.as_ref().len() * 8; + + BitIterator { + t, + n, + _limb: PhantomData::default(), + } + } +} + +impl> Iterator for BitIterator { + type Item = bool; + + fn next(&mut self) -> Option { + if self.n == 0 { + None + } else { + self.n -= 1; + let part = self.n / 8; + let bit = self.n - (8 * part); + + Some(self.t.as_ref()[part] & (1 << bit) > 0) + } + } +} + #[test] fn test_bit_iterator() { - let mut a = BitIterator::new([0xa953_d79b_83f6_ab59, 0x6dea_2059_e200_bd39]); + let mut a = BitIterator::::new([0xa953_d79b_83f6_ab59, 0x6dea_2059_e200_bd39]); let expected = "01101101111010100010000001011001111000100000000010111101001110011010100101010011110101111001101110000011111101101010101101011001"; for e in expected.chars() { @@ -380,7 +414,7 @@ fn test_bit_iterator() { let expected = "1010010101111110101010000101101011101000011101110101001000011001100100100011011010001011011011010001011011101100110100111011010010110001000011110100110001100110011101101000101100011100100100100100001010011101010111110011101011000011101000111011011101011001"; - let mut a = BitIterator::new([ + let mut a = BitIterator::::new([ 0x429d_5f3a_c3a3_b759, 0xb10f_4c66_768b_1c92, 0x9236_8b6d_16ec_d3b4, diff --git a/pairing/src/bls12_381/ec.rs b/pairing/src/bls12_381/ec.rs index 2dae1ea..42bd91e 100644 --- a/pairing/src/bls12_381/ec.rs +++ b/pairing/src/bls12_381/ec.rs @@ -81,7 +81,18 @@ macro_rules! curve_impl { } impl $affine { - fn mul_bits>(&self, bits: BitIterator) -> $projective { + fn mul_bits_u64>(&self, bits: BitIterator) -> $projective { + let mut res = $projective::zero(); + for i in bits { + res.double(); + if i { + res.add_assign(self) + } + } + res + } + + fn mul_bits_u8>(&self, bits: BitIterator) -> $projective { let mut res = $projective::zero(); for i in bits { res.double(); @@ -172,8 +183,8 @@ macro_rules! curve_impl { } fn mul::Repr>>(&self, by: S) -> $projective { - let bits = BitIterator::new(by.into()); - self.mul_bits(bits) + let bits = BitIterator::::new(by.into()); + self.mul_bits_u64(bits) } fn into_projective(&self) -> $projective { @@ -655,7 +666,7 @@ macro_rules! curve_impl { let mut found_one = false; - for i in BitIterator::new(other.into()) { + for i in BitIterator::::new(other.into()) { if found_one { res.double(); } else { @@ -992,8 +1003,8 @@ pub mod g1 { impl G1Affine { fn scale_by_cofactor(&self) -> G1 { // G1 cofactor = (x - 1)^2 / 3 = 76329603384216526031706109802092473003 - let cofactor = BitIterator::new([0x8c00aaab0000aaab, 0x396c8c005555e156]); - self.mul_bits(cofactor) + let cofactor = BitIterator::::new([0x8c00aaab0000aaab, 0x396c8c005555e156]); + self.mul_bits_u64(cofactor) } fn get_generator() -> Self { @@ -1714,7 +1725,7 @@ pub mod g2 { fn scale_by_cofactor(&self) -> G2 { // G2 cofactor = (x^8 - 4 x^7 + 5 x^6) - (4 x^4 + 6 x^3 - 4 x^2 - 4 x + 13) // 9 // 0x5d543a95414e7f1091d50792876a202cd91de4547085abaa68a205b2e5a7ddfa628f1cb4d9e82ef21537e293a6691ae1616ec6e786f0c70cf1c38e31c7238e5 - let cofactor = BitIterator::new([ + let cofactor = BitIterator::::new([ 0xcf1c38e31c7238e5, 0x1616ec6e786f0c70, 0x21537e293a6691ae, @@ -1724,7 +1735,7 @@ pub mod g2 { 0x91d50792876a202, 0x5d543a95414e7f1, ]); - self.mul_bits(cofactor) + self.mul_bits_u64(cofactor) } fn perform_pairing(&self, other: &G1Affine) -> Fq12 { diff --git a/pairing/src/bls12_381/mod.rs b/pairing/src/bls12_381/mod.rs index afa3aaf..0f18053 100644 --- a/pairing/src/bls12_381/mod.rs +++ b/pairing/src/bls12_381/mod.rs @@ -82,7 +82,7 @@ impl Engine for Bls12 { let mut f = Fq12::one(); let mut found_one = false; - for i in BitIterator::new(&[BLS_X >> 1]) { + for i in BitIterator::::new(&[BLS_X >> 1]) { if !found_one { found_one = i; continue; @@ -324,7 +324,7 @@ impl G2Prepared { let mut r: G2 = q.into(); let mut found_one = false; - for i in BitIterator::new([BLS_X >> 1]) { + for i in BitIterator::::new([BLS_X >> 1]) { if !found_one { found_one = i; continue; diff --git a/zcash_primitives/src/jubjub/edwards.rs b/zcash_primitives/src/jubjub/edwards.rs index 1b3ebc0..549d441 100644 --- a/zcash_primitives/src/jubjub/edwards.rs +++ b/zcash_primitives/src/jubjub/edwards.rs @@ -468,7 +468,7 @@ impl Point { let mut res = Self::zero(); - for b in BitIterator::new(scalar.into()) { + for b in BitIterator::::new(scalar.into()) { res = res.double(params); if b { diff --git a/zcash_primitives/src/jubjub/fs.rs b/zcash_primitives/src/jubjub/fs.rs index 466d4c5..e163adc 100644 --- a/zcash_primitives/src/jubjub/fs.rs +++ b/zcash_primitives/src/jubjub/fs.rs @@ -1,4 +1,3 @@ -use byteorder::{ByteOrder, LittleEndian}; use ff::{ adc, mac_with_carry, sbb, BitIterator, Field, PowVartime, PrimeField, PrimeFieldDecodingError, PrimeFieldRepr, SqrtField, @@ -721,7 +720,7 @@ impl Fs { self.reduce(); } - fn mul_bits>(&self, bits: BitIterator) -> Self { + fn mul_bits>(&self, bits: BitIterator) -> Self { let mut res = Self::zero(); for bit in bits { res = res.double(); @@ -741,9 +740,7 @@ impl ToUniform for Fs { /// Random Oracle output. fn to_uniform(digest: &[u8]) -> Self { assert_eq!(digest.len(), 64); - let mut repr: [u64; 8] = [0; 8]; - LittleEndian::read_u64_into(digest, &mut repr); - Self::one().mul_bits(BitIterator::new(repr)) + Self::one().mul_bits(BitIterator::::new(digest)) } } diff --git a/zcash_primitives/src/jubjub/montgomery.rs b/zcash_primitives/src/jubjub/montgomery.rs index 9cad803..efdb29d 100644 --- a/zcash_primitives/src/jubjub/montgomery.rs +++ b/zcash_primitives/src/jubjub/montgomery.rs @@ -304,7 +304,7 @@ impl Point { let mut res = Self::zero(); - for b in BitIterator::new(scalar.into()) { + for b in BitIterator::::new(scalar.into()) { res = res.double(params); if b { diff --git a/zcash_primitives/src/sapling.rs b/zcash_primitives/src/sapling.rs index da8b838..ecf5cd4 100644 --- a/zcash_primitives/src/sapling.rs +++ b/zcash_primitives/src/sapling.rs @@ -21,7 +21,7 @@ pub const SAPLING_COMMITMENT_TREE_DEPTH: usize = 32; pub fn merkle_hash(depth: usize, lhs: &FrRepr, rhs: &FrRepr) -> FrRepr { let lhs = { let mut tmp = [false; 256]; - for (a, b) in tmp.iter_mut().rev().zip(BitIterator::new(lhs)) { + for (a, b) in tmp.iter_mut().rev().zip(BitIterator::::new(lhs)) { *a = b; } tmp @@ -29,7 +29,7 @@ pub fn merkle_hash(depth: usize, lhs: &FrRepr, rhs: &FrRepr) -> FrRepr { let rhs = { let mut tmp = [false; 256]; - for (a, b) in tmp.iter_mut().rev().zip(BitIterator::new(rhs)) { + for (a, b) in tmp.iter_mut().rev().zip(BitIterator::::new(rhs)) { *a = b; } tmp diff --git a/zcash_proofs/src/circuit/ecc.rs b/zcash_proofs/src/circuit/ecc.rs index 05baf8b..01ed2d4 100644 --- a/zcash_proofs/src/circuit/ecc.rs +++ b/zcash_proofs/src/circuit/ecc.rs @@ -769,7 +769,7 @@ mod test { let q = p.mul(s, params); let (x1, y1) = q.to_xy(); - let mut s_bits = BitIterator::new(s.into_repr()).collect::>(); + let mut s_bits = BitIterator::::new(s.into_repr()).collect::>(); s_bits.reverse(); s_bits.truncate(Fs::NUM_BITS as usize); @@ -822,7 +822,7 @@ mod test { y: num_y0, }; - let mut s_bits = BitIterator::new(s.into_repr()).collect::>(); + let mut s_bits = BitIterator::::new(s.into_repr()).collect::>(); s_bits.reverse(); s_bits.truncate(Fs::NUM_BITS as usize); diff --git a/zcash_proofs/src/circuit/sapling.rs b/zcash_proofs/src/circuit/sapling.rs index 9782a4f..c3ddde9 100644 --- a/zcash_proofs/src/circuit/sapling.rs +++ b/zcash_proofs/src/circuit/sapling.rs @@ -615,8 +615,8 @@ fn test_input_circuit_with_bls12_381() { ::std::mem::swap(&mut lhs, &mut rhs); } - let mut lhs: Vec = BitIterator::new(lhs.into_repr()).collect(); - let mut rhs: Vec = BitIterator::new(rhs.into_repr()).collect(); + let mut lhs: Vec = BitIterator::::new(lhs.into_repr()).collect(); + let mut rhs: Vec = BitIterator::::new(rhs.into_repr()).collect(); lhs.reverse(); rhs.reverse(); @@ -799,8 +799,8 @@ fn test_input_circuit_with_bls12_381_external_test_vectors() { ::std::mem::swap(&mut lhs, &mut rhs); } - let mut lhs: Vec = BitIterator::new(lhs.into_repr()).collect(); - let mut rhs: Vec = BitIterator::new(rhs.into_repr()).collect(); + let mut lhs: Vec = BitIterator::::new(lhs.into_repr()).collect(); + let mut rhs: Vec = BitIterator::::new(rhs.into_repr()).collect(); lhs.reverse(); rhs.reverse();