1pub use crate::duplex_sponge::keccak::KeccakDuplexSponge;
4use crate::duplex_sponge::{shake::ShakeDuplexSponge, DuplexSpongeInterface};
5use ff::PrimeField;
6use group::{Group, GroupEncoding};
7use num_bigint::BigUint;
8use num_traits::identities::One;
9
10pub trait Codec {
22 type Challenge;
23
24 fn new(protocol_identifier: &[u8], session_identifier: &[u8], instance_label: &[u8]) -> Self;
26
27 fn from_iv(iv: [u8; 32]) -> Self;
29
30 fn prover_message(&mut self, data: &[u8]);
32
33 fn verifier_challenge(&mut self) -> Self::Challenge;
35}
36
37fn cardinal<F: PrimeField>() -> BigUint {
38 let bytes = (F::ZERO - F::ONE).to_repr();
39 BigUint::from_bytes_le(bytes.as_ref()) + BigUint::one()
40}
41
42#[derive(Clone)]
47pub struct ByteSchnorrCodec<G, H>
48where
49 G: Group + GroupEncoding,
50 H: DuplexSpongeInterface,
51{
52 hasher: H,
53 _marker: core::marker::PhantomData<G>,
54}
55
56impl<G, H> Codec for ByteSchnorrCodec<G, H>
57where
58 G: Group + GroupEncoding,
59 H: DuplexSpongeInterface,
60{
61 type Challenge = <G as Group>::Scalar;
62
63 fn new(protocol_id: &[u8], session_id: &[u8], instance_label: &[u8]) -> Self {
64 let iv = {
65 let mut tmp = H::new([0u8; 32]);
66 tmp.absorb(protocol_id);
67 tmp.ratchet();
68 tmp.absorb(session_id);
69 tmp.ratchet();
70 tmp.absorb(instance_label);
71 tmp.squeeze(32).try_into().unwrap()
72 };
73
74 Self::from_iv(iv)
75 }
76
77 fn from_iv(iv: [u8; 32]) -> Self {
78 Self {
79 hasher: H::new(iv),
80 _marker: core::marker::PhantomData,
81 }
82 }
83
84 fn prover_message(&mut self, data: &[u8]) {
85 self.hasher.absorb(data);
86 }
87
88 fn verifier_challenge(&mut self) -> G::Scalar {
89 #[allow(clippy::manual_div_ceil)]
90 let scalar_byte_length = (G::Scalar::NUM_BITS as usize + 7) / 8;
91
92 let uniform_bytes = self.hasher.squeeze(scalar_byte_length + 16);
93 let scalar = BigUint::from_bytes_be(&uniform_bytes);
94 let reduced = scalar % cardinal::<G::Scalar>();
95
96 let mut bytes = vec![0u8; scalar_byte_length];
97 let reduced_bytes = reduced.to_bytes_be();
98 let start = bytes.len() - reduced_bytes.len();
99 bytes[start..].copy_from_slice(&reduced_bytes);
100 bytes.reverse();
101
102 let mut repr = <<G as Group>::Scalar as PrimeField>::Repr::default();
103 repr.as_mut().copy_from_slice(&bytes);
104
105 <<G as Group>::Scalar as PrimeField>::from_repr(repr).expect("Error")
106 }
107}
108
109pub type KeccakByteSchnorrCodec<G> = ByteSchnorrCodec<G, KeccakDuplexSponge>;
112
113pub type ShakeCodec<G> = ByteSchnorrCodec<G, ShakeDuplexSponge>;