sigma_rs/linear_relation/
ops.rs

1use core::ops::{Add, Mul, Neg, Sub};
2use ff::Field;
3use group::Group;
4
5use super::{GroupVar, ScalarTerm, ScalarVar, Sum, Term, Weighted};
6
7mod add {
8    use super::*;
9
10    macro_rules! impl_add_term {
11        ($($type:ty),+) => {
12            $(
13            impl<G> Add<$type> for $type {
14                type Output = Sum<$type>;
15
16                fn add(self, rhs: $type) -> Self::Output {
17                    Sum(vec![self, rhs])
18                }
19            }
20            )+
21        };
22    }
23
24    impl_add_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
25
26    impl<T> Add<T> for Sum<T> {
27        type Output = Sum<T>;
28
29        fn add(mut self, rhs: T) -> Self::Output {
30            self.0.push(rhs);
31            self
32        }
33    }
34
35    macro_rules! impl_add_sum_term {
36        ($($type:ty),+) => {
37            $(
38            impl<G> Add<Sum<$type>> for $type {
39                type Output = Sum<$type>;
40
41                fn add(self, rhs: Sum<$type>) -> Self::Output {
42                    rhs + self
43                }
44            }
45            )+
46        };
47    }
48
49    impl_add_sum_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
50
51    impl<T> Add<Sum<T>> for Sum<T> {
52        type Output = Sum<T>;
53
54        fn add(mut self, rhs: Sum<T>) -> Self::Output {
55            self.0.extend(rhs.0);
56            self
57        }
58    }
59
60    impl<T, F> Add<Weighted<T, F>> for Weighted<T, F> {
61        type Output = Sum<Weighted<T, F>>;
62
63        fn add(self, rhs: Weighted<T, F>) -> Self::Output {
64            Sum(vec![self, rhs])
65        }
66    }
67
68    impl<T, F: Field> Add<T> for Weighted<T, F> {
69        type Output = Sum<Weighted<T, F>>;
70
71        fn add(self, rhs: T) -> Self::Output {
72            Sum(vec![self, rhs.into()])
73        }
74    }
75
76    macro_rules! impl_add_weighted_term {
77        ($($type:ty),+) => {
78            $(
79            impl<G: Group> Add<Weighted<$type, G::Scalar>> for $type {
80                type Output = Sum<Weighted<$type, G::Scalar>>;
81
82                fn add(self, rhs: Weighted<$type, G::Scalar>) -> Self::Output {
83                    rhs + self
84                }
85            }
86            )+
87        };
88    }
89
90    impl_add_weighted_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
91
92    impl<T, F: Field> Add<T> for Sum<Weighted<T, F>> {
93        type Output = Sum<Weighted<T, F>>;
94
95        fn add(mut self, rhs: T) -> Self::Output {
96            self.0.push(rhs.into());
97            self
98        }
99    }
100
101    macro_rules! impl_add_weighted_sum_term {
102        ($($type:ty),+) => {
103            $(
104            impl<G: Group> Add<Sum<Weighted<$type, G::Scalar>>> for $type {
105                type Output = Sum<Weighted<$type, G::Scalar>>;
106
107                fn add(self, rhs: Sum<Weighted<$type, G::Scalar>>) -> Self::Output {
108                    rhs + self
109                }
110            }
111            )+
112        };
113    }
114
115    impl_add_weighted_sum_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
116
117    impl<T, F: Field> Add<Sum<T>> for Sum<Weighted<T, F>> {
118        type Output = Sum<Weighted<T, F>>;
119
120        fn add(self, rhs: Sum<T>) -> Self::Output {
121            self + Self::from(rhs)
122        }
123    }
124
125    impl<T, F: Field> Add<Sum<Weighted<T, F>>> for Sum<T> {
126        type Output = Sum<Weighted<T, F>>;
127
128        fn add(self, rhs: Sum<Weighted<T, F>>) -> Self::Output {
129            rhs + self
130        }
131    }
132
133    impl<T, F: Field> Add<Weighted<T, F>> for Sum<T> {
134        type Output = Sum<Weighted<T, F>>;
135
136        fn add(self, rhs: Weighted<T, F>) -> Self::Output {
137            Self::Output::from(self) + rhs
138        }
139    }
140
141    impl<T, F: Field> Add<Sum<T>> for Weighted<T, F> {
142        type Output = Sum<Weighted<T, F>>;
143
144        fn add(self, rhs: Sum<T>) -> Self::Output {
145            rhs + self
146        }
147    }
148
149    impl<G> Add<ScalarVar<G>> for ScalarTerm<G> {
150        type Output = Sum<ScalarTerm<G>>;
151
152        fn add(self, rhs: ScalarVar<G>) -> Self::Output {
153            self + ScalarTerm::from(rhs)
154        }
155    }
156
157    impl<G> Add<ScalarTerm<G>> for ScalarVar<G> {
158        type Output = Sum<ScalarTerm<G>>;
159
160        fn add(self, rhs: ScalarTerm<G>) -> Self::Output {
161            rhs + self
162        }
163    }
164
165    impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Weighted<ScalarTerm<G>, G::Scalar> {
166        type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
167
168        fn add(self, rhs: T) -> Self::Output {
169            self + Self::from(rhs.into())
170        }
171    }
172
173    impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Weighted<ScalarVar<G>, G::Scalar> {
174        type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
175
176        fn add(self, rhs: T) -> Self::Output {
177            <Weighted<ScalarTerm<G>, G::Scalar>>::from(self) + rhs.into()
178        }
179    }
180
181    impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for ScalarVar<G> {
182        type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
183
184        fn add(self, rhs: T) -> Self::Output {
185            Weighted::from(ScalarTerm::from(self)) + rhs.into()
186        }
187    }
188
189    impl<G: Group> Add<GroupVar<G>> for Sum<Weighted<Term<G>, G::Scalar>> {
190        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
191
192        fn add(self, rhs: GroupVar<G>) -> Self::Output {
193            self + Self::from(rhs)
194        }
195    }
196
197    impl<G: Group> Add<GroupVar<G>> for Sum<Term<G>> {
198        type Output = Sum<Term<G>>;
199
200        fn add(self, rhs: GroupVar<G>) -> Self::Output {
201            self + Self::from(rhs)
202        }
203    }
204
205    impl<G: Group> Add<GroupVar<G>> for Weighted<Term<G>, G::Scalar> {
206        type Output = Sum<Weighted<Term<G>, G::Scalar>>;
207
208        fn add(self, rhs: GroupVar<G>) -> Self::Output {
209            self + Self::from(rhs)
210        }
211    }
212
213    impl<G: Group> Add<GroupVar<G>> for Term<G> {
214        type Output = Sum<Term<G>>;
215
216        fn add(self, rhs: GroupVar<G>) -> Self::Output {
217            self + Self::from(rhs)
218        }
219    }
220}
221
222mod mul {
223    use super::*;
224
225    impl<G> Mul<ScalarVar<G>> for GroupVar<G> {
226        type Output = Term<G>;
227
228        /// Multiply a [ScalarVar] by a [GroupVar] to form a new [Term].
229        fn mul(self, rhs: ScalarVar<G>) -> Term<G> {
230            Term {
231                elem: self,
232                scalar: rhs.into(),
233            }
234        }
235    }
236
237    impl<G> Mul<GroupVar<G>> for ScalarVar<G> {
238        type Output = Term<G>;
239
240        /// Multiply a [ScalarVar] by a [GroupVar] to form a new [Term].
241        fn mul(self, rhs: GroupVar<G>) -> Term<G> {
242            rhs * self
243        }
244    }
245
246    impl<G> Mul<ScalarTerm<G>> for GroupVar<G> {
247        type Output = Term<G>;
248
249        fn mul(self, rhs: ScalarTerm<G>) -> Term<G> {
250            Term {
251                elem: self,
252                scalar: rhs,
253            }
254        }
255    }
256
257    impl<G> Mul<GroupVar<G>> for ScalarTerm<G> {
258        type Output = Term<G>;
259
260        fn mul(self, rhs: GroupVar<G>) -> Term<G> {
261            rhs * self
262        }
263    }
264
265    impl<Rhs: Clone, Lhs: Mul<Rhs>> Mul<Rhs> for Sum<Lhs> {
266        type Output = Sum<<Lhs as Mul<Rhs>>::Output>;
267
268        /// Multiplication of the sum by a term, implemented as a general distributive property.
269        fn mul(self, rhs: Rhs) -> Self::Output {
270            Sum(self.0.into_iter().map(|x| x * rhs.clone()).collect())
271        }
272    }
273
274    // NOTE: Rust forbids implementation of foreign traits (e.g. Mul) over bare generic types (e.g. F:
275    // Field). It can be implemented over specific types (e.g. curve25519_dalek::Scalar or u64). As a
276    // result, this generic implements `var * scalar`, but not `scalar * var`.
277
278    macro_rules! impl_scalar_mul_term {
279        ($($type:ty),+) => {
280            $(
281            // NOTE: Rust does not like this impl when F is replaced by G::Scalar.
282            impl<F: Field + Into<G::Scalar>, G: Group> Mul<F> for $type {
283                type Output = Weighted<$type, G::Scalar>;
284
285                fn mul(self, rhs: F) -> Self::Output {
286                    Weighted {
287                        term: self,
288                        weight: rhs.into(),
289                    }
290                }
291            }
292            )+
293        };
294    }
295
296    impl_scalar_mul_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
297
298    impl<T, F: Field> Mul<F> for Weighted<T, F> {
299        type Output = Weighted<T, F>;
300
301        fn mul(self, rhs: F) -> Self::Output {
302            Weighted {
303                term: self.term,
304                weight: self.weight * rhs,
305            }
306        }
307    }
308
309    impl<G: Group> Mul<ScalarVar<G>> for Weighted<GroupVar<G>, G::Scalar> {
310        type Output = Weighted<Term<G>, G::Scalar>;
311
312        fn mul(self, rhs: ScalarVar<G>) -> Self::Output {
313            Weighted {
314                term: self.term * rhs,
315                weight: self.weight,
316            }
317        }
318    }
319
320    impl<G: Group> Mul<Weighted<GroupVar<G>, G::Scalar>> for ScalarVar<G> {
321        type Output = Weighted<Term<G>, G::Scalar>;
322
323        fn mul(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
324            rhs * self
325        }
326    }
327
328    impl<G: Group> Mul<GroupVar<G>> for Weighted<ScalarVar<G>, G::Scalar> {
329        type Output = Weighted<Term<G>, G::Scalar>;
330
331        fn mul(self, rhs: GroupVar<G>) -> Self::Output {
332            Weighted {
333                term: self.term * rhs,
334                weight: self.weight,
335            }
336        }
337    }
338
339    impl<G: Group> Mul<Weighted<ScalarVar<G>, G::Scalar>> for GroupVar<G> {
340        type Output = Weighted<Term<G>, G::Scalar>;
341
342        fn mul(self, rhs: Weighted<ScalarVar<G>, G::Scalar>) -> Self::Output {
343            rhs * self
344        }
345    }
346
347    impl<G: Group> Mul<ScalarTerm<G>> for Weighted<GroupVar<G>, G::Scalar> {
348        type Output = Weighted<Term<G>, G::Scalar>;
349
350        fn mul(self, rhs: ScalarTerm<G>) -> Self::Output {
351            Weighted {
352                term: self.term * rhs,
353                weight: self.weight,
354            }
355        }
356    }
357
358    impl<G: Group> Mul<Weighted<GroupVar<G>, G::Scalar>> for ScalarTerm<G> {
359        type Output = Weighted<Term<G>, G::Scalar>;
360
361        fn mul(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
362            rhs * self
363        }
364    }
365
366    impl<G: Group> Mul<GroupVar<G>> for Weighted<ScalarTerm<G>, G::Scalar> {
367        type Output = Weighted<Term<G>, G::Scalar>;
368
369        fn mul(self, rhs: GroupVar<G>) -> Self::Output {
370            Weighted {
371                term: self.term * rhs,
372                weight: self.weight,
373            }
374        }
375    }
376
377    impl<G: Group> Mul<Weighted<ScalarTerm<G>, G::Scalar>> for GroupVar<G> {
378        type Output = Weighted<Term<G>, G::Scalar>;
379
380        fn mul(self, rhs: Weighted<ScalarTerm<G>, G::Scalar>) -> Self::Output {
381            rhs * self
382        }
383    }
384}
385
386mod neg {
387    use super::*;
388
389    impl<T: Neg> Neg for Sum<T> {
390        type Output = Sum<<T as Neg>::Output>;
391
392        /// Negation a sum, implemented as a general distributive property.
393        fn neg(self) -> Self::Output {
394            Sum(self.0.into_iter().map(|x| x.neg()).collect())
395        }
396    }
397
398    impl<T, F: Field> Neg for Weighted<T, F> {
399        type Output = Weighted<T, F>;
400
401        /// Negation of a weighted term, implemented as negation of its weight.
402        fn neg(self) -> Self::Output {
403            Weighted {
404                term: self.term,
405                weight: -self.weight,
406            }
407        }
408    }
409
410    macro_rules! impl_neg_term {
411        ($($type:ty),+) => {
412            $(
413            impl<G: Group> Neg for $type {
414                type Output = Weighted<$type, G::Scalar>;
415
416                fn neg(self) -> Self::Output {
417                    Weighted {
418                        term: self,
419                        weight: -G::Scalar::ONE,
420                    }
421                }
422            }
423            )+
424        };
425    }
426
427    impl_neg_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
428}
429
430mod sub {
431    use super::*;
432
433    impl<T, Rhs> Sub<Rhs> for Sum<T>
434    where
435        Rhs: Neg,
436        <Rhs as Neg>::Output: Add<Self>,
437    {
438        type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
439
440        #[allow(clippy::suspicious_arithmetic_impl)]
441        fn sub(self, rhs: Rhs) -> Self::Output {
442            rhs.neg() + self
443        }
444    }
445
446    impl<T, F, Rhs> Sub<Rhs> for Weighted<T, F>
447    where
448        Rhs: Neg,
449        <Rhs as Neg>::Output: Add<Self>,
450    {
451        type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
452
453        #[allow(clippy::suspicious_arithmetic_impl)]
454        fn sub(self, rhs: Rhs) -> Self::Output {
455            rhs.neg() + self
456        }
457    }
458
459    macro_rules! impl_sub_as_neg_add {
460        ($($type:ty),+) => {
461            $(
462            impl<G, Rhs> Sub<Rhs> for $type
463            where
464                Rhs: Neg,
465                <Rhs as Neg>::Output: Add<Self>,
466            {
467                type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
468
469                #[allow(clippy::suspicious_arithmetic_impl)]
470                fn sub(self, rhs: Rhs) -> Self::Output {
471                    rhs.neg() + self
472                }
473            }
474            )+
475        };
476    }
477
478    impl_sub_as_neg_add!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
479}
480
481#[cfg(test)]
482mod tests {
483    use crate::linear_relation::{GroupVar, ScalarTerm, ScalarVar, Term};
484    use curve25519_dalek::RistrettoPoint as G;
485    use curve25519_dalek::Scalar;
486    use std::marker::PhantomData;
487
488    fn scalar_var(i: usize) -> ScalarVar<G> {
489        ScalarVar(i, PhantomData)
490    }
491
492    fn group_var(i: usize) -> GroupVar<G> {
493        GroupVar(i, PhantomData)
494    }
495
496    #[test]
497    fn test_scalar_var_addition() {
498        let x = scalar_var(0);
499        let y = scalar_var(1);
500
501        let sum = x + y;
502        assert_eq!(sum.terms().len(), 2);
503        assert_eq!(sum.terms()[0], x);
504        assert_eq!(sum.terms()[1], y);
505    }
506
507    #[test]
508    fn test_scalar_var_scalar_addition() {
509        let x = scalar_var(0);
510
511        let sum = x + Scalar::from(5u64);
512        assert_eq!(sum.terms().len(), 2);
513        assert_eq!(sum.terms()[0].term, x.into());
514        assert_eq!(sum.terms()[0].weight, Scalar::ONE);
515        assert_eq!(sum.terms()[1].term, ScalarTerm::Unit);
516        assert_eq!(sum.terms()[1].weight, Scalar::from(5u64));
517    }
518
519    #[test]
520    fn test_scalar_var_scalar_addition_mul_group() {
521        let x = scalar_var(0);
522        let g = group_var(0);
523
524        let res = (x + Scalar::from(5u64)) * g;
525
526        assert_eq!(res.terms().len(), 2);
527        assert_eq!(
528            res.terms()[0].term,
529            Term {
530                scalar: x.into(),
531                elem: g
532            }
533        );
534        assert_eq!(res.terms()[0].weight, Scalar::ONE);
535        assert_eq!(
536            res.terms()[1].term,
537            Term {
538                scalar: ScalarTerm::Unit,
539                elem: g
540            }
541        );
542        assert_eq!(res.terms()[1].weight, Scalar::from(5u64));
543    }
544
545    #[test]
546    fn test_group_var_addition() {
547        let g = group_var(0);
548        let h = group_var(1);
549
550        let sum = g + h;
551        assert_eq!(sum.terms().len(), 2);
552        assert_eq!(sum.terms()[0], g);
553        assert_eq!(sum.terms()[1], h);
554    }
555
556    #[test]
557    fn test_term_addition() {
558        let x = scalar_var(0);
559        let g = group_var(0);
560        let y = scalar_var(1);
561        let h = group_var(1);
562
563        let term1 = Term {
564            scalar: x.into(),
565            elem: g,
566        };
567        let term2 = Term {
568            scalar: y.into(),
569            elem: h,
570        };
571
572        let sum = term1 + term2;
573        assert_eq!(sum.terms().len(), 2);
574        assert_eq!(sum.terms()[0], term1);
575        assert_eq!(sum.terms()[1], term2);
576    }
577
578    #[test]
579    fn test_term_group_var_addition() {
580        let x = scalar_var(0);
581        let g = group_var(0);
582
583        let res = (x * g) + g;
584
585        assert_eq!(res.terms().len(), 2);
586        assert_eq!(
587            res.terms()[0],
588            Term {
589                scalar: x.into(),
590                elem: g
591            }
592        );
593        assert_eq!(
594            res.terms()[1],
595            Term {
596                scalar: ScalarTerm::Unit,
597                elem: g
598            }
599        );
600    }
601
602    #[test]
603    fn test_scalar_group_multiplication() {
604        let x = scalar_var(0);
605        let g = group_var(0);
606
607        let term1 = x * g;
608        let term2 = g * x;
609
610        assert_eq!(term1.scalar, x.into());
611        assert_eq!(term1.elem, g);
612        assert_eq!(term2.scalar, x.into());
613        assert_eq!(term2.elem, g);
614    }
615
616    #[test]
617    fn test_scalar_coefficient_multiplication() {
618        let x = scalar_var(0);
619        let weighted = x * Scalar::from(5u64);
620
621        assert_eq!(weighted.term, x);
622        assert_eq!(weighted.weight, Scalar::from(5u64));
623    }
624
625    #[test]
626    fn test_group_coefficient_multiplication() {
627        let g = group_var(0);
628        let weighted = g * Scalar::from(3u64);
629
630        assert_eq!(weighted.term, g);
631        assert_eq!(weighted.weight, Scalar::from(3u64));
632    }
633
634    #[test]
635    fn test_term_coefficient_multiplication() {
636        let x = scalar_var(0);
637        let g = group_var(0);
638        let term = Term {
639            scalar: x.into(),
640            elem: g,
641        };
642        let weighted = term * Scalar::from(7u64);
643
644        assert_eq!(weighted.term, term);
645        assert_eq!(weighted.weight, Scalar::from(7u64));
646    }
647
648    #[test]
649    fn test_scalar_var_negation() {
650        let x = scalar_var(0);
651        let neg_x = -x;
652
653        assert_eq!(neg_x.term, x);
654        assert_eq!(neg_x.weight, -Scalar::ONE);
655    }
656
657    #[test]
658    fn test_group_var_negation() {
659        let g = group_var(0);
660        let neg_g = -g;
661
662        assert_eq!(neg_g.term, g);
663        assert_eq!(neg_g.weight, -Scalar::ONE);
664    }
665
666    #[test]
667    fn test_term_negation() {
668        let x = scalar_var(0);
669        let g = group_var(0);
670        let term = Term {
671            scalar: x.into(),
672            elem: g,
673        };
674        let neg_term = -term;
675
676        assert_eq!(neg_term.term, term);
677        assert_eq!(neg_term.weight, -Scalar::ONE);
678    }
679
680    #[test]
681    fn test_weighted_negation() {
682        let x = scalar_var(0);
683        let weighted = x * Scalar::from(5u64);
684        let neg_weighted = -weighted;
685
686        assert_eq!(neg_weighted.term, x);
687        assert_eq!(neg_weighted.weight, -Scalar::from(5u64));
688    }
689
690    #[test]
691    fn test_scalar_var_subtraction() {
692        let x = scalar_var(0);
693        let y = scalar_var(1);
694
695        let diff = x - y;
696        assert_eq!(diff.terms().len(), 2);
697        assert_eq!(diff.terms()[0].term, y);
698        assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
699        assert_eq!(diff.terms()[1].term, x);
700        assert_eq!(diff.terms()[1].weight, Scalar::ONE);
701    }
702
703    #[test]
704    fn test_group_var_subtraction() {
705        let g = group_var(0);
706        let h = group_var(1);
707
708        let diff = g - h;
709        assert_eq!(diff.terms().len(), 2);
710        assert_eq!(diff.terms()[0].term, h);
711        assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
712        assert_eq!(diff.terms()[1].term, g);
713        assert_eq!(diff.terms()[1].weight, Scalar::ONE);
714    }
715
716    #[test]
717    fn test_term_subtraction() {
718        let x = scalar_var(0);
719        let g = group_var(0);
720        let y = scalar_var(1);
721        let h = group_var(1);
722
723        let term1 = Term {
724            scalar: x.into(),
725            elem: g,
726        };
727        let term2 = Term {
728            scalar: y.into(),
729            elem: h,
730        };
731
732        let diff = term1 - term2;
733        assert_eq!(diff.terms().len(), 2);
734        assert_eq!(diff.terms()[0].term, term2);
735        assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
736        assert_eq!(diff.terms()[1].term, term1);
737        assert_eq!(diff.terms()[1].weight, Scalar::ONE);
738    }
739
740    #[test]
741    fn test_sum_addition_chaining() {
742        let x = scalar_var(0);
743        let y = scalar_var(1);
744        let z = scalar_var(2);
745
746        let sum = x + y + z;
747        assert_eq!(sum.terms().len(), 3);
748        assert_eq!(sum.terms()[0], x);
749        assert_eq!(sum.terms()[1], y);
750        assert_eq!(sum.terms()[2], z);
751    }
752
753    #[test]
754    fn test_sum_plus_scalar_var() {
755        let x = scalar_var(0);
756        let y = scalar_var(1);
757        let z = scalar_var(2);
758
759        let sum = x + y;
760        let result = z + sum;
761        assert_eq!(result.terms().len(), 3);
762        assert_eq!(result.terms()[0], x);
763        assert_eq!(result.terms()[1], y);
764        assert_eq!(result.terms()[2], z);
765    }
766
767    #[test]
768    fn test_sum_plus_sum() {
769        let x = scalar_var(0);
770        let y = scalar_var(1);
771        let z = scalar_var(2);
772        let w = scalar_var(3);
773
774        let sum1 = x + y;
775        let sum2 = z + w;
776        let result = sum1 + sum2;
777
778        assert_eq!(result.terms().len(), 4);
779        assert_eq!(result.terms()[0], x);
780        assert_eq!(result.terms()[1], y);
781        assert_eq!(result.terms()[2], z);
782        assert_eq!(result.terms()[3], w);
783    }
784
785    #[test]
786    fn test_sum_negation() {
787        let x = scalar_var(0);
788        let y = scalar_var(1);
789
790        let sum = x + y;
791        let neg_sum = -sum;
792
793        assert_eq!(neg_sum.terms().len(), 2);
794        assert_eq!(neg_sum.terms()[0].term, x);
795        assert_eq!(neg_sum.terms()[0].weight, -Scalar::ONE);
796        assert_eq!(neg_sum.terms()[1].term, y);
797        assert_eq!(neg_sum.terms()[1].weight, -Scalar::ONE);
798    }
799
800    #[test]
801    fn test_weighted_addition() {
802        let x = scalar_var(0);
803        let y = scalar_var(1);
804
805        let weighted1 = x * Scalar::from(3u64);
806        let weighted2 = y * Scalar::from(5u64);
807        let sum = weighted1 + weighted2;
808
809        assert_eq!(sum.terms().len(), 2);
810        assert_eq!(sum.terms()[0].term, x);
811        assert_eq!(sum.terms()[0].weight, Scalar::from(3u64));
812        assert_eq!(sum.terms()[1].term, y);
813        assert_eq!(sum.terms()[1].weight, Scalar::from(5u64));
814    }
815
816    #[test]
817    fn test_weighted_plus_term() {
818        let x = scalar_var(0);
819        let y = scalar_var(1);
820
821        let weighted = x * Scalar::from(2u64);
822        let sum = weighted + y;
823
824        assert_eq!(sum.terms().len(), 2);
825        assert_eq!(sum.terms()[0].term, x);
826        assert_eq!(sum.terms()[0].weight, Scalar::from(2u64));
827        assert_eq!(sum.terms()[1].term, y);
828        assert_eq!(sum.terms()[1].weight, Scalar::ONE);
829    }
830
831    #[test]
832    fn test_weighted_scalar_multiplication() {
833        let x = scalar_var(0);
834        let weighted = x * Scalar::from(2u64);
835        let result = weighted * Scalar::from(3u64);
836
837        assert_eq!(result.term, x);
838        assert_eq!(result.weight, Scalar::from(6u64));
839    }
840
841    #[test]
842    fn test_weighted_group_var_times_scalar_var() {
843        let x = scalar_var(0);
844        let g = group_var(0);
845
846        let weighted_g = g * Scalar::from(5u64);
847        let result = x * weighted_g;
848
849        assert_eq!(result.term.scalar, x.into());
850        assert_eq!(result.term.elem, g);
851        assert_eq!(result.weight, Scalar::from(5u64));
852    }
853
854    #[test]
855    fn test_weighted_scalar_var_times_group_var() {
856        let x = scalar_var(0);
857        let g = group_var(0);
858
859        let weighted_x = x * Scalar::from(3u64);
860        let result = weighted_x * g;
861
862        assert_eq!(result.term.scalar, x.into());
863        assert_eq!(result.term.elem, g);
864        assert_eq!(result.weight, Scalar::from(3u64));
865    }
866
867    #[test]
868    fn test_sum_scalar_multiplication_distributive() {
869        let x = scalar_var(0);
870        let y = scalar_var(1);
871
872        let sum = x + y;
873        let result = sum * Scalar::from(2u64);
874
875        assert_eq!(result.terms().len(), 2);
876        assert_eq!(result.terms()[0].term, x);
877        assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
878        assert_eq!(result.terms()[1].term, y);
879        assert_eq!(result.terms()[1].weight, Scalar::from(2u64));
880    }
881
882    #[test]
883    fn test_sum_subtraction_distributive() {
884        let x = scalar_var(0);
885        let y = scalar_var(1);
886        let z = scalar_var(2);
887
888        let sum1 = x + y;
889        let result = sum1 - z;
890
891        assert_eq!(result.terms().len(), 3);
892        assert_eq!(result.terms()[0].term, x);
893        assert_eq!(result.terms()[0].weight, Scalar::ONE);
894        assert_eq!(result.terms()[1].term, y);
895        assert_eq!(result.terms()[1].weight, Scalar::ONE);
896        assert_eq!(result.terms()[2].term, z);
897        assert_eq!(result.terms()[2].weight, -Scalar::ONE);
898    }
899
900    #[test]
901    fn test_weighted_sum_scalar_multiplication() {
902        let x = scalar_var(0);
903        let y = scalar_var(1);
904
905        let weighted1 = x * Scalar::from(2u64);
906        let weighted2 = y * Scalar::from(3u64);
907        let sum = weighted1 + weighted2;
908        let result = sum * Scalar::from(4u64);
909
910        assert_eq!(result.terms().len(), 2);
911        assert_eq!(result.terms()[0].term, x);
912        assert_eq!(result.terms()[0].weight, Scalar::from(8u64));
913        assert_eq!(result.terms()[1].term, y);
914        assert_eq!(result.terms()[1].weight, Scalar::from(12u64));
915    }
916
917    #[test]
918    fn test_pedersen_commitment_expression() {
919        let x = scalar_var(0);
920        let r = scalar_var(1);
921        let g = group_var(0);
922        let h = group_var(1);
923
924        let commitment = x * g + r * h;
925        assert_eq!(commitment.terms().len(), 2);
926        assert_eq!(commitment.terms()[0].scalar, x.into());
927        assert_eq!(commitment.terms()[0].elem, g);
928        assert_eq!(commitment.terms()[1].scalar, r.into());
929        assert_eq!(commitment.terms()[1].elem, h);
930    }
931
932    #[test]
933    fn test_weighted_pedersen_commitment() {
934        let x = scalar_var(0);
935        let r = scalar_var(1);
936        let g = group_var(0);
937        let h = group_var(1);
938
939        let commitment = x * g * Scalar::from(3u64) + r * h * Scalar::from(2u64);
940        assert_eq!(commitment.terms().len(), 2);
941        assert_eq!(commitment.terms()[0].term.scalar, x.into());
942        assert_eq!(commitment.terms()[0].term.elem, g);
943        assert_eq!(commitment.terms()[0].weight, Scalar::from(3u64));
944        assert_eq!(commitment.terms()[1].term.scalar, r.into());
945        assert_eq!(commitment.terms()[1].term.elem, h);
946        assert_eq!(commitment.terms()[1].weight, Scalar::from(2u64));
947    }
948
949    #[test]
950    fn test_complex_multi_term_expression() {
951        let scalars = [scalar_var(0), scalar_var(1), scalar_var(2), scalar_var(3)];
952        let groups = [group_var(0), group_var(1), group_var(2), group_var(3)];
953
954        let expr = scalars[0] * groups[0] + scalars[1] * groups[1] + scalars[2] * groups[2]
955            - scalars[3] * groups[3];
956
957        assert_eq!(expr.terms().len(), 4);
958
959        for i in 0..3 {
960            assert_eq!(expr.terms()[i].term.scalar, scalars[i].into());
961            assert_eq!(expr.terms()[i].term.elem, groups[i]);
962            assert_eq!(expr.terms()[i].weight, Scalar::ONE);
963        }
964
965        assert_eq!(expr.terms()[3].term.scalar, scalars[3].into());
966        assert_eq!(expr.terms()[3].term.elem, groups[3]);
967        assert_eq!(expr.terms()[3].weight, -Scalar::ONE);
968    }
969
970    #[test]
971    fn test_chained_addition_with_coefficients() {
972        let x = scalar_var(0);
973        let y = scalar_var(1);
974        let z = scalar_var(2);
975        let g = group_var(0);
976        let h = group_var(1);
977        let k = group_var(2);
978
979        let expr =
980            x * g * Scalar::from(2u64) + y * h * Scalar::from(3u64) + z * k * Scalar::from(5u64);
981        assert_eq!(expr.terms().len(), 3);
982
983        let expected_coeffs = [2u64, 3u64, 5u64];
984        let expected_scalars = [x, y, z];
985        let expected_groups = [g, h, k];
986
987        for i in 0..3 {
988            assert_eq!(expr.terms()[i].term.scalar, expected_scalars[i].into());
989            assert_eq!(expr.terms()[i].term.elem, expected_groups[i]);
990            assert_eq!(expr.terms()[i].weight, Scalar::from(expected_coeffs[i]));
991        }
992    }
993
994    #[test]
995    fn test_mixing_sum_term_and_sum_weighted() {
996        let x = scalar_var(0);
997        let y = scalar_var(1);
998        let z = scalar_var(2);
999        let g = group_var(0);
1000        let h = group_var(1);
1001        let k = group_var(2);
1002
1003        let basic_sum = x * g + y * h; // Sum<Term>
1004        let weighted_term = z * k * Scalar::from(3u64); // Weighted<Term>
1005        let mixed = basic_sum + weighted_term;
1006
1007        assert_eq!(mixed.terms().len(), 3);
1008        assert_eq!(mixed.terms()[0].term.scalar, x.into());
1009        assert_eq!(mixed.terms()[0].term.elem, g);
1010        assert_eq!(mixed.terms()[0].weight, Scalar::ONE);
1011        assert_eq!(mixed.terms()[1].term.scalar, y.into());
1012        assert_eq!(mixed.terms()[1].term.elem, h);
1013        assert_eq!(mixed.terms()[1].weight, Scalar::ONE);
1014        assert_eq!(mixed.terms()[2].term.scalar, z.into());
1015        assert_eq!(mixed.terms()[2].term.elem, k);
1016        assert_eq!(mixed.terms()[2].weight, Scalar::from(3u64));
1017    }
1018}