Implement changes to traits in ff_derive

This commit is contained in:
Jack Grigg
2018-06-28 15:07:35 -04:00
parent 58cb06ee92
commit 29a9161981

View File

@@ -140,6 +140,17 @@ fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenS
} }
} }
impl ::std::fmt::Display for #repr {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
try!(write!(f, "0x"));
for i in self.0.iter().rev() {
try!(write!(f, "{:016x}", *i));
}
Ok(())
}
}
impl AsRef<[u64]> for #repr { impl AsRef<[u64]> for #repr {
#[inline(always)] #[inline(always)]
fn as_ref(&self) -> &[u64] { fn as_ref(&self) -> &[u64] {
@@ -147,6 +158,13 @@ fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenS
} }
} }
impl AsMut<[u64]> for #repr {
#[inline(always)]
fn as_mut(&mut self) -> &mut [u64] {
&mut self.0
}
}
impl From<u64> for #repr { impl From<u64> for #repr {
#[inline(always)] #[inline(always)]
fn from(val: u64) -> #repr { fn from(val: u64) -> #repr {
@@ -207,6 +225,32 @@ fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenS
} }
} }
#[inline(always)]
fn shr(&mut self, mut n: u32) {
if n as usize >= 64 * #limbs {
*self = Self::from(0);
return;
}
while n >= 64 {
let mut t = 0;
for i in self.0.iter_mut().rev() {
::std::mem::swap(&mut t, i);
}
n -= 64;
}
if n > 0 {
let mut t = 0;
for i in self.0.iter_mut().rev() {
let t2 = *i << (64 - n);
*i >>= n;
*i |= t;
t = t2;
}
}
}
#[inline(always)] #[inline(always)]
fn mul2(&mut self) { fn mul2(&mut self) {
let mut last = 0; let mut last = 0;
@@ -218,6 +262,32 @@ fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenS
} }
} }
#[inline(always)]
fn shl(&mut self, mut n: u32) {
if n as usize >= 64 * #limbs {
*self = Self::from(0);
return;
}
while n >= 64 {
let mut t = 0;
for i in &mut self.0 {
::std::mem::swap(&mut t, i);
}
n -= 64;
}
if n > 0 {
let mut t = 0;
for i in &mut self.0 {
let t2 = *i >> (64 - n);
*i <<= n;
*i |= t;
t = t2;
}
}
}
#[inline(always)] #[inline(always)]
fn num_bits(&self) -> u32 { fn num_bits(&self) -> u32 {
let mut ret = (#limbs as u32) * 64; let mut ret = (#limbs as u32) * 64;
@@ -233,25 +303,21 @@ fn prime_field_repr_impl(repr: &syn::Ident, limbs: usize) -> proc_macro2::TokenS
} }
#[inline(always)] #[inline(always)]
fn add_nocarry(&mut self, other: &#repr) -> bool { fn add_nocarry(&mut self, other: &#repr) {
let mut carry = 0; let mut carry = 0;
for (a, b) in self.0.iter_mut().zip(other.0.iter()) { for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
*a = ::ff::adc(*a, *b, &mut carry); *a = ::ff::adc(*a, *b, &mut carry);
} }
carry != 0
} }
#[inline(always)] #[inline(always)]
fn sub_noborrow(&mut self, other: &#repr) -> bool { fn sub_noborrow(&mut self, other: &#repr) {
let mut borrow = 0; let mut borrow = 0;
for (a, b) in self.0.iter_mut().zip(other.0.iter()) { for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
*a = ::ff::sbb(*a, *b, &mut borrow); *a = ::ff::sbb(*a, *b, &mut borrow);
} }
borrow != 0
} }
} }
} }
@@ -345,7 +411,7 @@ fn prime_field_constants_and_sqrt(
let r = (BigUint::one() << (limbs * 64)) % &modulus; let r = (BigUint::one() << (limbs * 64)) % &modulus;
// modulus - 1 = 2^s * t // modulus - 1 = 2^s * t
let mut s: usize = 0; let mut s: u32 = 0;
let mut t = &modulus - BigUint::from_str("1").unwrap(); let mut t = &modulus - BigUint::from_str("1").unwrap();
while t.is_even() { while t.is_even() {
t = t >> 1; t = t >> 1;
@@ -359,6 +425,22 @@ fn prime_field_constants_and_sqrt(
); );
let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs); let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs);
let mod_minus_1_over_2 =
biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs);
let legendre_impl = quote!{
fn legendre(&self) -> ::ff::LegendreSymbol {
// s = self^((modulus - 1) // 2)
let s = self.pow(#mod_minus_1_over_2);
if s == Self::zero() {
::ff::LegendreSymbol::Zero
} else if s == Self::one() {
::ff::LegendreSymbol::QuadraticResidue
} else {
::ff::LegendreSymbol::QuadraticNonResidue
}
}
};
let sqrt_impl = let sqrt_impl =
if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() { if (&modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
let mod_minus_3_over_4 = let mod_minus_3_over_4 =
@@ -369,6 +451,8 @@ fn prime_field_constants_and_sqrt(
quote!{ quote!{
impl ::ff::SqrtField for #name { impl ::ff::SqrtField for #name {
#legendre_impl
fn sqrt(&self) -> Option<Self> { fn sqrt(&self) -> Option<Self> {
// Shank's algorithm for q mod 4 = 3 // Shank's algorithm for q mod 4 = 3
// https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2) // https://eprint.iacr.org/2012/685.pdf (page 9, algorithm 2)
@@ -389,13 +473,13 @@ fn prime_field_constants_and_sqrt(
} }
} }
} else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() { } else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
let mod_minus_1_over_2 =
biguint_to_u64_vec((&modulus - BigUint::from_str("1").unwrap()) >> 1, limbs);
let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs); let t_plus_1_over_2 = biguint_to_u64_vec((&t + BigUint::one()) >> 1, limbs);
let t = biguint_to_u64_vec(t.clone(), limbs); let t = biguint_to_u64_vec(t.clone(), limbs);
quote!{ quote!{
impl ::ff::SqrtField for #name { impl ::ff::SqrtField for #name {
#legendre_impl
fn sqrt(&self) -> Option<Self> { fn sqrt(&self) -> Option<Self> {
// Tonelli-Shank's algorithm for q mod 16 = 1 // Tonelli-Shank's algorithm for q mod 16 = 1
// https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5) // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
@@ -483,7 +567,7 @@ fn prime_field_constants_and_sqrt(
const GENERATOR: #repr = #repr(#generator); const GENERATOR: #repr = #repr(#generator);
/// 2^s * t = MODULUS - 1 with t odd /// 2^s * t = MODULUS - 1 with t odd
const S: usize = #s; const S: u32 = #s;
/// 2^s root of unity computed by GENERATOR^t /// 2^s root of unity computed by GENERATOR^t
const ROOT_OF_UNITY: #repr = #repr(#root_of_unity); const ROOT_OF_UNITY: #repr = #repr(#root_of_unity);
@@ -736,6 +820,27 @@ fn prime_field_impl(
} }
} }
/// Elements are ordered lexicographically.
impl Ord for #name {
#[inline(always)]
fn cmp(&self, other: &#name) -> ::std::cmp::Ordering {
self.into_repr().cmp(&other.into_repr())
}
}
impl PartialOrd for #name {
#[inline(always)]
fn partial_cmp(&self, other: &#name) -> Option<::std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl ::std::fmt::Display for #name {
fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result {
write!(f, "{}({})", stringify!(#name), self.into_repr())
}
}
impl ::rand::Rand for #name { impl ::rand::Rand for #name {
/// Computes a uniformly random element using rejection sampling. /// Computes a uniformly random element using rejection sampling.
fn rand<R: ::rand::Rng>(rng: &mut R) -> Self { fn rand<R: ::rand::Rng>(rng: &mut R) -> Self {
@@ -751,17 +856,23 @@ fn prime_field_impl(
} }
} }
impl From<#name> for #repr {
fn from(e: #name) -> #repr {
e.into_repr()
}
}
impl ::ff::PrimeField for #name { impl ::ff::PrimeField for #name {
type Repr = #repr; type Repr = #repr;
fn from_repr(r: #repr) -> Result<#name, ()> { fn from_repr(r: #repr) -> Result<#name, PrimeFieldDecodingError> {
let mut r = #name(r); let mut r = #name(r);
if r.is_valid() { if r.is_valid() {
r.mul_assign(&#name(R2)); r.mul_assign(&#name(R2));
Ok(r) Ok(r)
} else { } else {
Err(()) Err(PrimeFieldDecodingError::NotInField(format!("{}", r.0)))
} }
} }
@@ -778,21 +889,15 @@ fn prime_field_impl(
MODULUS MODULUS
} }
fn num_bits() -> u32 { const NUM_BITS: u32 = MODULUS_BITS;
MODULUS_BITS
}
fn capacity() -> u32 { const CAPACITY: u32 = Self::NUM_BITS - 1;
Self::num_bits() - 1
}
fn multiplicative_generator() -> Self { fn multiplicative_generator() -> Self {
#name(GENERATOR) #name(GENERATOR)
} }
fn s() -> usize { const S: u32 = S;
S
}
fn root_of_unity() -> Self { fn root_of_unity() -> Self {
#name(ROOT_OF_UNITY) #name(ROOT_OF_UNITY)