#![allow(non_snake_case)]
#![doc(include = "../docs/inner-product-protocol.md")]
extern crate alloc;
use alloc::borrow::Borrow;
use alloc::vec::Vec;
use core::iter;
use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::VartimeMultiscalarMul;
use merlin::Transcript;
use errors::ProofError;
use transcript::TranscriptProtocol;
#[derive(Clone, Debug)]
pub struct InnerProductProof {
pub(crate) L_vec: Vec<CompressedRistretto>,
pub(crate) R_vec: Vec<CompressedRistretto>,
pub(crate) a: Scalar,
pub(crate) b: Scalar,
}
impl InnerProductProof {
pub fn create(
transcript: &mut Transcript,
Q: &RistrettoPoint,
G_factors: &[Scalar],
H_factors: &[Scalar],
mut G_vec: Vec<RistrettoPoint>,
mut H_vec: Vec<RistrettoPoint>,
mut a_vec: Vec<Scalar>,
mut b_vec: Vec<Scalar>,
) -> InnerProductProof {
let mut G = &mut G_vec[..];
let mut H = &mut H_vec[..];
let mut a = &mut a_vec[..];
let mut b = &mut b_vec[..];
let mut n = G.len();
assert_eq!(G.len(), n);
assert_eq!(H.len(), n);
assert_eq!(a.len(), n);
assert_eq!(b.len(), n);
assert_eq!(G_factors.len(), n);
assert_eq!(H_factors.len(), n);
assert!(n.is_power_of_two());
transcript.innerproduct_domain_sep(n as u64);
let lg_n = n.next_power_of_two().trailing_zeros() as usize;
let mut L_vec = Vec::with_capacity(lg_n);
let mut R_vec = Vec::with_capacity(lg_n);
if n != 1 {
n = n / 2;
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);
let c_L = inner_product(&a_L, &b_R);
let c_R = inner_product(&a_R, &b_L);
let L = RistrettoPoint::vartime_multiscalar_mul(
a_L.iter()
.zip(G_factors[n..2 * n].into_iter())
.map(|(a_L_i, g)| a_L_i * g)
.chain(
b_R.iter()
.zip(H_factors[0..n].into_iter())
.map(|(b_R_i, h)| b_R_i * h),
)
.chain(iter::once(c_L)),
G_R.iter().chain(H_L.iter()).chain(iter::once(Q)),
)
.compress();
let R = RistrettoPoint::vartime_multiscalar_mul(
a_R.iter()
.zip(G_factors[0..n].into_iter())
.map(|(a_R_i, g)| a_R_i * g)
.chain(
b_L.iter()
.zip(H_factors[n..2 * n].into_iter())
.map(|(b_L_i, h)| b_L_i * h),
)
.chain(iter::once(c_R)),
G_L.iter().chain(H_R.iter()).chain(iter::once(Q)),
)
.compress();
L_vec.push(L);
R_vec.push(R);
transcript.append_point(b"L", &L);
transcript.append_point(b"R", &R);
let u = transcript.challenge_scalar(b"u");
let u_inv = u.invert();
for i in 0..n {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = RistrettoPoint::vartime_multiscalar_mul(
&[u_inv * G_factors[i], u * G_factors[n + i]],
&[G_L[i], G_R[i]],
);
H_L[i] = RistrettoPoint::vartime_multiscalar_mul(
&[u * H_factors[i], u_inv * H_factors[n + i]],
&[H_L[i], H_R[i]],
)
}
a = a_L;
b = b_L;
G = G_L;
H = H_L;
}
while n != 1 {
n = n / 2;
let (a_L, a_R) = a.split_at_mut(n);
let (b_L, b_R) = b.split_at_mut(n);
let (G_L, G_R) = G.split_at_mut(n);
let (H_L, H_R) = H.split_at_mut(n);
let c_L = inner_product(&a_L, &b_R);
let c_R = inner_product(&a_R, &b_L);
let L = RistrettoPoint::vartime_multiscalar_mul(
a_L.iter().chain(b_R.iter()).chain(iter::once(&c_L)),
G_R.iter().chain(H_L.iter()).chain(iter::once(Q)),
)
.compress();
let R = RistrettoPoint::vartime_multiscalar_mul(
a_R.iter().chain(b_L.iter()).chain(iter::once(&c_R)),
G_L.iter().chain(H_R.iter()).chain(iter::once(Q)),
)
.compress();
L_vec.push(L);
R_vec.push(R);
transcript.append_point(b"L", &L);
transcript.append_point(b"R", &R);
let u = transcript.challenge_scalar(b"u");
let u_inv = u.invert();
for i in 0..n {
a_L[i] = a_L[i] * u + u_inv * a_R[i];
b_L[i] = b_L[i] * u_inv + u * b_R[i];
G_L[i] = RistrettoPoint::vartime_multiscalar_mul(&[u_inv, u], &[G_L[i], G_R[i]]);
H_L[i] = RistrettoPoint::vartime_multiscalar_mul(&[u, u_inv], &[H_L[i], H_R[i]]);
}
a = a_L;
b = b_L;
G = G_L;
H = H_L;
}
InnerProductProof {
L_vec: L_vec,
R_vec: R_vec,
a: a[0],
b: b[0],
}
}
pub(crate) fn verification_scalars(
&self,
n: usize,
transcript: &mut Transcript,
) -> Result<(Vec<Scalar>, Vec<Scalar>, Vec<Scalar>), ProofError> {
let lg_n = self.L_vec.len();
if lg_n >= 32 {
return Err(ProofError::VerificationError);
}
if n != (1 << lg_n) {
return Err(ProofError::VerificationError);
}
transcript.innerproduct_domain_sep(n as u64);
let mut challenges = Vec::with_capacity(lg_n);
for (L, R) in self.L_vec.iter().zip(self.R_vec.iter()) {
transcript.validate_and_append_point(b"L", L)?;
transcript.validate_and_append_point(b"R", R)?;
challenges.push(transcript.challenge_scalar(b"u"));
}
let mut challenges_inv = challenges.clone();
let allinv = Scalar::batch_invert(&mut challenges_inv);
for i in 0..lg_n {
challenges[i] = challenges[i] * challenges[i];
challenges_inv[i] = challenges_inv[i] * challenges_inv[i];
}
let challenges_sq = challenges;
let challenges_inv_sq = challenges_inv;
let mut s = Vec::with_capacity(n);
s.push(allinv);
for i in 1..n {
let lg_i = (32 - 1 - (i as u32).leading_zeros()) as usize;
let k = 1 << lg_i;
let u_lg_i_sq = challenges_sq[(lg_n - 1) - lg_i];
s.push(s[i - k] * u_lg_i_sq);
}
Ok((challenges_sq, challenges_inv_sq, s))
}
#[allow(dead_code)]
pub fn verify<IG, IH>(
&self,
n: usize,
transcript: &mut Transcript,
G_factors: IG,
H_factors: IH,
P: &RistrettoPoint,
Q: &RistrettoPoint,
G: &[RistrettoPoint],
H: &[RistrettoPoint],
) -> Result<(), ProofError>
where
IG: IntoIterator,
IG::Item: Borrow<Scalar>,
IH: IntoIterator,
IH::Item: Borrow<Scalar>,
{
let (u_sq, u_inv_sq, s) = self.verification_scalars(n, transcript)?;
let g_times_a_times_s = G_factors
.into_iter()
.zip(s.iter())
.map(|(g_i, s_i)| (self.a * s_i) * g_i.borrow())
.take(G.len());
let inv_s = s.iter().rev();
let h_times_b_div_s = H_factors
.into_iter()
.zip(inv_s)
.map(|(h_i, s_i_inv)| (self.b * s_i_inv) * h_i.borrow());
let neg_u_sq = u_sq.iter().map(|ui| -ui);
let neg_u_inv_sq = u_inv_sq.iter().map(|ui| -ui);
let Ls = self
.L_vec
.iter()
.map(|p| p.decompress().ok_or(ProofError::VerificationError))
.collect::<Result<Vec<_>, _>>()?;
let Rs = self
.R_vec
.iter()
.map(|p| p.decompress().ok_or(ProofError::VerificationError))
.collect::<Result<Vec<_>, _>>()?;
let expect_P = RistrettoPoint::vartime_multiscalar_mul(
iter::once(self.a * self.b)
.chain(g_times_a_times_s)
.chain(h_times_b_div_s)
.chain(neg_u_sq)
.chain(neg_u_inv_sq),
iter::once(Q)
.chain(G.iter())
.chain(H.iter())
.chain(Ls.iter())
.chain(Rs.iter()),
);
if expect_P == *P {
Ok(())
} else {
Err(ProofError::VerificationError)
}
}
pub fn serialized_size(&self) -> usize {
(self.L_vec.len() * 2 + 2) * 32
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(self.serialized_size());
for (l, r) in self.L_vec.iter().zip(self.R_vec.iter()) {
buf.extend_from_slice(l.as_bytes());
buf.extend_from_slice(r.as_bytes());
}
buf.extend_from_slice(self.a.as_bytes());
buf.extend_from_slice(self.b.as_bytes());
buf
}
#[inline]
pub(crate) fn to_bytes_iter(&self) -> impl Iterator<Item = u8> + '_ {
self.L_vec
.iter()
.zip(self.R_vec.iter())
.flat_map(|(l, r)| l.as_bytes().iter().chain(r.as_bytes()))
.chain(self.a.as_bytes())
.chain(self.b.as_bytes())
.copied()
}
pub fn from_bytes(slice: &[u8]) -> Result<InnerProductProof, ProofError> {
let b = slice.len();
if b % 32 != 0 {
return Err(ProofError::FormatError);
}
let num_elements = b / 32;
if num_elements < 2 {
return Err(ProofError::FormatError);
}
if (num_elements - 2) % 2 != 0 {
return Err(ProofError::FormatError);
}
let lg_n = (num_elements - 2) / 2;
if lg_n >= 32 {
return Err(ProofError::FormatError);
}
use util::read32;
let mut L_vec: Vec<CompressedRistretto> = Vec::with_capacity(lg_n);
let mut R_vec: Vec<CompressedRistretto> = Vec::with_capacity(lg_n);
for i in 0..lg_n {
let pos = 2 * i * 32;
L_vec.push(CompressedRistretto(read32(&slice[pos..])));
R_vec.push(CompressedRistretto(read32(&slice[pos + 32..])));
}
let pos = 2 * lg_n * 32;
let a =
Scalar::from_canonical_bytes(read32(&slice[pos..])).ok_or(ProofError::FormatError)?;
let b = Scalar::from_canonical_bytes(read32(&slice[pos + 32..]))
.ok_or(ProofError::FormatError)?;
Ok(InnerProductProof { L_vec, R_vec, a, b })
}
}
pub fn inner_product(a: &[Scalar], b: &[Scalar]) -> Scalar {
let mut out = Scalar::zero();
if a.len() != b.len() {
panic!("inner_product(a,b): lengths of vectors do not match");
}
for i in 0..a.len() {
out += a[i] * b[i];
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use sha3::Sha3_512;
use util;
fn test_helper_create(n: usize) {
let mut rng = rand::thread_rng();
use generators::BulletproofGens;
let bp_gens = BulletproofGens::new(n, 1);
let G: Vec<RistrettoPoint> = bp_gens.share(0).G(n).cloned().collect();
let H: Vec<RistrettoPoint> = bp_gens.share(0).H(n).cloned().collect();
let Q = RistrettoPoint::hash_from_bytes::<Sha3_512>(b"test point");
let a: Vec<_> = (0..n).map(|_| Scalar::random(&mut rng)).collect();
let b: Vec<_> = (0..n).map(|_| Scalar::random(&mut rng)).collect();
let c = inner_product(&a, &b);
let G_factors: Vec<Scalar> = iter::repeat(Scalar::one()).take(n).collect();
let y_inv = Scalar::random(&mut rng);
let H_factors: Vec<Scalar> = util::exp_iter(y_inv).take(n).collect();
let b_prime = b.iter().zip(util::exp_iter(y_inv)).map(|(bi, yi)| bi * yi);
let a_prime = a.iter().cloned();
let P = RistrettoPoint::vartime_multiscalar_mul(
a_prime.chain(b_prime).chain(iter::once(c)),
G.iter().chain(H.iter()).chain(iter::once(&Q)),
);
let mut verifier = Transcript::new(b"innerproducttest");
let proof = InnerProductProof::create(
&mut verifier,
&Q,
&G_factors,
&H_factors,
G.clone(),
H.clone(),
a.clone(),
b.clone(),
);
let mut verifier = Transcript::new(b"innerproducttest");
assert!(proof
.verify(
n,
&mut verifier,
iter::repeat(Scalar::one()).take(n),
util::exp_iter(y_inv).take(n),
&P,
&Q,
&G,
&H
)
.is_ok());
let proof = InnerProductProof::from_bytes(proof.to_bytes().as_slice()).unwrap();
let mut verifier = Transcript::new(b"innerproducttest");
assert!(proof
.verify(
n,
&mut verifier,
iter::repeat(Scalar::one()).take(n),
util::exp_iter(y_inv).take(n),
&P,
&Q,
&G,
&H
)
.is_ok());
}
#[test]
fn make_ipp_1() {
test_helper_create(1);
}
#[test]
fn make_ipp_2() {
test_helper_create(2);
}
#[test]
fn make_ipp_4() {
test_helper_create(4);
}
#[test]
fn make_ipp_32() {
test_helper_create(32);
}
#[test]
fn make_ipp_64() {
test_helper_create(64);
}
#[test]
fn test_inner_product() {
let a = vec![
Scalar::from(1u64),
Scalar::from(2u64),
Scalar::from(3u64),
Scalar::from(4u64),
];
let b = vec![
Scalar::from(2u64),
Scalar::from(3u64),
Scalar::from(4u64),
Scalar::from(5u64),
];
assert_eq!(Scalar::from(40u64), inner_product(&a, &b));
}
}