use core::iter;
use curve25519_dalek::ristretto::RistrettoPoint;
use curve25519_dalek::scalar::Scalar;
use merlin::Transcript;
use errors::MPCError;
use generators::{BulletproofGens, PedersenGens};
use inner_product_proof;
use range_proof::RangeProof;
use transcript::TranscriptProtocol;
use util;
use super::messages::*;
pub struct Dealer {}
impl Dealer {
    
    pub fn new<'a, 'b>(
        bp_gens: &'b BulletproofGens,
        pc_gens: &'b PedersenGens,
        transcript: &'a mut Transcript,
        n: usize,
        m: usize,
    ) -> Result<DealerAwaitingBitCommitments<'a, 'b>, MPCError> {
        if !(n == 8 || n == 16 || n == 32 || n == 64) {
            return Err(MPCError::InvalidBitsize);
        }
        if !m.is_power_of_two() {
            return Err(MPCError::InvalidAggregation);
        }
        if bp_gens.gens_capacity < n {
            return Err(MPCError::InvalidGeneratorsLength);
        }
        if bp_gens.party_capacity < m {
            return Err(MPCError::InvalidGeneratorsLength);
        }
        
        
        
        
        
        
        
        
        
        
        
        
        let initial_transcript = transcript.clone();
        transcript.rangeproof_domain_sep(n as u64, m as u64);
        Ok(DealerAwaitingBitCommitments {
            bp_gens,
            pc_gens,
            transcript,
            initial_transcript,
            n,
            m,
        })
    }
}
pub struct DealerAwaitingBitCommitments<'a, 'b> {
    bp_gens: &'b BulletproofGens,
    pc_gens: &'b PedersenGens,
    transcript: &'a mut Transcript,
    
    
    initial_transcript: Transcript,
    n: usize,
    m: usize,
}
impl<'a, 'b> DealerAwaitingBitCommitments<'a, 'b> {
    
    pub fn receive_bit_commitments(
        self,
        bit_commitments: Vec<BitCommitment>,
    ) -> Result<(DealerAwaitingPolyCommitments<'a, 'b>, BitChallenge), MPCError> {
        if self.m != bit_commitments.len() {
            return Err(MPCError::WrongNumBitCommitments);
        }
        
        for vc in bit_commitments.iter() {
            self.transcript.append_point(b"V", &vc.V_j);
        }
        
        let A: RistrettoPoint = bit_commitments.iter().map(|vc| vc.A_j).sum();
        self.transcript.append_point(b"A", &A.compress());
        let S: RistrettoPoint = bit_commitments.iter().map(|vc| vc.S_j).sum();
        self.transcript.append_point(b"S", &S.compress());
        let y = self.transcript.challenge_scalar(b"y");
        let z = self.transcript.challenge_scalar(b"z");
        let bit_challenge = BitChallenge { y, z };
        Ok((
            DealerAwaitingPolyCommitments {
                n: self.n,
                m: self.m,
                transcript: self.transcript,
                initial_transcript: self.initial_transcript,
                bp_gens: self.bp_gens,
                pc_gens: self.pc_gens,
                bit_challenge,
                bit_commitments,
                A,
                S,
            },
            bit_challenge,
        ))
    }
}
pub struct DealerAwaitingPolyCommitments<'a, 'b> {
    n: usize,
    m: usize,
    transcript: &'a mut Transcript,
    initial_transcript: Transcript,
    bp_gens: &'b BulletproofGens,
    pc_gens: &'b PedersenGens,
    bit_challenge: BitChallenge,
    bit_commitments: Vec<BitCommitment>,
    
    A: RistrettoPoint,
    
