diff --git a/ff_derive/src/lib.rs b/ff_derive/src/lib.rs index e507cd7..9621548 100644 --- a/ff_derive/src/lib.rs +++ b/ff_derive/src/lib.rs @@ -1,4 +1,4 @@ -#![recursion_limit="1024"] +#![recursion_limit = "1024"] extern crate proc_macro; extern crate syn; @@ -6,39 +6,38 @@ extern crate syn; extern crate quote; extern crate num_bigint; -extern crate num_traits; extern crate num_integer; +extern crate num_traits; -use num_integer::Integer; -use num_traits::{Zero, One, ToPrimitive}; use num_bigint::BigUint; +use num_integer::Integer; +use num_traits::{One, ToPrimitive, Zero}; use std::str::FromStr; #[proc_macro_derive(PrimeField, attributes(PrimeFieldModulus, PrimeFieldGenerator))] -pub fn prime_field( - input: proc_macro::TokenStream -) -> proc_macro::TokenStream -{ +pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream { // Construct a string representation of the type definition let s = input.to_string(); - + // Parse the string representation let ast = syn::parse_derive_input(&s).unwrap(); // The struct we're deriving for is a wrapper around a "Repr" type we must construct. let repr_ident = fetch_wrapped_ident(&ast.body) - .expect("PrimeField derive only operates over tuple structs of a single item"); + .expect("PrimeField derive only operates over tuple structs of a single item"); // We're given the modulus p of the prime field let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs) - .expect("Please supply a PrimeFieldModulus attribute") - .parse().expect("PrimeFieldModulus should be a number"); + .expect("Please supply a PrimeFieldModulus attribute") + .parse() + .expect("PrimeFieldModulus should be a number"); // We may be provided with a generator of p - 1 order. It is required that this generator be quadratic // nonresidue. let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs) - .expect("Please supply a PrimeFieldGenerator attribute") - .parse().expect("PrimeFieldGenerator should be a number"); + .expect("Please supply a PrimeFieldGenerator attribute") + .parse() + .expect("PrimeFieldGenerator should be a number"); // The arithmetic in this library only works if the modulus*2 is smaller than the backing // representation. Compute the number of limbs we need. @@ -55,18 +54,21 @@ pub fn prime_field( let mut gen = quote::Tokens::new(); gen.append(prime_field_repr_impl(&repr_ident, limbs)); - gen.append(prime_field_constants_and_sqrt(&ast.ident, &repr_ident, modulus, limbs, generator)); + gen.append(prime_field_constants_and_sqrt( + &ast.ident, + &repr_ident, + modulus, + limbs, + generator, + )); gen.append(prime_field_impl(&ast.ident, &repr_ident, limbs)); - + // Return the generated impl gen.parse().unwrap() } /// Fetches the ident being wrapped by the type we're deriving. -fn fetch_wrapped_ident( - body: &syn::Body -) -> Option -{ +fn fetch_wrapped_ident(body: &syn::Body) -> Option { match body { &syn::Body::Struct(ref variant_data) => { let fields = variant_data.fields(); @@ -76,11 +78,11 @@ fn fetch_wrapped_ident( if path.segments.len() == 1 { return Some(path.segments[0].ident.clone()); } - }, + } _ => {} } } - }, + } _ => {} }; @@ -88,22 +90,14 @@ fn fetch_wrapped_ident( } /// Fetch an attribute string from the derived struct. -fn fetch_attr( - name: &str, - attrs: &[syn::Attribute] -) -> Option -{ +fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option { for attr in attrs { if attr.name() == name { match attr.value { - syn::MetaItem::NameValue(_, ref val) => { - match val { - &syn::Lit::Str(ref s, _) => { - return Some(s.clone()) - }, - _ => { - panic!("attribute {} should be a string", name); - } + syn::MetaItem::NameValue(_, ref val) => match val { + &syn::Lit::Str(ref s, _) => return Some(s.clone()), + _ => { + panic!("attribute {} should be a string", name); } }, _ => { @@ -117,11 +111,7 @@ fn fetch_attr( } // Implement PrimeFieldRepr for the wrapped ident `repr` with `limbs` limbs. -fn prime_field_repr_impl( - repr: &syn::Ident, - limbs: usize -) -> quote::Tokens -{ +fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> quote::Tokens { quote! { #[derive(Copy, Clone, PartialEq, Eq, Default)] pub struct #repr(pub [u64; #limbs]); @@ -263,11 +253,7 @@ fn prime_field_repr_impl( } /// Convert BigUint into a vector of 64-bit limbs. -fn biguint_to_u64_vec( - mut v: BigUint, - limbs: usize -) -> Vec -{ +fn biguint_to_u64_vec(mut v: BigUint, limbs: usize) -> Vec { let m = BigUint::one() << 64; let mut ret = vec![]; @@ -285,10 +271,7 @@ fn biguint_to_u64_vec( ret } -fn biguint_num_bits( - mut v: BigUint -) -> u32 -{ +fn biguint_num_bits(mut v: BigUint) -> u32 { let mut bits = 0; while v != BigUint::zero() { @@ -300,17 +283,12 @@ fn biguint_num_bits( } /// BigUint modular exponentiation by square-and-multiply. -fn exp( - base: BigUint, - exp: &BigUint, - modulus: &BigUint -) -> BigUint -{ +fn exp(base: BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint { let mut ret = BigUint::one(); for i in exp.to_bytes_be() - .into_iter() - .flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd())) + .into_iter() + .flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd())) { ret = (&ret * &ret) % modulus; if i { @@ -327,9 +305,13 @@ fn test_exp() { exp( BigUint::from_str("4398572349857239485729348572983472345").unwrap(), &BigUint::from_str("5489673498567349856734895").unwrap(), - &BigUint::from_str("52435875175126190479447740508185965837690552500527637822603658699938581184513").unwrap() + &BigUint::from_str( + "52435875175126190479447740508185965837690552500527637822603658699938581184513" + ).unwrap() ), - BigUint::from_str("4371221214068404307866768905142520595925044802278091865033317963560480051536").unwrap() + BigUint::from_str( + "4371221214068404307866768905142520595925044802278091865033317963560480051536" + ).unwrap() ); } @@ -338,14 +320,13 @@ fn prime_field_constants_and_sqrt( repr: &syn::Ident, modulus: BigUint, limbs: usize, - generator: BigUint -) -> quote::Tokens -{ + generator: BigUint, +) -> quote::Tokens { let modulus_num_bits = biguint_num_bits(modulus.clone()); // The number of bits we should "shave" from a randomly sampled reputation, i.e., // if our modulus is 381 bits and our representation is 384 bits, we should shave - // 3 bits from the beginning of a randomly sampled 384 bit representation to + // 3 bits from the beginning of a randomly sampled 384 bit representation to // reduce the cost of rejection sampling. let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone()); @@ -361,91 +342,96 @@ fn prime_field_constants_and_sqrt( } // Compute 2^s root of unity given the generator - let root_of_unity = biguint_to_u64_vec((exp(generator.clone(), &t, &modulus) * &r) % &modulus, limbs); + let root_of_unity = biguint_to_u64_vec( + (exp(generator.clone(), &t, &modulus) * &r) % &modulus, + limbs, + ); let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs); let sqrt_impl = - if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() { - let mod_minus_3_over_4 = biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2, limbs); + if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() { + let mod_minus_3_over_4 = + biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2, limbs); - // Compute -R as (m - r) - let rneg = biguint_to_u64_vec(&modulus - &r, limbs); + // Compute -R as (m - r) + let rneg = biguint_to_u64_vec(&modulus - &r, limbs); - quote!{ - impl ::ff::SqrtField for #name { - fn sqrt(&self) -> Option { - // Shank's algorithm for q mod 4 = 3 - // https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2) + quote!{ + impl ::ff::SqrtField for #name { + fn sqrt(&self) -> Option { + // Shank's algorithm for q mod 4 = 3 + // https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2) - let mut a1 = self.pow(#mod_minus_3_over_4); + let mut a1 = self.pow(#mod_minus_3_over_4); - let mut a0 = a1; - a0.square(); - a0.mul_assign(self); + let mut a0 = a1; + a0.square(); + a0.mul_assign(self); - if a0.0 == #repr(#rneg) { - None - } else { - a1.mul_assign(self); - Some(a1) + if a0.0 == #repr(#rneg) { + None + } else { + a1.mul_assign(self); + Some(a1) + } } } } - } - } else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() { - let mod_minus_1_over_2 = biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs); - let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs); - let t = biguint_to_u64_vec(t.clone(), limbs); + } else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() { + let mod_minus_1_over_2 = + biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs); + let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs); + let t = biguint_to_u64_vec(t.clone(), limbs); - quote!{ - impl ::ff::SqrtField for #name { - fn sqrt(&self) -> Option { - // Tonelli-Shank's algorithm for q mod 16 = 1 - // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) + quote!{ + impl ::ff::SqrtField for #name { + fn sqrt(&self) -> Option { + // Tonelli-Shank's algorithm for q mod 16 = 1 + // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) - if self.is_zero() { - return Some(*self); - } - - if self.pow(#mod_minus_1_over_2) != Self::one() { - None - } else { - let mut c = #name(#repr(#root_of_unity)); - let mut r = self.pow(#t_plus_1_over_2); - let mut t = self.pow(#t); - let mut m = #s; - - while t != Self::one() { - let mut i = 1; - { - let mut t2i = t; - t2i.square(); - loop { - if t2i == Self::one() { - break; - } - t2i.square(); - i += 1; - } - } - - for _ in 0..(m - i - 1) { - c.square(); - } - r.mul_assign(&c); - c.square(); - t.mul_assign(&c); - m = i; + if self.is_zero() { + return Some(*self); } - Some(r) + if self.pow(#mod_minus_1_over_2) != Self::one() { + None + } else { + let mut c = #name(#repr(#root_of_unity)); + let mut r = self.pow(#t_plus_1_over_2); + let mut t = self.pow(#t); + let mut m = #s; + + while t != Self::one() { + let mut i = 1; + { + let mut t2i = t; + t2i.square(); + loop { + if t2i == Self::one() { + break; + } + t2i.square(); + i += 1; + } + } + + for _ in 0..(m - i - 1) { + c.square(); + } + r.mul_assign(&c); + c.square(); + t.mul_assign(&c); + m = i; + } + + Some(r) + } } } } - } - } else { - quote!{} - }; + } else { + quote!{} + }; // Compute R^2 mod m let r2 = biguint_to_u64_vec((&r * &r) % &modulus, limbs); @@ -496,12 +482,7 @@ fn prime_field_constants_and_sqrt( } /// Implement PrimeField for the derived type. -fn prime_field_impl( - name: &syn::Ident, - repr: &syn::Ident, - limbs: usize -) -> quote::Tokens -{ +fn prime_field_impl(name: &syn::Ident, repr: &syn::Ident, limbs: usize) -> quote::Tokens { // Returns r{n} as an ident. fn get_temp(n: usize) -> syn::Ident { syn::Ident::from(format!("r{}", n)) @@ -511,20 +492,18 @@ fn prime_field_impl( // r0: u64, mut r1: u64, mut r2: u64, ... let mut mont_paramlist = quote::Tokens::new(); mont_paramlist.append_separated( - (0..(limbs*2)).map(|i| (i, get_temp(i))) - .map(|(i, x)| { - if i != 0 { - quote!{mut #x: u64} - } else { - quote!{#x: u64} - } - }), - "," + (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| { + if i != 0 { + quote!{mut #x: u64} + } else { + quote!{#x: u64} + } + }), + ",", ); // Implement montgomery reduction for some number of limbs - fn mont_impl(limbs: usize) -> quote::Tokens - { + fn mont_impl(limbs: usize) -> quote::Tokens { let mut gen = quote::Tokens::new(); for i in 0..limbs { @@ -574,16 +553,15 @@ fn prime_field_impl( gen } - fn sqr_impl(a: quote::Tokens, limbs: usize) -> quote::Tokens - { + fn sqr_impl(a: quote::Tokens, limbs: usize) -> quote::Tokens { let mut gen = quote::Tokens::new(); - for i in 0..(limbs-1) { + for i in 0..(limbs - 1) { gen.append(quote!{ let mut carry = 0; }); - for j in (i+1)..limbs { + for j in (i + 1)..limbs { let temp = get_temp(i + j); if i == 0 { gen.append(quote!{ @@ -603,7 +581,7 @@ fn prime_field_impl( }); } - for i in 1..(limbs*2) { + for i in 1..(limbs * 2) { let k = get_temp(i); if i == 1 { @@ -611,7 +589,7 @@ fn prime_field_impl( let tmp0 = #k >> 63; let #k = #k << 1; }); - } else if i == (limbs*2 - 1) { + } else if i == (limbs * 2 - 1) { gen.append(quote!{ let #k = tmp0; }); @@ -648,7 +626,7 @@ fn prime_field_impl( } let mut mont_calling = quote::Tokens::new(); - mont_calling.append_separated((0..(limbs*2)).map(|i| get_temp(i)), ","); + mont_calling.append_separated((0..(limbs * 2)).map(|i| get_temp(i)), ","); gen.append(quote!{ self.mont_reduce(#mont_calling); @@ -657,8 +635,7 @@ fn prime_field_impl( gen } - fn mul_impl(a: quote::Tokens, b: quote::Tokens, limbs: usize) -> quote::Tokens - { + fn mul_impl(a: quote::Tokens, b: quote::Tokens, limbs: usize) -> quote::Tokens { let mut gen = quote::Tokens::new(); for i in 0..limbs { @@ -688,7 +665,7 @@ fn prime_field_impl( } let mut mont_calling = quote::Tokens::new(); - mont_calling.append_separated((0..(limbs*2)).map(|i| get_temp(i)), ","); + mont_calling.append_separated((0..(limbs * 2)).map(|i| get_temp(i)), ","); gen.append(quote!{ self.mont_reduce(#mont_calling); @@ -704,9 +681,10 @@ fn prime_field_impl( // (self.0).0[0], (self.0).0[1], ..., 0, 0, 0, 0, ... let mut into_repr_params = quote::Tokens::new(); into_repr_params.append_separated( - (0..limbs).map(|i| quote!{ (self.0).0[#i] }) - .chain((0..limbs).map(|_| quote!{0})), - "," + (0..limbs) + .map(|i| quote!{ (self.0).0[#i] }) + .chain((0..limbs).map(|_| quote!{0})), + ",", ); quote!{ diff --git a/src/lib.rs b/src/lib.rs index 99a7e7c..88696ac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,15 +11,8 @@ pub use ff_derive::*; use std::fmt; /// This trait represents an element of a field. -pub trait Field: Sized + - Eq + - Copy + - Clone + - Send + - Sync + - fmt::Debug + - 'static + - rand::Rand +pub trait Field: + Sized + Eq + Copy + Clone + Send + Sync + fmt::Debug + 'static + rand::Rand { /// Returns the zero element of the field, the additive identity. fn zero() -> Self; @@ -57,8 +50,7 @@ pub trait Field: Sized + /// Exponentiates this element by a number represented with `u64` limbs, /// least significant digit first. - fn pow>(&self, exp: S) -> Self - { + fn pow>(&self, exp: S) -> Self { let mut res = Self::one(); for i in BitIterator::new(exp) { @@ -73,8 +65,7 @@ pub trait Field: Sized + } /// This trait represents an element of a field that has a square root operation described for it. -pub trait SqrtField: Field -{ +pub trait SqrtField: Field { /// Returns the square root of the field element, if it is /// quadratic residue. fn sqrt(&self) -> Option; @@ -83,18 +74,19 @@ pub trait SqrtField: Field /// This trait represents a wrapper around a biginteger which can encode any element of a particular /// prime field. It is a smart wrapper around a sequence of `u64` limbs, least-significant digit /// first. -pub trait PrimeFieldRepr: Sized + - Copy + - Clone + - Eq + - Ord + - Send + - Sync + - fmt::Debug + - 'static + - rand::Rand + - AsRef<[u64]> + - From +pub trait PrimeFieldRepr: + Sized + + Copy + + Clone + + Eq + + Ord + + Send + + Sync + + fmt::Debug + + 'static + + rand::Rand + + AsRef<[u64]> + + From { /// Subtract another reprensetation from this one, returning the borrow bit. fn sub_noborrow(&mut self, other: &Self) -> bool; @@ -124,8 +116,7 @@ pub trait PrimeFieldRepr: Sized + } /// This represents an element of a prime field. -pub trait PrimeField: Field -{ +pub trait PrimeField: Field { /// The prime field can be converted back and forth into this biginteger /// representation. type Repr: PrimeFieldRepr; @@ -162,17 +153,14 @@ pub trait PrimeField: Field pub struct BitIterator { t: E, - n: usize + n: usize, } impl> BitIterator { fn new(t: E) -> Self { let n = t.as_ref().len() * 64; - BitIterator { - t: t, - n: n - } + BitIterator { t: t, n: n } } } @@ -205,7 +193,12 @@ fn test_bit_iterator() { let expected = "1010010101111110101010000101101011101000011101110101001000011001100100100011011010001011011011010001011011101100110100111011010010110001000011110100110001100110011101101000101100011100100100100100001010011101010111110011101011000011101000111011011101011001"; - let mut a = BitIterator::new([0x429d5f3ac3a3b759, 0xb10f4c66768b1c92, 0x92368b6d16ecd3b4, 0xa57ea85ae8775219]); + let mut a = BitIterator::new([ + 0x429d5f3ac3a3b759, + 0xb10f4c66768b1c92, + 0x92368b6d16ecd3b4, + 0xa57ea85ae8775219, + ]); for e in expected.chars() { assert!(a.next().unwrap() == (e == '1'));