#![allow(non_snake_case)]
use clear_on_drop::clear::Clear;
use core::mem;
use curve25519_dalek::ristretto::{CompressedRistretto, RistrettoPoint};
use curve25519_dalek::scalar::Scalar;
use curve25519_dalek::traits::{Identity, MultiscalarMul};
use merlin::Transcript;
use super::{
ConstraintSystem, LinearCombination, R1CSProof, RandomizableConstraintSystem,
RandomizedConstraintSystem, Variable,
};
use errors::R1CSError;
use generators::{BulletproofGens, PedersenGens};
use inner_product_proof::InnerProductProof;
use transcript::TranscriptProtocol;
pub struct Prover<'t, 'g> {
transcript: &'t mut Transcript,
pc_gens: &'g PedersenGens,
constraints: Vec<LinearCombination>,
a_L: Vec<Scalar>,
a_R: Vec<Scalar>,
a_O: Vec<Scalar>,
v: Vec<Scalar>,
v_blinding: Vec<Scalar>,
deferred_constraints: Vec<Box<Fn(&mut RandomizingProver<'t, 'g>) -> Result<(), R1CSError>>>,
pending_multiplier: Option<usize>,
}
pub struct RandomizingProver<'t, 'g> {
prover: Prover<'t, 'g>,
}
impl<'t, 'g> Drop for Prover<'t, 'g> {
fn drop(&mut self) {
self.v.clear();
self.v_blinding.clear();
for e in self.a_L.iter_mut() {
e.clear();
}
for e in self.a_R.iter_mut() {
e.clear();
}
for e in self.a_O.iter_mut() {
e.clear();
}
}
}
impl<'t, 'g> ConstraintSystem for Prover<'t, 'g> {
fn transcript(&mut self) -> &mut Transcript {
self.transcript
}
fn multiply(
&mut self,
mut left: LinearCombination,
mut right: LinearCombination,
) -> (Variable, Variable, Variable) {
let l = self.eval(&left);
let r = self.eval(&right);
let o = l * r;
let l_var = Variable::MultiplierLeft(self.a_L.len());
let r_var = Variable::MultiplierRight(self.a_R.len());
let o_var = Variable::MultiplierOutput(self.a_O.len());
self.a_L.push(l);
self.a_R.push(r);
self.a_O.push(o);
left.terms.push((l_var, -Scalar::one()));
right.terms.push((r_var, -Scalar::one()));
self.constrain(left);
self.constrain(right);
(l_var, r_var, o_var)
}
fn allocate(&mut self, assignment: Option<Scalar>) -> Result<Variable, R1CSError> {
let scalar = assignment.ok_or(R1CSError::MissingAssignment)?;
match self.pending_multiplier {
None => {
let i = self.a_L.len();
self.pending_multiplier = Some(i);
self.a_L.push(scalar);
self.a_R.push(Scalar::zero());
self.a_O.push(Scalar::zero());
Ok(Variable::MultiplierLeft(i))
}
Some(i) => {
self.pending_multiplier = None;
self.a_R[i] = scalar;
self.a_O[i] = self.a_L[i] * self.a_R[i];
Ok(Variable::MultiplierRight(i))
}
}
}
fn allocate_multiplier(
&mut self,
input_assignments: Option<(Scalar, Scalar)>,
) -> Result<(Variable, Variable, Variable), R1CSError> {
let (l, r) = input_assignments.ok_or(R1CSError::MissingAssignment)?;
let o = l * r;
let l_var = Variable::MultiplierLeft(self.a_L.len());
let r_var = Variable::MultiplierRight(self.a_R.len());
let o_var = Variable::MultiplierOutput(self.a_O.len());
self.a_L.push(l);
self.a_R.push(r);
self.a_O.push(o);
Ok((l_var, r_var, o_var))
}
fn multipliers_len(&self) -> usize {
self.a_L.len()
}
fn constrain(&mut self, lc: LinearCombination) {
self.constraints.push(lc);
}
}
impl<'t, 'g> RandomizableConstraintSystem for Prover<'t, 'g> {
type RandomizedCS = RandomizingProver<'t, 'g>;
fn specify_randomized_constraints<F>(&mut self, callback: F) -> Result<(), R1CSError>
where
F: 'static + Fn(&mut Self::RandomizedCS) -> Result<(), R1CSError>,
{
self.deferred_constraints.push(Box::new(callback));
Ok(())
}
}
impl<'t, 'g> ConstraintSystem for RandomizingProver<'t, 'g> {
fn transcript(&mut self) -> &mut Transcript {
self.prover.transcript
}
fn multiply(
&mut self,
left: LinearCombination,
right: LinearCombination,
) -> (Variable, Variable, Variable) {
self.prover.multiply(left, right)
}
fn allocate(&mut self, assignment: Option<Scalar>) -> Result<Variable, R1CSError> {
self.prover.allocate(assignment)
}
fn allocate_multiplier(
&mut self,
input_assignments: Option<(Scalar, Scalar)>,
) -> Result<(Variable, Variable, Variable), R1CSError> {
self.prover.allocate_multiplier(input_assignments)
}
fn multipliers_len(&self) -> usize {
self.prover.multipliers_len()
}
fn constrain(&mut self, lc: LinearCombination) {
self.prover.constrain(lc)
}
}
impl<'t, 'g> RandomizedConstraintSystem for RandomizingProver<'t, 'g> {
fn challenge_scalar(&mut self, label: &'static [u8]) -> Scalar {
self.prover.transcript.challenge_scalar(label)
}
}
impl<'t, 'g> Prover<'t, 'g> {
pub fn new(pc_gens: &'g PedersenGens, transcript: &'t mut Transcript) -> Self {
transcript.r1cs_domain_sep();
Prover {
pc_gens,
transcript,
v: Vec::new(),
v_blinding: Vec::new(),
constraints: Vec::new(),
a_L: Vec::new(),
a_R: Vec::new(),
a_O: Vec::new(),
deferred_constraints: Vec::new(),
pending_multiplier: None,
}
}
pub fn commit(&mut self, v: Scalar, v_blinding: Scalar) -> (CompressedRistretto, Variable) {
let i = self.v.len();
self.v.push(v);
self.v_blinding.push(v_blinding);
let V = self.pc_gens.commit(v, v_blinding).compress();
self.transcript.append_point(b"V", &V);
(V, Variable::Committed(i))
}
fn flattened_constraints(
&mut self,
z: &Scalar,
) -> (Vec<Scalar>, Vec<Scalar>, Vec<Scalar>, Vec<Scalar>) {
let n = self.a_L.len();
let m = self.v.len();
let mut wL = vec![Scalar::zero(); n];
let mut wR = vec![Scalar::zero(); n];
let mut wO = vec![Scalar::zero(); n];
let mut wV = vec![Scalar::zero(); m];
let mut exp_z = *z;
for lc in self.constraints.iter() {
for (var, coeff) in &lc.terms {
match var {
Variable::MultiplierLeft(i) => {
wL[*i] += exp_z * coeff;
}
Variable::MultiplierRight(i) => {
wR[*i] += exp_z * coeff;
}
Variable::MultiplierOutput(i) => {
wO[*i] += exp_z * coeff;
}
Variable::Committed(i) => {
wV[*i] -= exp_z * coeff;
}
Variable::One() => {
}
}
}
exp_z *= z;
}
(wL, wR, wO, wV)
}
fn eval(&self, lc: &LinearCombination) -> Scalar {
lc.terms
.iter()
.map(|(var, coeff)| {
coeff
* match var {
Variable::MultiplierLeft(i) => self.a_L[*i],
Variable::MultiplierRight(i) => self.a_R[*i],
Variable::MultiplierOutput(i) => self.a_O[*i],
Variable::Committed(i) => self.v[*i],
Variable::One() => Scalar::one(),
}
})
.sum()
}
fn create_randomized_constraints(mut self) -> Result<Self, R1CSError> {
self.pending_multiplier = None;
if self.deferred_constraints.len() == 0 {
self.transcript.r1cs_1phase_domain_sep();
Ok(self)
} else {
self.transcript.r1cs_2phase_domain_sep();
let mut callbacks = mem::replace(&mut self.deferred_constraints, Vec::new());
let mut wrapped_self = RandomizingProver { prover: self };
for callback in callbacks.drain(..) {
callback(&mut wrapped_self)?;
}
Ok(wrapped_self.prover)
}
}
pub fn prove(mut self, bp_gens: &BulletproofGens) -> Result<R1CSProof, R1CSError> {
use std::iter;
use util;
self.transcript.append_u64(b"m", self.v.len() as u64);
let mut rng = {
let mut builder = self.transcript.build_rng();
for v_b in &self.v_blinding {
builder = builder.rekey_with_witness_bytes(b"v_blinding", v_b.as_bytes());
}
use rand::thread_rng;
builder.finalize(&mut thread_rng())
};
let n1 = self.a_L.len();
if bp_gens.gens_capacity < n1 {
return Err(R1CSError::InvalidGeneratorsLength);
}
let gens = bp_gens.share(0);
let i_blinding1 = Scalar::random(&mut rng);
let o_blinding1 = Scalar::random(&mut rng);
let s_blinding1 = Scalar::random(&mut rng);
let mut s_L1: Vec<Scalar> = (0..n1).map(|_| Scalar::random(&mut rng)).collect();
let mut s_R1: Vec<Scalar> = (0..n1).map(|_| Scalar::random(&mut rng)).collect();
let A_I1 = RistrettoPoint::multiscalar_mul(
iter::once(&i_blinding1)
.chain(self.a_L.iter())
.chain(self.a_R.iter()),
iter::once(&self.pc_gens.B_blinding)
.chain(gens.G(n1))
.chain(gens.H(n1)),
)
.compress();
let A_O1 = RistrettoPoint::multiscalar_mul(
iter::once(&o_blinding1).chain(self.a_O.iter()),
iter::once(&self.pc_gens.B_blinding).chain(gens.G(n1)),
)
.compress();
let S1 = RistrettoPoint::multiscalar_mul(
iter::once(&s_blinding1)
.chain(s_L1.iter())
.chain(s_R1.iter()),
iter::once(&self.pc_gens.B_blinding)
.chain(gens.G(n1))
.chain(gens.H(n1)),
)
.compress();
self.transcript.append_point(b"A_I1", &A_I1);
self.transcript.append_point(b"A_O1", &A_O1);
self.transcript.append_point(b"S1", &S1);
self = self.create_randomized_constraints()?;
let n = self.a_L.len();
let n2 = n - n1;
let padded_n = self.a_L.len().next_power_of_two();
let pad = padded_n - n;
if bp_gens.gens_capacity < padded_n {
return Err(R1CSError::InvalidGeneratorsLength);
}
let has_2nd_phase_commitments = n2 > 0;
let (i_blinding2, o_blinding2, s_blinding2) = if has_2nd_phase_commitments {
(
Scalar::random(&mut rng),
Scalar::random(&mut rng),
Scalar::random(&mut rng),
)
} else {
(Scalar::zero(), Scalar::zero(), Scalar::zero())
};
let mut s_L2: Vec<Scalar> = (0..n2).map(|_| Scalar::random(&mut rng)).collect();
let mut s_R2: Vec<Scalar> = (0..n2).map(|_| Scalar::random(&mut rng)).collect();
let (A_I2, A_O2, S2) = if has_2nd_phase_commitments {
(
RistrettoPoint::multiscalar_mul(
iter::once(&i_blinding2)
.chain(self.a_L.iter().skip(n1))
.chain(self.a_R.iter().skip(n1)),
iter::once(&self.pc_gens.B_blinding)
.chain(gens.G(n).skip(n1))
.chain(gens.H(n).skip(n1)),
)
.compress(),
RistrettoPoint::multiscalar_mul(
iter::once(&o_blinding2).chain(self.a_O.iter().skip(n1)),
iter::once(&self.pc_gens.B_blinding).chain(gens.G(n).skip(n1)),
)
.compress(),
RistrettoPoint::multiscalar_mul(
iter::once(&s_blinding2)
.chain(s_L2.iter())
.chain(s_R2.iter()),
iter::once(&self.pc_gens.B_blinding)
.chain(gens.G(n).skip(n1))
.chain(gens.H(n).skip(n1)),
)
.compress(),
)
} else {
(
CompressedRistretto::identity(),
CompressedRistretto::identity(),
CompressedRistretto::identity(),
)
};
self.transcript.append_point(b"A_I2", &A_I2);
self.transcript.append_point(b"A_O2", &A_O2);
self.transcript.append_point(b"S2", &S2);
let y = self.transcript.challenge_scalar(b"y");
let z = self.transcript.challenge_scalar(b"z");
let (wL, wR, wO, wV) = self.flattened_constraints(&z);
let mut l_poly = util::VecPoly3::zero(n);
let mut r_poly = util::VecPoly3::zero(n);
let mut exp_y = Scalar::one();
let y_inv = y.invert();
let exp_y_inv = util::exp_iter(y_inv).take(padded_n).collect::<Vec<_>>();
let sLsR = s_L1
.iter()
.chain(s_L2.iter())
.zip(s_R1.iter().chain(s_R2.iter()));
for (i, (sl, sr)) in sLsR.enumerate() {
l_poly.1[i] = self.a_L[i] + exp_y_inv[i] * wR[i];
l_poly.2[i] = self.a_O[i];
l_poly.3[i] = *sl;
r_poly.0[i] = wO[i] - exp_y;
r_poly.1[i] = exp_y * self.a_R[i] + wL[i];
r_poly.3[i] = exp_y * sr;
exp_y = exp_y * y;
}
let t_poly = util::VecPoly3::special_inner_product(&l_poly, &r_poly);
let t_1_blinding = Scalar::random(&mut rng);
let t_3_blinding = Scalar::random(&mut rng);
let t_4_blinding = Scalar::random(&mut rng);
let t_5_blinding = Scalar::random(&mut rng);
let t_6_blinding = Scalar::random(&mut rng);
let T_1 = self.pc_gens.commit(t_poly.t1, t_1_blinding).compress();
let T_3 = self.pc_gens.commit(t_poly.t3, t_3_blinding).compress();
let T_4 = self.pc_gens.commit(t_poly.t4, t_4_blinding).compress();
let T_5 = self.pc_gens.commit(t_poly.t5, t_5_blinding).compress();
let T_6 = self.pc_gens.commit(t_poly.t6, t_6_blinding).compress();
self.transcript.append_point(b"T_1", &T_1);
self.transcript.append_point(b"T_3", &T_3);
self.transcript.append_point(b"T_4", &T_4);
self.transcript.append_point(b"T_5", &T_5);
self.transcript.append_point(b"T_6", &T_6);
let u = self.transcript.challenge_scalar(b"u");
let x = self.transcript.challenge_scalar(b"x");
let t_2_blinding = wV
.iter()
.zip(self.v_blinding.iter())
.map(|(c, v_blinding)| c * v_blinding)
.sum();
let t_blinding_poly = util::Poly6 {
t1: t_1_blinding,
t2: t_2_blinding,
t3: t_3_blinding,
t4: t_4_blinding,
t5: t_5_blinding,
t6: t_6_blinding,
};
let t_x = t_poly.eval(x);
let t_x_blinding = t_blinding_poly.eval(x);
let mut l_vec = l_poly.eval(x);
l_vec.append(&mut vec![Scalar::zero(); pad]);
let mut r_vec = r_poly.eval(x);
r_vec.append(&mut vec![Scalar::zero(); pad]);
for i in n..padded_n {
r_vec[i] = -exp_y;
exp_y = exp_y * y;
}
let i_blinding = i_blinding1 + u * i_blinding2;
let o_blinding = o_blinding1 + u * o_blinding2;
let s_blinding = s_blinding1 + u * s_blinding2;
let e_blinding = x * (i_blinding + x * (o_blinding + x * s_blinding));
self.transcript.append_scalar(b"t_x", &t_x);
self.transcript
.append_scalar(b"t_x_blinding", &t_x_blinding);
self.transcript.append_scalar(b"e_blinding", &e_blinding);
let w = self.transcript.challenge_scalar(b"w");
let Q = w * self.pc_gens.B;
let G_factors = iter::repeat(Scalar::one())
.take(n1)
.chain(iter::repeat(u).take(n2 + pad))
.collect::<Vec<_>>();
let H_factors = exp_y_inv
.into_iter()
.zip(G_factors.iter())
.map(|(y, u_or_1)| y * u_or_1)
.collect::<Vec<_>>();
let ipp_proof = InnerProductProof::create(
self.transcript,
&Q,
&G_factors,
&H_factors,
gens.G(padded_n).cloned().collect(),
gens.H(padded_n).cloned().collect(),
l_vec,
r_vec,
);
for scalar in s_L1
.iter_mut()
.chain(s_L2.iter_mut())
.chain(s_R1.iter_mut())
.chain(s_R2.iter_mut())
{
scalar.clear();
}
Ok(R1CSProof {
A_I1,
A_O1,
S1,
A_I2,
A_O2,
S2,
T_1,
T_3,
T_4,
T_5,
T_6,
t_x,
t_x_blinding,
e_blinding,
ipp_proof,
})
}
}