    S: RistrettoPoint,
}
impl<'a, 'b> DealerAwaitingPolyCommitments<'a, 'b> {
    
    
    pub fn receive_poly_commitments(
        self,
        poly_commitments: Vec<PolyCommitment>,
    ) -> Result<(DealerAwaitingProofShares<'a, 'b>, PolyChallenge), MPCError> {
        if self.m != poly_commitments.len() {
            return Err(MPCError::WrongNumPolyCommitments);
        }
        
        let T_1: RistrettoPoint = poly_commitments.iter().map(|pc| pc.T_1_j).sum();
        let T_2: RistrettoPoint = poly_commitments.iter().map(|pc| pc.T_2_j).sum();
        self.transcript.append_point(b"T_1", &T_1.compress());
        self.transcript.append_point(b"T_2", &T_2.compress());
        let x = self.transcript.challenge_scalar(b"x");
        let poly_challenge = PolyChallenge { x };
        Ok((
            DealerAwaitingProofShares {
                n: self.n,
                m: self.m,
                transcript: self.transcript,
                initial_transcript: self.initial_transcript,
                bp_gens: self.bp_gens,
                pc_gens: self.pc_gens,
                bit_challenge: self.bit_challenge,
                bit_commitments: self.bit_commitments,
                A: self.A,
                S: self.S,
                poly_challenge,
                poly_commitments,
                T_1,
                T_2,
            },
            poly_challenge,
        ))
    }
}
pub struct DealerAwaitingProofShares<'a, 'b> {
    n: usize,
    m: usize,
    transcript: &'a mut Transcript,
    initial_transcript: Transcript,
    bp_gens: &'b BulletproofGens,
    pc_gens: &'b PedersenGens,
    bit_challenge: BitChallenge,
    bit_commitments: Vec<BitCommitment>,
    poly_challenge: PolyChallenge,
    poly_commitments: Vec<PolyCommitment>,
    A: RistrettoPoint,
    S: RistrettoPoint,
    T_1: RistrettoPoint,
    T_2: RistrettoPoint,
}
impl<'a, 'b> DealerAwaitingProofShares<'a, 'b> {
    
    
    
    
    
    fn assemble_shares(&mut self, proof_shares: &[ProofShare]) -> Result<RangeProof, MPCError> {
        if self.m != proof_shares.len() {
            return Err(MPCError::WrongNumProofShares);
        }
        let t_x: Scalar = proof_shares.iter().map(|ps| ps.t_x).sum();
        let t_x_blinding: Scalar = proof_shares.iter().map(|ps| ps.t_x_blinding).sum();
        let e_blinding: Scalar = proof_shares.iter().map(|ps| ps.e_blinding).sum();
        self.transcript.append_scalar(b"t_x", &t_x);
        self.transcript
            .append_scalar(b"t_x_blinding", &t_x_blinding);
        self.transcript.append_scalar(b"e_blinding", &e_blinding);
        
        let w = self.transcript.challenge_scalar(b"w");
        let Q = w * self.pc_gens.B;
        let G_factors: Vec<Scalar> = iter::repeat(Scalar::one()).take(self.n * self.m).collect();
        let H_factors: Vec<Scalar> = util::exp_iter(self.bit_challenge.y.invert())
            .take(self.n * self.m)
            .collect();
        let l_vec: Vec<Scalar> = proof_shares
            .iter()
            .flat_map(|ps| ps.l_vec.clone().into_iter())
            .collect();
        let r_vec: Vec<Scalar> = proof_shares
            .iter()
            .flat_map(|ps| ps.r_vec.clone().into_iter())
            .collect();
        let ipp_proof = inner_product_proof::InnerProductProof::create(
            self.transcript,
            &Q,
            &G_factors,
            &H_factors,
            self.bp_gens.G(self.n, self.m).cloned().collect(),
            self.bp_gens.H(self.n, self.m).cloned().collect(),
            l_vec,
            r_vec,
        );
        Ok(RangeProof {
            A: self.A.compress(),
            S: self.S.compress(),
            T_1: self.T_1.compress(),
            T_2: self.T_2.compress(),
            t_x,
            t_x_blinding,
            e_blinding,
            ipp_proof,
        })
    }
    
    
    
    
    
    
    
    
    
    
    
    
    
    pub fn receive_shares(mut self, proof_shares: &[ProofShare]) -> Result<RangeProof, MPCError> {
        let proof = self.assemble_shares(proof_shares)?;
        let Vs: Vec<_> = self.bit_commitments.iter().map(|vc| vc.V_j).collect();
        
        let transcript = &mut self.initial_transcript;
        if proof
            .verify_multiple(self.bp_gens, self.pc_gens, transcript, &Vs, self.n)
            .is_ok()
        {
            Ok(proof)
        } else {
            
            let mut bad_shares = Vec::new();
            for j in 0..self.m {
                match proof_shares[j].audit_share(
                    &self.bp_gens,
                    &self.pc_gens,
                    j,
                    &self.bit_commitments[j],
                    &self.bit_challenge,
                    &self.poly_commitments[j],
                    &self.poly_challenge,
                ) {
                    Ok(_) => {}
                    Err(_) => bad_shares.push(j),
                }
            }
            Err(MPCError::MalformedProofShares { bad_shares })
        }
    }
    
    
    
    
    
    
    
    
    
    
    
    
    
    
    pub fn receive_trusted_shares(
        mut self,
        proof_shares: &[ProofShare],
    ) -> Result<RangeProof, MPCError> {
        self.assemble_shares(proof_shares)
    }
}