#![deny(missing_docs)]
#![allow(non_snake_case)]
use clear_on_drop::clear::Clear;
use curve25519_dalek::scalar::Scalar;
use inner_product_proof::inner_product;
pub struct VecPoly1(pub Vec<Scalar>, pub Vec<Scalar>);
#[cfg(feature = "yoloproofs")]
pub struct VecPoly3(
pub Vec<Scalar>,
pub Vec<Scalar>,
pub Vec<Scalar>,
pub Vec<Scalar>,
);
pub struct Poly2(pub Scalar, pub Scalar, pub Scalar);
#[cfg(feature = "yoloproofs")]
pub struct Poly6 {
pub t1: Scalar,
pub t2: Scalar,
pub t3: Scalar,
pub t4: Scalar,
pub t5: Scalar,
pub t6: Scalar,
}
pub struct ScalarExp {
x: Scalar,
next_exp_x: Scalar,
}
impl Iterator for ScalarExp {
type Item = Scalar;
fn next(&mut self) -> Option<Scalar> {
let exp_x = self.next_exp_x;
self.next_exp_x *= self.x;
Some(exp_x)
}
fn size_hint(&self) -> (usize, Option<usize>) {
(usize::max_value(), None)
}
}
pub fn exp_iter(x: Scalar) -> ScalarExp {
let next_exp_x = Scalar::one();
ScalarExp { x, next_exp_x }
}
pub fn add_vec(a: &[Scalar], b: &[Scalar]) -> Vec<Scalar> {
if a.len() != b.len() {
println!("lengths of vectors don't match for vector addition");
}
let mut out = vec![Scalar::zero(); b.len()];
for i in 0..a.len() {
out[i] = a[i] + b[i];
}
out
}
impl VecPoly1 {
pub fn zero(n: usize) -> Self {
VecPoly1(vec![Scalar::zero(); n], vec![Scalar::zero(); n])
}
pub fn inner_product(&self, rhs: &VecPoly1) -> Poly2 {
let l = self;
let r = rhs;
let t0 = inner_product(&l.0, &r.0);
let t2 = inner_product(&l.1, &r.1);
let l0_plus_l1 = add_vec(&l.0, &l.1);
let r0_plus_r1 = add_vec(&r.0, &r.1);
let t1 = inner_product(&l0_plus_l1, &r0_plus_r1) - t0 - t2;
Poly2(t0, t1, t2)
}
pub fn eval(&self, x: Scalar) -> Vec<Scalar> {
let n = self.0.len();
let mut out = vec![Scalar::zero(); n];
for i in 0..n {
out[i] = self.0[i] + self.1[i] * x;
}
out
}
}
#[cfg(feature = "yoloproofs")]
impl VecPoly3 {
pub fn zero(n: usize) -> Self {
VecPoly3(
vec![Scalar::zero(); n],
vec![Scalar::zero(); n],
vec![Scalar::zero(); n],
vec![Scalar::zero(); n],
)
}
pub fn special_inner_product(lhs: &Self, rhs: &Self) -> Poly6 {
let t1 = inner_product(&lhs.1, &rhs.0);
let t2 = inner_product(&lhs.1, &rhs.1) + inner_product(&lhs.2, &rhs.0);
let t3 = inner_product(&lhs.2, &rhs.1) + inner_product(&lhs.3, &rhs.0);
let t4 = inner_product(&lhs.1, &rhs.3) + inner_product(&lhs.3, &rhs.1);
let t5 = inner_product(&lhs.2, &rhs.3);
let t6 = inner_product(&lhs.3, &rhs.3);
Poly6 {
t1,
t2,
t3,
t4,
t5,
t6,
}
}
pub fn eval(&self, x: Scalar) -> Vec<Scalar> {
let n = self.0.len();
let mut out = vec![Scalar::zero(); n];
for i in 0..n {
out[i] = self.0[i] + x * (self.1[i] + x * (self.2[i] + x * self.3[i]));
}
out
}
}
impl Poly2 {
pub fn eval(&self, x: Scalar) -> Scalar {
self.0 + x * (self.1 + x * self.2)
}
}
#[cfg(feature = "yoloproofs")]
impl Poly6 {
pub fn eval(&self, x: Scalar) -> Scalar {
x * (self.t1 + x * (self.t2 + x * (self.t3 + x * (self.t4 + x * (self.t5 + x * self.t6)))))
}
}
impl Drop for VecPoly1 {
fn drop(&mut self) {
for e in self.0.iter_mut() {
e.clear();
}
for e in self.1.iter_mut() {
e.clear();
}
}
}
impl Drop for Poly2 {
fn drop(&mut self) {
self.0.clear();
self.1.clear();
self.2.clear();
}
}
#[cfg(feature = "yoloproofs")]
impl Drop for VecPoly3 {
fn drop(&mut self) {
for e in self.0.iter_mut() {
e.clear();
}
for e in self.1.iter_mut() {
e.clear();
}
for e in self.2.iter_mut() {
e.clear();
}
for e in self.3.iter_mut() {
e.clear();
}
}
}
#[cfg(feature = "yoloproofs")]
impl Drop for Poly6 {
fn drop(&mut self) {
self.t1.clear();
self.t2.clear();
self.t3.clear();
self.t4.clear();
self.t5.clear();
self.t6.clear();
}
}
pub fn scalar_exp_vartime(x: &Scalar, mut n: u64) -> Scalar {
let mut result = Scalar::one();
let mut aux = *x;
while n > 0 {
let bit = n & 1;
if bit == 1 {
result = result * aux;
}
n = n >> 1;
aux = aux * aux;
}
result
}
pub fn sum_of_powers(x: &Scalar, n: usize) -> Scalar {
if !n.is_power_of_two() {
return sum_of_powers_slow(x, n);
}
if n == 0 || n == 1 {
return Scalar::from(n as u64);
}
let mut m = n;
let mut result = Scalar::one() + x;
let mut factor = *x;
while m > 2 {
factor = factor * factor;
result = result + factor * result;
m = m / 2;
}
result
}
fn sum_of_powers_slow(x: &Scalar, n: usize) -> Scalar {
exp_iter(*x).take(n).sum()
}
pub fn read32(data: &[u8]) -> [u8; 32] {
let mut buf32 = [0u8; 32];
buf32[..].copy_from_slice(&data[..32]);
buf32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn exp_2_is_powers_of_2() {
let exp_2: Vec<_> = exp_iter(Scalar::from(2u64)).take(4).collect();
assert_eq!(exp_2[0], Scalar::from(1u64));
assert_eq!(exp_2[1], Scalar::from(2u64));
assert_eq!(exp_2[2], Scalar::from(4u64));
assert_eq!(exp_2[3], Scalar::from(8u64));
}
#[test]
fn test_inner_product() {
let a = vec![
Scalar::from(1u64),
Scalar::from(2u64),
Scalar::from(3u64),
Scalar::from(4u64),
];
let b = vec![
Scalar::from(2u64),
Scalar::from(3u64),
Scalar::from(4u64),
Scalar::from(5u64),
];
assert_eq!(Scalar::from(40u64), inner_product(&a, &b));
}
fn scalar_exp_vartime_slow(x: &Scalar, n: u64) -> Scalar {
let mut result = Scalar::one();
for _ in 0..n {
result = result * x;
}
result
}
#[test]
fn test_scalar_exp() {
let x = Scalar::from_bits(
*b"\x84\xfc\xbcOx\x12\xa0\x06\xd7\x91\xd9z:'\xdd\x1e!CE\xf7\xb1\xb9Vz\x810sD\x96\x85\xb5\x07",
);
assert_eq!(scalar_exp_vartime(&x, 0), Scalar::one());
assert_eq!(scalar_exp_vartime(&x, 1), x);
assert_eq!(scalar_exp_vartime(&x, 2), x * x);
assert_eq!(scalar_exp_vartime(&x, 3), x * x * x);
assert_eq!(scalar_exp_vartime(&x, 4), x * x * x * x);
assert_eq!(scalar_exp_vartime(&x, 5), x * x * x * x * x);
assert_eq!(scalar_exp_vartime(&x, 64), scalar_exp_vartime_slow(&x, 64));
assert_eq!(
scalar_exp_vartime(&x, 0b11001010),
scalar_exp_vartime_slow(&x, 0b11001010)
);
}
#[test]
fn test_sum_of_powers() {
let x = Scalar::from(10u64);
assert_eq!(sum_of_powers_slow(&x, 0), sum_of_powers(&x, 0));
assert_eq!(sum_of_powers_slow(&x, 1), sum_of_powers(&x, 1));
assert_eq!(sum_of_powers_slow(&x, 2), sum_of_powers(&x, 2));
assert_eq!(sum_of_powers_slow(&x, 4), sum_of_powers(&x, 4));
assert_eq!(sum_of_powers_slow(&x, 8), sum_of_powers(&x, 8));
assert_eq!(sum_of_powers_slow(&x, 16), sum_of_powers(&x, 16));
assert_eq!(sum_of_powers_slow(&x, 32), sum_of_powers(&x, 32));
assert_eq!(sum_of_powers_slow(&x, 64), sum_of_powers(&x, 64));
}
#[test]
fn test_sum_of_powers_slow() {
let x = Scalar::from(10u64);
assert_eq!(sum_of_powers_slow(&x, 0), Scalar::zero());
assert_eq!(sum_of_powers_slow(&x, 1), Scalar::one());
assert_eq!(sum_of_powers_slow(&x, 2), Scalar::from(11u64));
assert_eq!(sum_of_powers_slow(&x, 3), Scalar::from(111u64));
assert_eq!(sum_of_powers_slow(&x, 4), Scalar::from(1111u64));
assert_eq!(sum_of_powers_slow(&x, 5), Scalar::from(11111u64));
assert_eq!(sum_of_powers_slow(&x, 6), Scalar::from(111111u64));
}
#[test]
fn vec_of_scalars_clear_on_drop() {
let mut v = vec![Scalar::from(24u64), Scalar::from(42u64)];
for e in v.iter_mut() {
e.clear();
}
fn flat_slice<T>(x: &[T]) -> &[u8] {
use core::mem;
use core::slice;
unsafe { slice::from_raw_parts(x.as_ptr() as *const u8, mem::size_of_val(x)) }
}
assert_eq!(flat_slice(&v.as_slice()), &[0u8; 64][..]);
assert_eq!(v[0], Scalar::zero());
assert_eq!(v[1], Scalar::zero());
}
#[test]
fn tuple_of_scalars_clear_on_drop() {
let mut v = Poly2(
Scalar::from(24u64),
Scalar::from(42u64),
Scalar::from(255u64),
);
v.0.clear();
v.1.clear();
v.2.clear();
fn as_bytes<T>(x: &T) -> &[u8] {
use core::mem;
use core::slice;
unsafe { slice::from_raw_parts(x as *const T as *const u8, mem::size_of_val(x)) }
}
assert_eq!(as_bytes(&v), &[0u8; 96][..]);
assert_eq!(v.0, Scalar::zero());
assert_eq!(v.1, Scalar::zero());
assert_eq!(v.2, Scalar::zero());
}
}