1use ff::{Field, PrimeField};
22use group::{Group, GroupEncoding};
23use sha3::Digest;
24use sha3::Sha3_256;
25
26use crate::{
27 errors::Error,
28 linear_relation::LinearRelation,
29 schnorr_protocol::SchnorrProof,
30 serialization::{deserialize_scalars, serialize_scalars},
31 traits::{SigmaProtocol, SigmaProtocolSimulator},
32};
33
34#[derive(Clone)]
41pub enum Protocol<G: Group + GroupEncoding> {
42 Simple(SchnorrProof<G>),
43 And(Vec<Protocol<G>>),
44 Or(Vec<Protocol<G>>),
45}
46
47impl<G> From<SchnorrProof<G>> for Protocol<G>
48where
49 G: Group + GroupEncoding,
50{
51 fn from(value: SchnorrProof<G>) -> Self {
52 Protocol::Simple(value)
53 }
54}
55
56impl<G> From<LinearRelation<G>> for Protocol<G>
57where
58 G: Group + GroupEncoding,
59{
60 fn from(value: LinearRelation<G>) -> Self {
61 Self::from(SchnorrProof::from(value))
62 }
63}
64
65#[derive(Clone)]
67pub enum ProtocolCommitment<G: Group + GroupEncoding> {
68 Simple(<SchnorrProof<G> as SigmaProtocol>::Commitment),
69 And(Vec<ProtocolCommitment<G>>),
70 Or(Vec<ProtocolCommitment<G>>),
71}
72
73#[derive(Clone)]
75pub enum ProtocolProverState<G: Group + GroupEncoding> {
76 Simple(<SchnorrProof<G> as SigmaProtocol>::ProverState),
77 And(Vec<ProtocolProverState<G>>),
78 Or(
79 usize, Vec<ProtocolProverState<G>>, (Vec<ProtocolChallenge<G>>, Vec<ProtocolResponse<G>>), ),
83}
84
85#[derive(Clone)]
87pub enum ProtocolResponse<G: Group + GroupEncoding> {
88 Simple(<SchnorrProof<G> as SigmaProtocol>::Response),
89 And(Vec<ProtocolResponse<G>>),
90 Or(Vec<ProtocolChallenge<G>>, Vec<ProtocolResponse<G>>),
91}
92
93pub enum ProtocolWitness<G: Group + GroupEncoding> {
95 Simple(<SchnorrProof<G> as SigmaProtocol>::Witness),
96 And(Vec<ProtocolWitness<G>>),
97 Or(usize, Vec<ProtocolWitness<G>>),
98}
99
100type ProtocolChallenge<G> = <SchnorrProof<G> as SigmaProtocol>::Challenge;
102
103impl<G: Group + GroupEncoding> SigmaProtocol for Protocol<G> {
104 type Commitment = ProtocolCommitment<G>;
105 type ProverState = ProtocolProverState<G>;
106 type Response = ProtocolResponse<G>;
107 type Witness = ProtocolWitness<G>;
108 type Challenge = ProtocolChallenge<G>;
109
110 fn prover_commit(
111 &self,
112 witness: &Self::Witness,
113 rng: &mut (impl rand::Rng + rand::CryptoRng),
114 ) -> Result<(Self::Commitment, Self::ProverState), Error> {
115 match (self, witness) {
116 (Protocol::Simple(p), ProtocolWitness::Simple(w)) => {
117 p.prover_commit(w, rng).map(|(c, s)| {
118 (
119 ProtocolCommitment::Simple(c),
120 ProtocolProverState::Simple(s),
121 )
122 })
123 }
124 (Protocol::And(ps), ProtocolWitness::And(ws)) => {
125 if ps.len() != ws.len() {
126 return Err(Error::InvalidInstanceWitnessPair);
127 }
128 let mut commitments = Vec::with_capacity(ps.len());
129 let mut prover_states = Vec::with_capacity(ps.len());
130
131 for (p, w) in ps.iter().zip(ws.iter()) {
132 let (c, s) = p.prover_commit(w, rng)?;
133 commitments.push(c);
134 prover_states.push(s);
135 }
136
137 Ok((
138 ProtocolCommitment::And(commitments),
139 ProtocolProverState::And(prover_states),
140 ))
141 }
142 (Protocol::Or(ps), ProtocolWitness::Or(w_index, w)) => {
143 let mut commitments = Vec::new();
144 let mut simulated_challenges = Vec::new();
145 let mut simulated_responses = Vec::new();
146
147 let (real_commitment, real_state) = ps[*w_index].prover_commit(&w[0], rng)?;
148
149 for i in (0..ps.len()).filter(|i| i != w_index) {
150 let (commitment, challenge, response) = ps[i].simulate_transcript(rng)?;
151 commitments.push(commitment);
152 simulated_challenges.push(challenge);
153 simulated_responses.push(response);
154 }
155 commitments.insert(*w_index, real_commitment);
156
157 Ok((
158 ProtocolCommitment::Or(commitments),
159 ProtocolProverState::Or(
160 *w_index,
161 vec![real_state],
162 (simulated_challenges, simulated_responses),
163 ),
164 ))
165 }
166 _ => unreachable!(),
167 }
168 }
169
170 fn prover_response(
171 &self,
172 state: Self::ProverState,
173 challenge: &Self::Challenge,
174 ) -> Result<Self::Response, Error> {
175 match (self, state) {
176 (Protocol::Simple(p), ProtocolProverState::Simple(state)) => p
177 .prover_response(state, challenge)
178 .map(ProtocolResponse::Simple),
179 (Protocol::And(ps), ProtocolProverState::And(states)) => {
180 if ps.len() != states.len() {
181 return Err(Error::InvalidInstanceWitnessPair);
182 }
183 let responses: Result<Vec<_>, _> = ps
184 .iter()
185 .zip(states)
186 .map(|(p, s)| p.prover_response(s, challenge))
187 .collect();
188
189 Ok(ProtocolResponse::And(responses?))
190 }
191 (
192 Protocol::Or(ps),
193 ProtocolProverState::Or(
194 w_index,
195 real_state,
196 (simulated_challenges, simulated_responses),
197 ),
198 ) => {
199 let mut challenges = Vec::with_capacity(ps.len());
200 let mut responses = Vec::with_capacity(ps.len());
201
202 let mut real_challenge = *challenge;
203 for ch in &simulated_challenges {
204 real_challenge -= ch;
205 }
206 let real_response =
207 ps[w_index].prover_response(real_state[0].clone(), &real_challenge)?;
208
209 for (i, _) in ps.iter().enumerate() {
210 if i == w_index {
211 challenges.push(real_challenge);
212 responses.push(real_response.clone());
213 } else {
214 let simulated_index = if i < w_index { i } else { i - 1 };
215 challenges.push(simulated_challenges[simulated_index]);
216 responses.push(simulated_responses[simulated_index].clone());
217 }
218 }
219 Ok(ProtocolResponse::Or(challenges, responses))
220 }
221 _ => panic!(),
222 }
223 }
224
225 fn verifier(
226 &self,
227 commitment: &Self::Commitment,
228 challenge: &Self::Challenge,
229 response: &Self::Response,
230 ) -> Result<(), Error> {
231 match (self, commitment, response) {
232 (Protocol::Simple(p), ProtocolCommitment::Simple(c), ProtocolResponse::Simple(r)) => {
233 p.verifier(c, challenge, r)
234 }
235 (
236 Protocol::And(ps),
237 ProtocolCommitment::And(commitments),
238 ProtocolResponse::And(responses),
239 ) => ps
240 .iter()
241 .zip(commitments)
242 .zip(responses)
243 .try_for_each(|((p, c), r)| p.verifier(c, challenge, r)),
244 (
245 Protocol::Or(ps),
246 ProtocolCommitment::Or(commitments),
247 ProtocolResponse::Or(challenges, responses),
248 ) => {
249 let mut expected_difference = *challenge;
250 for (i, p) in ps.iter().enumerate() {
251 p.verifier(&commitments[i], &challenges[i], &responses[i])?;
252 expected_difference -= challenges[i];
253 }
254 match expected_difference.is_zero_vartime() {
255 true => Ok(()),
256 false => Err(Error::VerificationFailure),
257 }
258 }
259 _ => panic!(),
260 }
261 }
262
263 fn serialize_commitment(&self, commitment: &Self::Commitment) -> Vec<u8> {
264 match (self, commitment) {
265 (Protocol::Simple(p), ProtocolCommitment::Simple(c)) => p.serialize_commitment(c),
266 (Protocol::And(ps), ProtocolCommitment::And(commitments))
267 | (Protocol::Or(ps), ProtocolCommitment::Or(commitments)) => ps
268 .iter()
269 .zip(commitments)
270 .flat_map(|(p, c)| p.serialize_commitment(c))
271 .collect(),
272 _ => panic!(),
273 }
274 }
275
276 fn serialize_challenge(&self, challenge: &Self::Challenge) -> Vec<u8> {
277 serialize_scalars::<G>(&[*challenge])
278 }
279
280 fn instance_label(&self) -> impl AsRef<[u8]> {
281 match self {
282 Protocol::Simple(p) => {
283 let label = p.instance_label();
284 label.as_ref().to_vec()
285 }
286 Protocol::And(ps) => {
287 let mut bytes = Vec::new();
288 for p in ps {
289 bytes.extend(p.instance_label().as_ref());
290 }
291 bytes
292 }
293 Protocol::Or(ps) => {
294 let mut bytes = Vec::new();
295 for p in ps {
296 bytes.extend(p.instance_label().as_ref());
297 }
298 bytes
299 }
300 }
301 }
302
303 fn protocol_identifier(&self) -> impl AsRef<[u8]> {
304 let mut hasher = Sha3_256::new();
305
306 match self {
307 Protocol::Simple(p) => {
308 hasher.update([0u8; 32]);
310 hasher.update(p.protocol_identifier());
311 }
312 Protocol::And(protocols) => {
313 let mut hasher = Sha3_256::new();
314 hasher.update([1u8; 32]);
315 for p in protocols {
316 hasher.update(p.protocol_identifier());
317 }
318 }
319 Protocol::Or(protocols) => {
320 let mut hasher = Sha3_256::new();
321 hasher.update([2u8; 32]);
322 for p in protocols {
323 hasher.update(p.protocol_identifier());
324 }
325 }
326 }
327
328 hasher.finalize()
329 }
330
331 fn serialize_response(&self, response: &Self::Response) -> Vec<u8> {
332 match (self, response) {
333 (Protocol::Simple(p), ProtocolResponse::Simple(r)) => p.serialize_response(r),
334 (Protocol::And(ps), ProtocolResponse::And(responses)) => {
335 let mut bytes = Vec::new();
336 for (i, p) in ps.iter().enumerate() {
337 bytes.extend(p.serialize_response(&responses[i]));
338 }
339 bytes
340 }
341 (Protocol::Or(ps), ProtocolResponse::Or(challenges, responses)) => {
342 let mut bytes = Vec::new();
343 for (i, p) in ps.iter().enumerate() {
344 bytes.extend(&serialize_scalars::<G>(&[challenges[i]]));
345 bytes.extend(p.serialize_response(&responses[i]));
346 }
347 bytes
348 }
349 _ => panic!(),
350 }
351 }
352
353 fn deserialize_commitment(&self, data: &[u8]) -> Result<Self::Commitment, Error> {
354 match self {
355 Protocol::Simple(p) => {
356 let c = p.deserialize_commitment(data)?;
357 Ok(ProtocolCommitment::Simple(c))
358 }
359 Protocol::And(ps) | Protocol::Or(ps) => {
360 let mut cursor = 0;
361 let mut commitments = Vec::with_capacity(ps.len());
362
363 for p in ps {
364 let c = p.deserialize_commitment(&data[cursor..])?;
365 let size = p.serialize_commitment(&c).len();
366 cursor += size;
367 commitments.push(c);
368 }
369
370 Ok(match self {
371 Protocol::And(_) => ProtocolCommitment::And(commitments),
372 Protocol::Or(_) => ProtocolCommitment::Or(commitments),
373 _ => unreachable!(),
374 })
375 }
376 }
377 }
378
379 fn deserialize_challenge(&self, data: &[u8]) -> Result<Self::Challenge, Error> {
380 let scalars = deserialize_scalars::<G>(data, 1).ok_or(Error::VerificationFailure)?;
381 Ok(scalars[0])
382 }
383
384 fn deserialize_response(&self, data: &[u8]) -> Result<Self::Response, Error> {
385 match self {
386 Protocol::Simple(p) => {
387 let r = p.deserialize_response(data)?;
388 Ok(ProtocolResponse::Simple(r))
389 }
390 Protocol::And(ps) => {
391 let mut cursor = 0;
392 let mut responses = Vec::with_capacity(ps.len());
393 for p in ps {
394 let r = p.deserialize_response(&data[cursor..])?;
395 let size = p.serialize_response(&r).len();
396 cursor += size;
397 responses.push(r);
398 }
399 Ok(ProtocolResponse::And(responses))
400 }
401 Protocol::Or(ps) => {
402 let ch_bytes_len = <<G as Group>::Scalar as PrimeField>::Repr::default()
403 .as_ref()
404 .len();
405 let mut cursor = 0;
406 let mut challenges = Vec::with_capacity(ps.len());
407 let mut responses = Vec::with_capacity(ps.len());
408 for p in ps {
409 let ch_vec = deserialize_scalars::<G>(&data[cursor..cursor + ch_bytes_len], 1)
410 .ok_or(Error::VerificationFailure)?;
411 let ch = ch_vec[0];
412 cursor += ch_bytes_len;
413 let r = p.deserialize_response(&data[cursor..])?;
414 let size = p.serialize_response(&r).len();
415 cursor += size;
416 challenges.push(ch);
417 responses.push(r);
418 }
419 Ok(ProtocolResponse::Or(challenges, responses))
420 }
421 }
422 }
423}
424
425impl<G: Group + GroupEncoding> SigmaProtocolSimulator for Protocol<G> {
426 fn simulate_commitment(
427 &self,
428 challenge: &Self::Challenge,
429 response: &Self::Response,
430 ) -> Result<Self::Commitment, Error> {
431 match (self, response) {
432 (Protocol::Simple(p), ProtocolResponse::Simple(r)) => Ok(ProtocolCommitment::Simple(
433 p.simulate_commitment(challenge, r)?,
434 )),
435 (Protocol::And(ps), ProtocolResponse::And(rs)) => {
436 let commitments = ps
437 .iter()
438 .zip(rs)
439 .map(|(p, r)| p.simulate_commitment(challenge, r))
440 .collect::<Result<Vec<_>, _>>()?;
441 Ok(ProtocolCommitment::And(commitments))
442 }
443 (Protocol::Or(ps), ProtocolResponse::Or(challenges, rs)) => {
444 let commitments = ps
445 .iter()
446 .zip(challenges)
447 .zip(rs)
448 .map(|((p, ch), r)| p.simulate_commitment(ch, r))
449 .collect::<Result<Vec<_>, _>>()?;
450 Ok(ProtocolCommitment::Or(commitments))
451 }
452 _ => panic!(),
453 }
454 }
455
456 fn simulate_response<R: rand::Rng + rand::CryptoRng>(&self, rng: &mut R) -> Self::Response {
457 match self {
458 Protocol::Simple(p) => ProtocolResponse::Simple(p.simulate_response(rng)),
459 Protocol::And(ps) => {
460 ProtocolResponse::And(ps.iter().map(|p| p.simulate_response(rng)).collect())
461 }
462 Protocol::Or(ps) => {
463 let mut challenges = Vec::with_capacity(ps.len());
464 let mut responses = Vec::with_capacity(ps.len());
465 for _ in 0..ps.len() {
466 challenges.push(G::Scalar::random(&mut *rng));
467 }
468 for p in ps.iter() {
469 responses.push(p.simulate_response(&mut *rng));
470 }
471 ProtocolResponse::Or(challenges, responses)
472 }
473 }
474 }
475
476 fn simulate_transcript<R: rand::Rng + rand::CryptoRng>(
477 &self,
478 rng: &mut R,
479 ) -> Result<(Self::Commitment, Self::Challenge, Self::Response), Error> {
480 match self {
481 Protocol::Simple(p) => {
482 let (c, ch, r) = p.simulate_transcript(rng)?;
483 Ok((
484 ProtocolCommitment::Simple(c),
485 ch,
486 ProtocolResponse::Simple(r),
487 ))
488 }
489 Protocol::And(ps) => {
490 let challenge = G::Scalar::random(&mut *rng);
491 let mut responses = Vec::with_capacity(ps.len());
492 for p in ps.iter() {
493 responses.push(p.simulate_response(&mut *rng));
494 }
495 let commitments = ps
496 .iter()
497 .enumerate()
498 .map(|(i, p)| p.simulate_commitment(&challenge, &responses[i]))
499 .collect::<Result<Vec<_>, Error>>()?;
500
501 Ok((
502 ProtocolCommitment::And(commitments),
503 challenge,
504 ProtocolResponse::And(responses),
505 ))
506 }
507 Protocol::Or(ps) => {
508 let mut commitments = Vec::with_capacity(ps.len());
509 let mut challenges = Vec::with_capacity(ps.len());
510 let mut responses = Vec::with_capacity(ps.len());
511
512 for p in ps.iter() {
513 let (c, ch, r) = p.simulate_transcript(rng)?;
514 commitments.push(c);
515 challenges.push(ch);
516 responses.push(r);
517 }
518 let challenge = challenges.iter().sum();
519 Ok((
520 ProtocolCommitment::Or(commitments),
521 challenge,
522 ProtocolResponse::Or(challenges, responses),
523 ))
524 }
525 }
526 }
527}