cargo fmt

This commit is contained in:
Jack Grigg
2018-06-26 10:48:27 -04:00
committed by str4d
parent c7252a43bf
commit 755fc7aba8
2 changed files with 164 additions and 193 deletions

View File

@@ -1,4 +1,4 @@
#![recursion_limit="1024"]
#![recursion_limit = "1024"]
extern crate proc_macro;
extern crate syn;
@@ -6,39 +6,38 @@ extern crate syn;
extern crate quote;
extern crate num_bigint;
extern crate num_traits;
extern crate num_integer;
extern crate num_traits;
use num_integer::Integer;
use num_traits::{Zero, One, ToPrimitive};
use num_bigint::BigUint;
use num_integer::Integer;
use num_traits::{One, ToPrimitive, Zero};
use std::str::FromStr;
#[proc_macro_derive(PrimeField, attributes(PrimeFieldModulus, PrimeFieldGenerator))]
pub fn prime_field(
input: proc_macro::TokenStream
) -> proc_macro::TokenStream
{
pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
// Construct a string representation of the type definition
let s = input.to_string();
// Parse the string representation
let ast = syn::parse_derive_input(&s).unwrap();
// The struct we're deriving for is a wrapper around a "Repr" type we must construct.
let repr_ident = fetch_wrapped_ident(&ast.body)
.expect("PrimeField derive only operates over tuple structs of a single item");
.expect("PrimeField derive only operates over tuple structs of a single item");
// We're given the modulus p of the prime field
let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs)
.expect("Please supply a PrimeFieldModulus attribute")
.parse().expect("PrimeFieldModulus should be a number");
.expect("Please supply a PrimeFieldModulus attribute")
.parse()
.expect("PrimeFieldModulus should be a number");
// We may be provided with a generator of p - 1 order. It is required that this generator be quadratic
// nonresidue.
let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs)
.expect("Please supply a PrimeFieldGenerator attribute")
.parse().expect("PrimeFieldGenerator should be a number");
.expect("Please supply a PrimeFieldGenerator attribute")
.parse()
.expect("PrimeFieldGenerator should be a number");
// The arithmetic in this library only works if the modulus*2 is smaller than the backing
// representation. Compute the number of limbs we need.
@@ -55,18 +54,21 @@ pub fn prime_field(
let mut gen = quote::Tokens::new();
gen.append(prime_field_repr_impl(&repr_ident, limbs));
gen.append(prime_field_constants_and_sqrt(&ast.ident, &repr_ident, modulus, limbs, generator));
gen.append(prime_field_constants_and_sqrt(
&ast.ident,
&repr_ident,
modulus,
limbs,
generator,
));
gen.append(prime_field_impl(&ast.ident, &repr_ident, limbs));
// Return the generated impl
gen.parse().unwrap()
}
/// Fetches the ident being wrapped by the type we're deriving.
fn fetch_wrapped_ident(
body: &syn::Body
) -> Option<syn::Ident>
{
fn fetch_wrapped_ident(body: &syn::Body) -> Option<syn::Ident> {
match body {
&syn::Body::Struct(ref variant_data) => {
let fields = variant_data.fields();
@@ -76,11 +78,11 @@ fn fetch_wrapped_ident(
if path.segments.len() == 1 {
return Some(path.segments[0].ident.clone());
}
},
}
_ => {}
}
}
},
}
_ => {}
};
@@ -88,22 +90,14 @@ fn fetch_wrapped_ident(
}
/// Fetch an attribute string from the derived struct.
fn fetch_attr(
name: &str,
attrs: &[syn::Attribute]
) -> Option<String>
{
fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
for attr in attrs {
if attr.name() == name {
match attr.value {
syn::MetaItem::NameValue(_, ref val) => {
match val {
&syn::Lit::Str(ref s, _) => {
return Some(s.clone())
},
_ => {
panic!("attribute {} should be a string", name);
}
syn::MetaItem::NameValue(_, ref val) => match val {
&syn::Lit::Str(ref s, _) => return Some(s.clone()),
_ => {
panic!("attribute {} should be a string", name);
}
},
_ => {
@@ -117,11 +111,7 @@ fn fetch_attr(
}
// Implement PrimeFieldRepr for the wrapped ident `repr` with `limbs` limbs.
fn prime_field_repr_impl(
repr: &syn::Ident,
limbs: usize
) -> quote::Tokens
{
fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> quote::Tokens {
quote! {
#[derive(Copy, Clone, PartialEq, Eq, Default)]
pub struct #repr(pub [u64; #limbs]);
@@ -263,11 +253,7 @@ fn prime_field_repr_impl(
}
/// Convert BigUint into a vector of 64-bit limbs.
fn biguint_to_u64_vec(
mut v: BigUint,
limbs: usize
) -> Vec<u64>
{
fn biguint_to_u64_vec(mut v: BigUint, limbs: usize) -> Vec<u64> {
let m = BigUint::one() << 64;
let mut ret = vec![];
@@ -285,10 +271,7 @@ fn biguint_to_u64_vec(
ret
}
fn biguint_num_bits(
mut v: BigUint
) -> u32
{
fn biguint_num_bits(mut v: BigUint) -> u32 {
let mut bits = 0;
while v != BigUint::zero() {
@@ -300,17 +283,12 @@ fn biguint_num_bits(
}
/// BigUint modular exponentiation by square-and-multiply.
fn exp(
base: BigUint,
exp: &BigUint,
modulus: &BigUint
) -> BigUint
{
fn exp(base: BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
let mut ret = BigUint::one();
for i in exp.to_bytes_be()
.into_iter()
.flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd()))
.into_iter()
.flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd()))
{
ret = (&ret * &ret) % modulus;
if i {
@@ -327,9 +305,13 @@ fn test_exp() {
exp(
BigUint::from_str("4398572349857239485729348572983472345").unwrap(),
&BigUint::from_str("5489673498567349856734895").unwrap(),
&BigUint::from_str("52435875175126190479447740508185965837690552500527637822603658699938581184513").unwrap()
&BigUint::from_str(
"52435875175126190479447740508185965837690552500527637822603658699938581184513"
).unwrap()
),
BigUint::from_str("4371221214068404307866768905142520595925044802278091865033317963560480051536").unwrap()
BigUint::from_str(
"4371221214068404307866768905142520595925044802278091865033317963560480051536"
).unwrap()
);
}
@@ -338,14 +320,13 @@ fn prime_field_constants_and_sqrt(
repr: &syn::Ident,
modulus: BigUint,
limbs: usize,
generator: BigUint
) -> quote::Tokens
{
generator: BigUint,
) -> quote::Tokens {
let modulus_num_bits = biguint_num_bits(modulus.clone());
// The number of bits we should "shave" from a randomly sampled reputation, i.e.,
// if our modulus is 381 bits and our representation is 384 bits, we should shave
// 3 bits from the beginning of a randomly sampled 384 bit representation to
// 3 bits from the beginning of a randomly sampled 384 bit representation to
// reduce the cost of rejection sampling.
let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone());
@@ -361,91 +342,96 @@ fn prime_field_constants_and_sqrt(
}
// Compute 2^s root of unity given the generator
let root_of_unity = biguint_to_u64_vec((exp(generator.clone(), &t, &modulus) * &r) % &modulus, limbs);
let root_of_unity = biguint_to_u64_vec(
(exp(generator.clone(), &t, &modulus) * &r) % &modulus,
limbs,
);
let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs);
let sqrt_impl =
if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
let mod_minus_3_over_4 = biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2, limbs);
if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
let mod_minus_3_over_4 =
biguint_to_u64_vec((&modulus - BigUint::from_str("3").unwrap()) >> 2, limbs);
// Compute -R as (m - r)
let rneg = biguint_to_u64_vec(&modulus - &r, limbs);
// Compute -R as (m - r)
let rneg = biguint_to_u64_vec(&modulus - &r, limbs);
quote!{
impl ::ff::SqrtField for #name {
fn sqrt(&self) -> Option<Self> {
// Shank's algorithm for q mod 4 = 3
// https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2)
quote!{
impl ::ff::SqrtField for #name {
fn sqrt(&self) -> Option<Self> {
// Shank's algorithm for q mod 4 = 3
// https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2)
let mut a1 = self.pow(#mod_minus_3_over_4);
let mut a1 = self.pow(#mod_minus_3_over_4);
let mut a0 = a1;
a0.square();
a0.mul_assign(self);
let mut a0 = a1;
a0.square();
a0.mul_assign(self);
if a0.0 == #repr(#rneg) {
None
} else {
a1.mul_assign(self);
Some(a1)
if a0.0 == #repr(#rneg) {
None
} else {
a1.mul_assign(self);
Some(a1)
}
}
}
}
}
} else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
let mod_minus_1_over_2 = biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs);
let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs);
let t = biguint_to_u64_vec(t.clone(), limbs);
} else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
let mod_minus_1_over_2 =
biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs);
let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs);
let t = biguint_to_u64_vec(t.clone(), limbs);
quote!{
impl ::ff::SqrtField for #name {
fn sqrt(&self) -> Option<Self> {
// Tonelli-Shank's algorithm for q mod 16 = 1
// https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
quote!{
impl ::ff::SqrtField for #name {
fn sqrt(&self) -> Option<Self> {
// Tonelli-Shank's algorithm for q mod 16 = 1
// https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
if self.is_zero() {
return Some(*self);
}
if self.pow(#mod_minus_1_over_2) != Self::one() {
None
} else {
let mut c = #name(#repr(#root_of_unity));
let mut r = self.pow(#t_plus_1_over_2);
let mut t = self.pow(#t);
let mut m = #s;
while t != Self::one() {
let mut i = 1;
{
let mut t2i = t;
t2i.square();
loop {
if t2i == Self::one() {
break;
}
t2i.square();
i += 1;
}
}
for _ in 0..(m - i - 1) {
c.square();
}
r.mul_assign(&c);
c.square();
t.mul_assign(&c);
m = i;
if self.is_zero() {
return Some(*self);
}
Some(r)
if self.pow(#mod_minus_1_over_2) != Self::one() {
None
} else {
let mut c = #name(#repr(#root_of_unity));
let mut r = self.pow(#t_plus_1_over_2);
let mut t = self.pow(#t);
let mut m = #s;
while t != Self::one() {
let mut i = 1;
{
let mut t2i = t;
t2i.square();
loop {
if t2i == Self::one() {
break;
}
t2i.square();
i += 1;
}
}
for _ in 0..(m - i - 1) {
c.square();
}
r.mul_assign(&c);
c.square();
t.mul_assign(&c);
m = i;
}
Some(r)
}
}
}
}
}
} else {
quote!{}
};
} else {
quote!{}
};
// Compute R^2 mod m
let r2 = biguint_to_u64_vec((&r * &r) % &modulus, limbs);
@@ -496,12 +482,7 @@ fn prime_field_constants_and_sqrt(
}
/// Implement PrimeField for the derived type.
fn prime_field_impl(
name: &syn::Ident,
repr: &syn::Ident,
limbs: usize
) -> quote::Tokens
{
fn prime_field_impl(name: &syn::Ident, repr: &syn::Ident, limbs: usize) -> quote::Tokens {
// Returns r{n} as an ident.
fn get_temp(n: usize) -> syn::Ident {
syn::Ident::from(format!("r{}", n))
@@ -511,20 +492,18 @@ fn prime_field_impl(
// r0: u64, mut r1: u64, mut r2: u64, ...
let mut mont_paramlist = quote::Tokens::new();
mont_paramlist.append_separated(
(0..(limbs*2)).map(|i| (i, get_temp(i)))
.map(|(i, x)| {
if i != 0 {
quote!{mut #x: u64}
} else {
quote!{#x: u64}
}
}),
","
(0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| {
if i != 0 {
quote!{mut #x: u64}
} else {
quote!{#x: u64}
}
}),
",",
);
// Implement montgomery reduction for some number of limbs
fn mont_impl(limbs: usize) -> quote::Tokens
{
fn mont_impl(limbs: usize) -> quote::Tokens {
let mut gen = quote::Tokens::new();
for i in 0..limbs {
@@ -574,16 +553,15 @@ fn prime_field_impl(
gen
}
fn sqr_impl(a: quote::Tokens, limbs: usize) -> quote::Tokens
{
fn sqr_impl(a: quote::Tokens, limbs: usize) -> quote::Tokens {
let mut gen = quote::Tokens::new();
for i in 0..(limbs-1) {
for i in 0..(limbs - 1) {
gen.append(quote!{
let mut carry = 0;
});
for j in (i+1)..limbs {
for j in (i + 1)..limbs {
let temp = get_temp(i + j);
if i == 0 {
gen.append(quote!{
@@ -603,7 +581,7 @@ fn prime_field_impl(
});
}
for i in 1..(limbs*2) {
for i in 1..(limbs * 2) {
let k = get_temp(i);
if i == 1 {
@@ -611,7 +589,7 @@ fn prime_field_impl(
let tmp0 = #k >> 63;
let #k = #k << 1;
});
} else if i == (limbs*2 - 1) {
} else if i == (limbs * 2 - 1) {
gen.append(quote!{
let #k = tmp0;
});
@@ -648,7 +626,7 @@ fn prime_field_impl(
}
let mut mont_calling = quote::Tokens::new();
mont_calling.append_separated((0..(limbs*2)).map(|i| get_temp(i)), ",");
mont_calling.append_separated((0..(limbs * 2)).map(|i| get_temp(i)), ",");
gen.append(quote!{
self.mont_reduce(#mont_calling);
@@ -657,8 +635,7 @@ fn prime_field_impl(
gen
}
fn mul_impl(a: quote::Tokens, b: quote::Tokens, limbs: usize) -> quote::Tokens
{
fn mul_impl(a: quote::Tokens, b: quote::Tokens, limbs: usize) -> quote::Tokens {
let mut gen = quote::Tokens::new();
for i in 0..limbs {
@@ -688,7 +665,7 @@ fn prime_field_impl(
}
let mut mont_calling = quote::Tokens::new();
mont_calling.append_separated((0..(limbs*2)).map(|i| get_temp(i)), ",");
mont_calling.append_separated((0..(limbs * 2)).map(|i| get_temp(i)), ",");
gen.append(quote!{
self.mont_reduce(#mont_calling);
@@ -704,9 +681,10 @@ fn prime_field_impl(
// (self.0).0[0], (self.0).0[1], ..., 0, 0, 0, 0, ...
let mut into_repr_params = quote::Tokens::new();
into_repr_params.append_separated(
(0..limbs).map(|i| quote!{ (self.0).0[#i] })
.chain((0..limbs).map(|_| quote!{0})),
","
(0..limbs)
.map(|i| quote!{ (self.0).0[#i] })
.chain((0..limbs).map(|_| quote!{0})),
",",
);
quote!{

View File

@@ -11,15 +11,8 @@ pub use ff_derive::*;
use std::fmt;
/// This trait represents an element of a field.
pub trait Field: Sized +
Eq +
Copy +
Clone +
Send +
Sync +
fmt::Debug +
'static +
rand::Rand
pub trait Field:
Sized + Eq + Copy + Clone + Send + Sync + fmt::Debug + 'static + rand::Rand
{
/// Returns the zero element of the field, the additive identity.
fn zero() -> Self;
@@ -57,8 +50,7 @@ pub trait Field: Sized +
/// Exponentiates this element by a number represented with `u64` limbs,
/// least significant digit first.
fn pow<S: AsRef<[u64]>>(&self, exp: S) -> Self
{
fn pow<S: AsRef<[u64]>>(&self, exp: S) -> Self {
let mut res = Self::one();
for i in BitIterator::new(exp) {
@@ -73,8 +65,7 @@ pub trait Field: Sized +
}
/// This trait represents an element of a field that has a square root operation described for it.
pub trait SqrtField: Field
{
pub trait SqrtField: Field {
/// Returns the square root of the field element, if it is
/// quadratic residue.
fn sqrt(&self) -> Option<Self>;
@@ -83,18 +74,19 @@ pub trait SqrtField: Field
/// This trait represents a wrapper around a biginteger which can encode any element of a particular
/// prime field. It is a smart wrapper around a sequence of `u64` limbs, least-significant digit
/// first.
pub trait PrimeFieldRepr: Sized +
Copy +
Clone +
Eq +
Ord +
Send +
Sync +
fmt::Debug +
'static +
rand::Rand +
AsRef<[u64]> +
From<u64>
pub trait PrimeFieldRepr:
Sized
+ Copy
+ Clone
+ Eq
+ Ord
+ Send
+ Sync
+ fmt::Debug
+ 'static
+ rand::Rand
+ AsRef<[u64]>
+ From<u64>
{
/// Subtract another reprensetation from this one, returning the borrow bit.
fn sub_noborrow(&mut self, other: &Self) -> bool;
@@ -124,8 +116,7 @@ pub trait PrimeFieldRepr: Sized +
}
/// This represents an element of a prime field.
pub trait PrimeField: Field
{
pub trait PrimeField: Field {
/// The prime field can be converted back and forth into this biginteger
/// representation.
type Repr: PrimeFieldRepr;
@@ -162,17 +153,14 @@ pub trait PrimeField: Field
pub struct BitIterator<E> {
t: E,
n: usize
n: usize,
}
impl<E: AsRef<[u64]>> BitIterator<E> {
fn new(t: E) -> Self {
let n = t.as_ref().len() * 64;
BitIterator {
t: t,
n: n
}
BitIterator { t: t, n: n }
}
}
@@ -205,7 +193,12 @@ fn test_bit_iterator() {
let expected = "1010010101111110101010000101101011101000011101110101001000011001100100100011011010001011011011010001011011101100110100111011010010110001000011110100110001100110011101101000101100011100100100100100001010011101010111110011101011000011101000111011011101011001";
let mut a = BitIterator::new([0x429d5f3ac3a3b759, 0xb10f4c66768b1c92, 0x92368b6d16ecd3b4, 0xa57ea85ae8775219]);
let mut a = BitIterator::new([
0x429d5f3ac3a3b759,
0xb10f4c66768b1c92,
0x92368b6d16ecd3b4,
0xa57ea85ae8775219,
]);
for e in expected.chars() {
assert!(a.next().unwrap() == (e == '1'));