sigma_rs/
composition.rs

1//! # Protocol Composition with AND/OR Logic
2//!
3//! This module defines the [`Protocol`] enum, which generalizes the [`SchnorrProof`]
4//! by enabling compositional logic between multiple proof instances.
5//!
6//! Specifically, it supports:
7//! - Simple atomic proofs (e.g., discrete logarithm, Pedersen commitments)
8//! - Conjunctions (`And`) of multiple sub-protocols
9//! - Disjunctions (`Or`) of multiple sub-protocols
10//!
11//! ## Example Composition
12//!
13//! ```ignore
14//! And(
15//!    Or( dleq, pedersen_commitment ),
16//!    Simple( discrete_logarithm ),
17//!    And( pedersen_commitment_dleq, bbs_blind_commitment_computation )
18//! )
19//! ```
20
21use ff::{Field, PrimeField};
22use group::{Group, GroupEncoding};
23use sha3::Digest;
24use sha3::Sha3_256;
25
26use crate::{
27    errors::Error,
28    linear_relation::LinearRelation,
29    schnorr_protocol::SchnorrProof,
30    serialization::{deserialize_scalars, serialize_scalars},
31    traits::{SigmaProtocol, SigmaProtocolSimulator},
32};
33
34/// A protocol proving knowledge of a witness for a composition of SchnorrProof's.
35///
36/// This implementation generalizes [`SchnorrProof`] by using AND/OR links.
37///
38/// # Type Parameters
39/// - `G`: A cryptographic group implementing [`Group`] and [`GroupEncoding`].
40#[derive(Clone)]
41pub enum Protocol<G: Group + GroupEncoding> {
42    Simple(SchnorrProof<G>),
43    And(Vec<Protocol<G>>),
44    Or(Vec<Protocol<G>>),
45}
46
47impl<G> From<SchnorrProof<G>> for Protocol<G>
48where
49    G: Group + GroupEncoding,
50{
51    fn from(value: SchnorrProof<G>) -> Self {
52        Protocol::Simple(value)
53    }
54}
55
56impl<G> From<LinearRelation<G>> for Protocol<G>
57where
58    G: Group + GroupEncoding,
59{
60    fn from(value: LinearRelation<G>) -> Self {
61        Self::from(SchnorrProof::from(value))
62    }
63}
64
65// Structure representing the Commitment type of Protocol as SigmaProtocol
66#[derive(Clone)]
67pub enum ProtocolCommitment<G: Group + GroupEncoding> {
68    Simple(<SchnorrProof<G> as SigmaProtocol>::Commitment),
69    And(Vec<ProtocolCommitment<G>>),
70    Or(Vec<ProtocolCommitment<G>>),
71}
72
73// Structure representing the ProverState type of Protocol as SigmaProtocol
74#[derive(Clone)]
75pub enum ProtocolProverState<G: Group + GroupEncoding> {
76    Simple(<SchnorrProof<G> as SigmaProtocol>::ProverState),
77    And(Vec<ProtocolProverState<G>>),
78    Or(
79        usize,                                                 // real index
80        Vec<ProtocolProverState<G>>,                           // real ProverState
81        (Vec<ProtocolChallenge<G>>, Vec<ProtocolResponse<G>>), // simulated transcripts
82    ),
83}
84
85// Structure representing the Response type of Protocol as SigmaProtocol
86#[derive(Clone)]
87pub enum ProtocolResponse<G: Group + GroupEncoding> {
88    Simple(<SchnorrProof<G> as SigmaProtocol>::Response),
89    And(Vec<ProtocolResponse<G>>),
90    Or(Vec<ProtocolChallenge<G>>, Vec<ProtocolResponse<G>>),
91}
92
93// Structure representing the Witness type of Protocol as SigmaProtocol
94pub enum ProtocolWitness<G: Group + GroupEncoding> {
95    Simple(<SchnorrProof<G> as SigmaProtocol>::Witness),
96    And(Vec<ProtocolWitness<G>>),
97    Or(usize, Vec<ProtocolWitness<G>>),
98}
99
100// Structure representing the Challenge type of Protocol as SigmaProtocol
101type ProtocolChallenge<G> = <SchnorrProof<G> as SigmaProtocol>::Challenge;
102
103impl<G: Group + GroupEncoding> SigmaProtocol for Protocol<G> {
104    type Commitment = ProtocolCommitment<G>;
105    type ProverState = ProtocolProverState<G>;
106    type Response = ProtocolResponse<G>;
107    type Witness = ProtocolWitness<G>;
108    type Challenge = ProtocolChallenge<G>;
109
110    fn prover_commit(
111        &self,
112        witness: &Self::Witness,
113        rng: &mut (impl rand::Rng + rand::CryptoRng),
114    ) -> Result<(Self::Commitment, Self::ProverState), Error> {
115        match (self, witness) {
116            (Protocol::Simple(p), ProtocolWitness::Simple(w)) => {
117                p.prover_commit(w, rng).map(|(c, s)| {
118                    (
119                        ProtocolCommitment::Simple(c),
120                        ProtocolProverState::Simple(s),
121                    )
122                })
123            }
124            (Protocol::And(ps), ProtocolWitness::And(ws)) => {
125                if ps.len() != ws.len() {
126                    return Err(Error::InvalidInstanceWitnessPair);
127                }
128                let mut commitments = Vec::with_capacity(ps.len());
129                let mut prover_states = Vec::with_capacity(ps.len());
130
131                for (p, w) in ps.iter().zip(ws.iter()) {
132                    let (c, s) = p.prover_commit(w, rng)?;
133                    commitments.push(c);
134                    prover_states.push(s);
135                }
136
137                Ok((
138                    ProtocolCommitment::And(commitments),
139                    ProtocolProverState::And(prover_states),
140                ))
141            }
142            (Protocol::Or(ps), ProtocolWitness::Or(w_index, w)) => {
143                let mut commitments = Vec::new();
144                let mut simulated_challenges = Vec::new();
145                let mut simulated_responses = Vec::new();
146
147                let (real_commitment, real_state) = ps[*w_index].prover_commit(&w[0], rng)?;
148
149                for i in (0..ps.len()).filter(|i| i != w_index) {
150                    let (commitment, challenge, response) = ps[i].simulate_transcript(rng)?;
151                    commitments.push(commitment);
152                    simulated_challenges.push(challenge);
153                    simulated_responses.push(response);
154                }
155                commitments.insert(*w_index, real_commitment);
156
157                Ok((
158                    ProtocolCommitment::Or(commitments),
159                    ProtocolProverState::Or(
160                        *w_index,
161                        vec![real_state],
162                        (simulated_challenges, simulated_responses),
163                    ),
164                ))
165            }
166            _ => unreachable!(),
167        }
168    }
169
170    fn prover_response(
171        &self,
172        state: Self::ProverState,
173        challenge: &Self::Challenge,
174    ) -> Result<Self::Response, Error> {
175        match (self, state) {
176            (Protocol::Simple(p), ProtocolProverState::Simple(state)) => p
177                .prover_response(state, challenge)
178                .map(ProtocolResponse::Simple),
179            (Protocol::And(ps), ProtocolProverState::And(states)) => {
180                if ps.len() != states.len() {
181                    return Err(Error::InvalidInstanceWitnessPair);
182                }
183                let responses: Result<Vec<_>, _> = ps
184                    .iter()
185                    .zip(states)
186                    .map(|(p, s)| p.prover_response(s, challenge))
187                    .collect();
188
189                Ok(ProtocolResponse::And(responses?))
190            }
191            (
192                Protocol::Or(ps),
193                ProtocolProverState::Or(
194                    w_index,
195                    real_state,
196                    (simulated_challenges, simulated_responses),
197                ),
198            ) => {
199                let mut challenges = Vec::with_capacity(ps.len());
200                let mut responses = Vec::with_capacity(ps.len());
201
202                let mut real_challenge = *challenge;
203                for ch in &simulated_challenges {
204                    real_challenge -= ch;
205                }
206                let real_response =
207                    ps[w_index].prover_response(real_state[0].clone(), &real_challenge)?;
208
209                for (i, _) in ps.iter().enumerate() {
210                    if i == w_index {
211                        challenges.push(real_challenge);
212                        responses.push(real_response.clone());
213                    } else {
214                        let simulated_index = if i < w_index { i } else { i - 1 };
215                        challenges.push(simulated_challenges[simulated_index]);
216                        responses.push(simulated_responses[simulated_index].clone());
217                    }
218                }
219                Ok(ProtocolResponse::Or(challenges, responses))
220            }
221            _ => panic!(),
222        }
223    }
224
225    fn verifier(
226        &self,
227        commitment: &Self::Commitment,
228        challenge: &Self::Challenge,
229        response: &Self::Response,
230    ) -> Result<(), Error> {
231        match (self, commitment, response) {
232            (Protocol::Simple(p), ProtocolCommitment::Simple(c), ProtocolResponse::Simple(r)) => {
233                p.verifier(c, challenge, r)
234            }
235            (
236                Protocol::And(ps),
237                ProtocolCommitment::And(commitments),
238                ProtocolResponse::And(responses),
239            ) => ps
240                .iter()
241                .zip(commitments)
242                .zip(responses)
243                .try_for_each(|((p, c), r)| p.verifier(c, challenge, r)),
244            (
245                Protocol::Or(ps),
246                ProtocolCommitment::Or(commitments),
247                ProtocolResponse::Or(challenges, responses),
248            ) => {
249                let mut expected_difference = *challenge;
250                for (i, p) in ps.iter().enumerate() {
251                    p.verifier(&commitments[i], &challenges[i], &responses[i])?;
252                    expected_difference -= challenges[i];
253                }
254                match expected_difference.is_zero_vartime() {
255                    true => Ok(()),
256                    false => Err(Error::VerificationFailure),
257                }
258            }
259            _ => panic!(),
260        }
261    }
262
263    fn serialize_commitment(&self, commitment: &Self::Commitment) -> Vec<u8> {
264        match (self, commitment) {
265            (Protocol::Simple(p), ProtocolCommitment::Simple(c)) => p.serialize_commitment(c),
266            (Protocol::And(ps), ProtocolCommitment::And(commitments))
267            | (Protocol::Or(ps), ProtocolCommitment::Or(commitments)) => ps
268                .iter()
269                .zip(commitments)
270                .flat_map(|(p, c)| p.serialize_commitment(c))
271                .collect(),
272            _ => panic!(),
273        }
274    }
275
276    fn serialize_challenge(&self, challenge: &Self::Challenge) -> Vec<u8> {
277        serialize_scalars::<G>(&[*challenge])
278    }
279
280    fn instance_label(&self) -> impl AsRef<[u8]> {
281        match self {
282            Protocol::Simple(p) => {
283                let label = p.instance_label();
284                label.as_ref().to_vec()
285            }
286            Protocol::And(ps) => {
287                let mut bytes = Vec::new();
288                for p in ps {
289                    bytes.extend(p.instance_label().as_ref());
290                }
291                bytes
292            }
293            Protocol::Or(ps) => {
294                let mut bytes = Vec::new();
295                for p in ps {
296                    bytes.extend(p.instance_label().as_ref());
297                }
298                bytes
299            }
300        }
301    }
302
303    fn protocol_identifier(&self) -> impl AsRef<[u8]> {
304        let mut hasher = Sha3_256::new();
305
306        match self {
307            Protocol::Simple(p) => {
308                // take the digest of the simple protocol id
309                hasher.update([0u8; 32]);
310                hasher.update(p.protocol_identifier());
311            }
312            Protocol::And(protocols) => {
313                let mut hasher = Sha3_256::new();
314                hasher.update([1u8; 32]);
315                for p in protocols {
316                    hasher.update(p.protocol_identifier());
317                }
318            }
319            Protocol::Or(protocols) => {
320                let mut hasher = Sha3_256::new();
321                hasher.update([2u8; 32]);
322                for p in protocols {
323                    hasher.update(p.protocol_identifier());
324                }
325            }
326        }
327
328        hasher.finalize()
329    }
330
331    fn serialize_response(&self, response: &Self::Response) -> Vec<u8> {
332        match (self, response) {
333            (Protocol::Simple(p), ProtocolResponse::Simple(r)) => p.serialize_response(r),
334            (Protocol::And(ps), ProtocolResponse::And(responses)) => {
335                let mut bytes = Vec::new();
336                for (i, p) in ps.iter().enumerate() {
337                    bytes.extend(p.serialize_response(&responses[i]));
338                }
339                bytes
340            }
341            (Protocol::Or(ps), ProtocolResponse::Or(challenges, responses)) => {
342                let mut bytes = Vec::new();
343                for (i, p) in ps.iter().enumerate() {
344                    bytes.extend(&serialize_scalars::<G>(&[challenges[i]]));
345                    bytes.extend(p.serialize_response(&responses[i]));
346                }
347                bytes
348            }
349            _ => panic!(),
350        }
351    }
352
353    fn deserialize_commitment(&self, data: &[u8]) -> Result<Self::Commitment, Error> {
354        match self {
355            Protocol::Simple(p) => {
356                let c = p.deserialize_commitment(data)?;
357                Ok(ProtocolCommitment::Simple(c))
358            }
359            Protocol::And(ps) | Protocol::Or(ps) => {
360                let mut cursor = 0;
361                let mut commitments = Vec::with_capacity(ps.len());
362
363                for p in ps {
364                    let c = p.deserialize_commitment(&data[cursor..])?;
365                    let size = p.serialize_commitment(&c).len();
366                    cursor += size;
367                    commitments.push(c);
368                }
369
370                Ok(match self {
371                    Protocol::And(_) => ProtocolCommitment::And(commitments),
372                    Protocol::Or(_) => ProtocolCommitment::Or(commitments),
373                    _ => unreachable!(),
374                })
375            }
376        }
377    }
378
379    fn deserialize_challenge(&self, data: &[u8]) -> Result<Self::Challenge, Error> {
380        let scalars = deserialize_scalars::<G>(data, 1).ok_or(Error::VerificationFailure)?;
381        Ok(scalars[0])
382    }
383
384    fn deserialize_response(&self, data: &[u8]) -> Result<Self::Response, Error> {
385        match self {
386            Protocol::Simple(p) => {
387                let r = p.deserialize_response(data)?;
388                Ok(ProtocolResponse::Simple(r))
389            }
390            Protocol::And(ps) => {
391                let mut cursor = 0;
392                let mut responses = Vec::with_capacity(ps.len());
393                for p in ps {
394                    let r = p.deserialize_response(&data[cursor..])?;
395                    let size = p.serialize_response(&r).len();
396                    cursor += size;
397                    responses.push(r);
398                }
399                Ok(ProtocolResponse::And(responses))
400            }
401            Protocol::Or(ps) => {
402                let ch_bytes_len = <<G as Group>::Scalar as PrimeField>::Repr::default()
403                    .as_ref()
404                    .len();
405                let mut cursor = 0;
406                let mut challenges = Vec::with_capacity(ps.len());
407                let mut responses = Vec::with_capacity(ps.len());
408                for p in ps {
409                    let ch_vec = deserialize_scalars::<G>(&data[cursor..cursor + ch_bytes_len], 1)
410                        .ok_or(Error::VerificationFailure)?;
411                    let ch = ch_vec[0];
412                    cursor += ch_bytes_len;
413                    let r = p.deserialize_response(&data[cursor..])?;
414                    let size = p.serialize_response(&r).len();
415                    cursor += size;
416                    challenges.push(ch);
417                    responses.push(r);
418                }
419                Ok(ProtocolResponse::Or(challenges, responses))
420            }
421        }
422    }
423}
424
425impl<G: Group + GroupEncoding> SigmaProtocolSimulator for Protocol<G> {
426    fn simulate_commitment(
427        &self,
428        challenge: &Self::Challenge,
429        response: &Self::Response,
430    ) -> Result<Self::Commitment, Error> {
431        match (self, response) {
432            (Protocol::Simple(p), ProtocolResponse::Simple(r)) => Ok(ProtocolCommitment::Simple(
433                p.simulate_commitment(challenge, r)?,
434            )),
435            (Protocol::And(ps), ProtocolResponse::And(rs)) => {
436                let commitments = ps
437                    .iter()
438                    .zip(rs)
439                    .map(|(p, r)| p.simulate_commitment(challenge, r))
440                    .collect::<Result<Vec<_>, _>>()?;
441                Ok(ProtocolCommitment::And(commitments))
442            }
443            (Protocol::Or(ps), ProtocolResponse::Or(challenges, rs)) => {
444                let commitments = ps
445                    .iter()
446                    .zip(challenges)
447                    .zip(rs)
448                    .map(|((p, ch), r)| p.simulate_commitment(ch, r))
449                    .collect::<Result<Vec<_>, _>>()?;
450                Ok(ProtocolCommitment::Or(commitments))
451            }
452            _ => panic!(),
453        }
454    }
455
456    fn simulate_response<R: rand::Rng + rand::CryptoRng>(&self, rng: &mut R) -> Self::Response {
457        match self {
458            Protocol::Simple(p) => ProtocolResponse::Simple(p.simulate_response(rng)),
459            Protocol::And(ps) => {
460                ProtocolResponse::And(ps.iter().map(|p| p.simulate_response(rng)).collect())
461            }
462            Protocol::Or(ps) => {
463                let mut challenges = Vec::with_capacity(ps.len());
464                let mut responses = Vec::with_capacity(ps.len());
465                for _ in 0..ps.len() {
466                    challenges.push(G::Scalar::random(&mut *rng));
467                }
468                for p in ps.iter() {
469                    responses.push(p.simulate_response(&mut *rng));
470                }
471                ProtocolResponse::Or(challenges, responses)
472            }
473        }
474    }
475
476    fn simulate_transcript<R: rand::Rng + rand::CryptoRng>(
477        &self,
478        rng: &mut R,
479    ) -> Result<(Self::Commitment, Self::Challenge, Self::Response), Error> {
480        match self {
481            Protocol::Simple(p) => {
482                let (c, ch, r) = p.simulate_transcript(rng)?;
483                Ok((
484                    ProtocolCommitment::Simple(c),
485                    ch,
486                    ProtocolResponse::Simple(r),
487                ))
488            }
489            Protocol::And(ps) => {
490                let challenge = G::Scalar::random(&mut *rng);
491                let mut responses = Vec::with_capacity(ps.len());
492                for p in ps.iter() {
493                    responses.push(p.simulate_response(&mut *rng));
494                }
495                let commitments = ps
496                    .iter()
497                    .enumerate()
498                    .map(|(i, p)| p.simulate_commitment(&challenge, &responses[i]))
499                    .collect::<Result<Vec<_>, Error>>()?;
500
501                Ok((
502                    ProtocolCommitment::And(commitments),
503                    challenge,
504                    ProtocolResponse::And(responses),
505                ))
506            }
507            Protocol::Or(ps) => {
508                let mut commitments = Vec::with_capacity(ps.len());
509                let mut challenges = Vec::with_capacity(ps.len());
510                let mut responses = Vec::with_capacity(ps.len());
511
512                for p in ps.iter() {
513                    let (c, ch, r) = p.simulate_transcript(rng)?;
514                    commitments.push(c);
515                    challenges.push(ch);
516                    responses.push(r);
517                }
518                let challenge = challenges.iter().sum();
519                Ok((
520                    ProtocolCommitment::Or(commitments),
521                    challenge,
522                    ProtocolResponse::Or(challenges, responses),
523                ))
524            }
525        }
526    }
527}