1use core::ops::{Add, Mul, Neg, Sub};
2use ff::Field;
3use group::Group;
4
5use super::{GroupVar, ScalarTerm, ScalarVar, Sum, Term, Weighted};
6
7mod add {
8 use super::*;
9
10 macro_rules! impl_add_term {
11 ($($type:ty),+) => {
12 $(
13 impl<G> Add<$type> for $type {
14 type Output = Sum<$type>;
15
16 fn add(self, rhs: $type) -> Self::Output {
17 Sum(vec![self, rhs])
18 }
19 }
20 )+
21 };
22 }
23
24 impl_add_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
25
26 impl<T> Add<T> for Sum<T> {
27 type Output = Sum<T>;
28
29 fn add(mut self, rhs: T) -> Self::Output {
30 self.0.push(rhs);
31 self
32 }
33 }
34
35 macro_rules! impl_add_sum_term {
36 ($($type:ty),+) => {
37 $(
38 impl<G> Add<Sum<$type>> for $type {
39 type Output = Sum<$type>;
40
41 fn add(self, rhs: Sum<$type>) -> Self::Output {
42 rhs + self
43 }
44 }
45 )+
46 };
47 }
48
49 impl_add_sum_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
50
51 impl<T> Add<Sum<T>> for Sum<T> {
52 type Output = Sum<T>;
53
54 fn add(mut self, rhs: Sum<T>) -> Self::Output {
55 self.0.extend(rhs.0);
56 self
57 }
58 }
59
60 impl<T, F> Add<Weighted<T, F>> for Weighted<T, F> {
61 type Output = Sum<Weighted<T, F>>;
62
63 fn add(self, rhs: Weighted<T, F>) -> Self::Output {
64 Sum(vec![self, rhs])
65 }
66 }
67
68 impl<T, F: Field> Add<T> for Weighted<T, F> {
69 type Output = Sum<Weighted<T, F>>;
70
71 fn add(self, rhs: T) -> Self::Output {
72 Sum(vec![self, rhs.into()])
73 }
74 }
75
76 macro_rules! impl_add_weighted_term {
77 ($($type:ty),+) => {
78 $(
79 impl<G: Group> Add<Weighted<$type, G::Scalar>> for $type {
80 type Output = Sum<Weighted<$type, G::Scalar>>;
81
82 fn add(self, rhs: Weighted<$type, G::Scalar>) -> Self::Output {
83 rhs + self
84 }
85 }
86 )+
87 };
88 }
89
90 impl_add_weighted_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
91
92 impl<T, F: Field> Add<T> for Sum<Weighted<T, F>> {
93 type Output = Sum<Weighted<T, F>>;
94
95 fn add(mut self, rhs: T) -> Self::Output {
96 self.0.push(rhs.into());
97 self
98 }
99 }
100
101 macro_rules! impl_add_weighted_sum_term {
102 ($($type:ty),+) => {
103 $(
104 impl<G: Group> Add<Sum<Weighted<$type, G::Scalar>>> for $type {
105 type Output = Sum<Weighted<$type, G::Scalar>>;
106
107 fn add(self, rhs: Sum<Weighted<$type, G::Scalar>>) -> Self::Output {
108 rhs + self
109 }
110 }
111 )+
112 };
113 }
114
115 impl_add_weighted_sum_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
116
117 impl<T, F: Field> Add<Sum<T>> for Sum<Weighted<T, F>> {
118 type Output = Sum<Weighted<T, F>>;
119
120 fn add(self, rhs: Sum<T>) -> Self::Output {
121 self + Self::from(rhs)
122 }
123 }
124
125 impl<T, F: Field> Add<Sum<Weighted<T, F>>> for Sum<T> {
126 type Output = Sum<Weighted<T, F>>;
127
128 fn add(self, rhs: Sum<Weighted<T, F>>) -> Self::Output {
129 rhs + self
130 }
131 }
132
133 impl<T, F: Field> Add<Weighted<T, F>> for Sum<T> {
134 type Output = Sum<Weighted<T, F>>;
135
136 fn add(self, rhs: Weighted<T, F>) -> Self::Output {
137 Self::Output::from(self) + rhs
138 }
139 }
140
141 impl<T, F: Field> Add<Sum<T>> for Weighted<T, F> {
142 type Output = Sum<Weighted<T, F>>;
143
144 fn add(self, rhs: Sum<T>) -> Self::Output {
145 rhs + self
146 }
147 }
148
149 impl<G> Add<ScalarVar<G>> for ScalarTerm<G> {
150 type Output = Sum<ScalarTerm<G>>;
151
152 fn add(self, rhs: ScalarVar<G>) -> Self::Output {
153 self + ScalarTerm::from(rhs)
154 }
155 }
156
157 impl<G> Add<ScalarTerm<G>> for ScalarVar<G> {
158 type Output = Sum<ScalarTerm<G>>;
159
160 fn add(self, rhs: ScalarTerm<G>) -> Self::Output {
161 rhs + self
162 }
163 }
164
165 impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Weighted<ScalarTerm<G>, G::Scalar> {
166 type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
167
168 fn add(self, rhs: T) -> Self::Output {
169 self + Self::from(rhs.into())
170 }
171 }
172
173 impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for Weighted<ScalarVar<G>, G::Scalar> {
174 type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
175
176 fn add(self, rhs: T) -> Self::Output {
177 <Weighted<ScalarTerm<G>, G::Scalar>>::from(self) + rhs.into()
178 }
179 }
180
181 impl<T: Field + Into<G::Scalar>, G: Group> Add<T> for ScalarVar<G> {
182 type Output = Sum<Weighted<ScalarTerm<G>, G::Scalar>>;
183
184 fn add(self, rhs: T) -> Self::Output {
185 Weighted::from(ScalarTerm::from(self)) + rhs.into()
186 }
187 }
188
189 impl<G: Group> Add<GroupVar<G>> for Sum<Weighted<Term<G>, G::Scalar>> {
190 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
191
192 fn add(self, rhs: GroupVar<G>) -> Self::Output {
193 self + Self::from(rhs)
194 }
195 }
196
197 impl<G: Group> Add<GroupVar<G>> for Sum<Term<G>> {
198 type Output = Sum<Term<G>>;
199
200 fn add(self, rhs: GroupVar<G>) -> Self::Output {
201 self + Self::from(rhs)
202 }
203 }
204
205 impl<G: Group> Add<GroupVar<G>> for Weighted<Term<G>, G::Scalar> {
206 type Output = Sum<Weighted<Term<G>, G::Scalar>>;
207
208 fn add(self, rhs: GroupVar<G>) -> Self::Output {
209 self + Self::from(rhs)
210 }
211 }
212
213 impl<G: Group> Add<GroupVar<G>> for Term<G> {
214 type Output = Sum<Term<G>>;
215
216 fn add(self, rhs: GroupVar<G>) -> Self::Output {
217 self + Self::from(rhs)
218 }
219 }
220}
221
222mod mul {
223 use super::*;
224
225 impl<G> Mul<ScalarVar<G>> for GroupVar<G> {
226 type Output = Term<G>;
227
228 fn mul(self, rhs: ScalarVar<G>) -> Term<G> {
230 Term {
231 elem: self,
232 scalar: rhs.into(),
233 }
234 }
235 }
236
237 impl<G> Mul<GroupVar<G>> for ScalarVar<G> {
238 type Output = Term<G>;
239
240 fn mul(self, rhs: GroupVar<G>) -> Term<G> {
242 rhs * self
243 }
244 }
245
246 impl<G> Mul<ScalarTerm<G>> for GroupVar<G> {
247 type Output = Term<G>;
248
249 fn mul(self, rhs: ScalarTerm<G>) -> Term<G> {
250 Term {
251 elem: self,
252 scalar: rhs,
253 }
254 }
255 }
256
257 impl<G> Mul<GroupVar<G>> for ScalarTerm<G> {
258 type Output = Term<G>;
259
260 fn mul(self, rhs: GroupVar<G>) -> Term<G> {
261 rhs * self
262 }
263 }
264
265 impl<Rhs: Clone, Lhs: Mul<Rhs>> Mul<Rhs> for Sum<Lhs> {
266 type Output = Sum<<Lhs as Mul<Rhs>>::Output>;
267
268 fn mul(self, rhs: Rhs) -> Self::Output {
270 Sum(self.0.into_iter().map(|x| x * rhs.clone()).collect())
271 }
272 }
273
274 macro_rules! impl_scalar_mul_term {
279 ($($type:ty),+) => {
280 $(
281 impl<F: Field + Into<G::Scalar>, G: Group> Mul<F> for $type {
283 type Output = Weighted<$type, G::Scalar>;
284
285 fn mul(self, rhs: F) -> Self::Output {
286 Weighted {
287 term: self,
288 weight: rhs.into(),
289 }
290 }
291 }
292 )+
293 };
294 }
295
296 impl_scalar_mul_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
297
298 impl<T, F: Field> Mul<F> for Weighted<T, F> {
299 type Output = Weighted<T, F>;
300
301 fn mul(self, rhs: F) -> Self::Output {
302 Weighted {
303 term: self.term,
304 weight: self.weight * rhs,
305 }
306 }
307 }
308
309 impl<G: Group> Mul<ScalarVar<G>> for Weighted<GroupVar<G>, G::Scalar> {
310 type Output = Weighted<Term<G>, G::Scalar>;
311
312 fn mul(self, rhs: ScalarVar<G>) -> Self::Output {
313 Weighted {
314 term: self.term * rhs,
315 weight: self.weight,
316 }
317 }
318 }
319
320 impl<G: Group> Mul<Weighted<GroupVar<G>, G::Scalar>> for ScalarVar<G> {
321 type Output = Weighted<Term<G>, G::Scalar>;
322
323 fn mul(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
324 rhs * self
325 }
326 }
327
328 impl<G: Group> Mul<GroupVar<G>> for Weighted<ScalarVar<G>, G::Scalar> {
329 type Output = Weighted<Term<G>, G::Scalar>;
330
331 fn mul(self, rhs: GroupVar<G>) -> Self::Output {
332 Weighted {
333 term: self.term * rhs,
334 weight: self.weight,
335 }
336 }
337 }
338
339 impl<G: Group> Mul<Weighted<ScalarVar<G>, G::Scalar>> for GroupVar<G> {
340 type Output = Weighted<Term<G>, G::Scalar>;
341
342 fn mul(self, rhs: Weighted<ScalarVar<G>, G::Scalar>) -> Self::Output {
343 rhs * self
344 }
345 }
346
347 impl<G: Group> Mul<ScalarTerm<G>> for Weighted<GroupVar<G>, G::Scalar> {
348 type Output = Weighted<Term<G>, G::Scalar>;
349
350 fn mul(self, rhs: ScalarTerm<G>) -> Self::Output {
351 Weighted {
352 term: self.term * rhs,
353 weight: self.weight,
354 }
355 }
356 }
357
358 impl<G: Group> Mul<Weighted<GroupVar<G>, G::Scalar>> for ScalarTerm<G> {
359 type Output = Weighted<Term<G>, G::Scalar>;
360
361 fn mul(self, rhs: Weighted<GroupVar<G>, G::Scalar>) -> Self::Output {
362 rhs * self
363 }
364 }
365
366 impl<G: Group> Mul<GroupVar<G>> for Weighted<ScalarTerm<G>, G::Scalar> {
367 type Output = Weighted<Term<G>, G::Scalar>;
368
369 fn mul(self, rhs: GroupVar<G>) -> Self::Output {
370 Weighted {
371 term: self.term * rhs,
372 weight: self.weight,
373 }
374 }
375 }
376
377 impl<G: Group> Mul<Weighted<ScalarTerm<G>, G::Scalar>> for GroupVar<G> {
378 type Output = Weighted<Term<G>, G::Scalar>;
379
380 fn mul(self, rhs: Weighted<ScalarTerm<G>, G::Scalar>) -> Self::Output {
381 rhs * self
382 }
383 }
384}
385
386mod neg {
387 use super::*;
388
389 impl<T: Neg> Neg for Sum<T> {
390 type Output = Sum<<T as Neg>::Output>;
391
392 fn neg(self) -> Self::Output {
394 Sum(self.0.into_iter().map(|x| x.neg()).collect())
395 }
396 }
397
398 impl<T, F: Field> Neg for Weighted<T, F> {
399 type Output = Weighted<T, F>;
400
401 fn neg(self) -> Self::Output {
403 Weighted {
404 term: self.term,
405 weight: -self.weight,
406 }
407 }
408 }
409
410 macro_rules! impl_neg_term {
411 ($($type:ty),+) => {
412 $(
413 impl<G: Group> Neg for $type {
414 type Output = Weighted<$type, G::Scalar>;
415
416 fn neg(self) -> Self::Output {
417 Weighted {
418 term: self,
419 weight: -G::Scalar::ONE,
420 }
421 }
422 }
423 )+
424 };
425 }
426
427 impl_neg_term!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
428}
429
430mod sub {
431 use super::*;
432
433 impl<T, Rhs> Sub<Rhs> for Sum<T>
434 where
435 Rhs: Neg,
436 <Rhs as Neg>::Output: Add<Self>,
437 {
438 type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
439
440 #[allow(clippy::suspicious_arithmetic_impl)]
441 fn sub(self, rhs: Rhs) -> Self::Output {
442 rhs.neg() + self
443 }
444 }
445
446 impl<T, F, Rhs> Sub<Rhs> for Weighted<T, F>
447 where
448 Rhs: Neg,
449 <Rhs as Neg>::Output: Add<Self>,
450 {
451 type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
452
453 #[allow(clippy::suspicious_arithmetic_impl)]
454 fn sub(self, rhs: Rhs) -> Self::Output {
455 rhs.neg() + self
456 }
457 }
458
459 macro_rules! impl_sub_as_neg_add {
460 ($($type:ty),+) => {
461 $(
462 impl<G, Rhs> Sub<Rhs> for $type
463 where
464 Rhs: Neg,
465 <Rhs as Neg>::Output: Add<Self>,
466 {
467 type Output = <<Rhs as Neg>::Output as Add<Self>>::Output;
468
469 #[allow(clippy::suspicious_arithmetic_impl)]
470 fn sub(self, rhs: Rhs) -> Self::Output {
471 rhs.neg() + self
472 }
473 }
474 )+
475 };
476 }
477
478 impl_sub_as_neg_add!(ScalarVar<G>, ScalarTerm<G>, GroupVar<G>, Term<G>);
479}
480
481#[cfg(test)]
482mod tests {
483 use crate::linear_relation::{GroupVar, ScalarTerm, ScalarVar, Term};
484 use curve25519_dalek::RistrettoPoint as G;
485 use curve25519_dalek::Scalar;
486 use std::marker::PhantomData;
487
488 fn scalar_var(i: usize) -> ScalarVar<G> {
489 ScalarVar(i, PhantomData)
490 }
491
492 fn group_var(i: usize) -> GroupVar<G> {
493 GroupVar(i, PhantomData)
494 }
495
496 #[test]
497 fn test_scalar_var_addition() {
498 let x = scalar_var(0);
499 let y = scalar_var(1);
500
501 let sum = x + y;
502 assert_eq!(sum.terms().len(), 2);
503 assert_eq!(sum.terms()[0], x);
504 assert_eq!(sum.terms()[1], y);
505 }
506
507 #[test]
508 fn test_scalar_var_scalar_addition() {
509 let x = scalar_var(0);
510
511 let sum = x + Scalar::from(5u64);
512 assert_eq!(sum.terms().len(), 2);
513 assert_eq!(sum.terms()[0].term, x.into());
514 assert_eq!(sum.terms()[0].weight, Scalar::ONE);
515 assert_eq!(sum.terms()[1].term, ScalarTerm::Unit);
516 assert_eq!(sum.terms()[1].weight, Scalar::from(5u64));
517 }
518
519 #[test]
520 fn test_scalar_var_scalar_addition_mul_group() {
521 let x = scalar_var(0);
522 let g = group_var(0);
523
524 let res = (x + Scalar::from(5u64)) * g;
525
526 assert_eq!(res.terms().len(), 2);
527 assert_eq!(
528 res.terms()[0].term,
529 Term {
530 scalar: x.into(),
531 elem: g
532 }
533 );
534 assert_eq!(res.terms()[0].weight, Scalar::ONE);
535 assert_eq!(
536 res.terms()[1].term,
537 Term {
538 scalar: ScalarTerm::Unit,
539 elem: g
540 }
541 );
542 assert_eq!(res.terms()[1].weight, Scalar::from(5u64));
543 }
544
545 #[test]
546 fn test_group_var_addition() {
547 let g = group_var(0);
548 let h = group_var(1);
549
550 let sum = g + h;
551 assert_eq!(sum.terms().len(), 2);
552 assert_eq!(sum.terms()[0], g);
553 assert_eq!(sum.terms()[1], h);
554 }
555
556 #[test]
557 fn test_term_addition() {
558 let x = scalar_var(0);
559 let g = group_var(0);
560 let y = scalar_var(1);
561 let h = group_var(1);
562
563 let term1 = Term {
564 scalar: x.into(),
565 elem: g,
566 };
567 let term2 = Term {
568 scalar: y.into(),
569 elem: h,
570 };
571
572 let sum = term1 + term2;
573 assert_eq!(sum.terms().len(), 2);
574 assert_eq!(sum.terms()[0], term1);
575 assert_eq!(sum.terms()[1], term2);
576 }
577
578 #[test]
579 fn test_term_group_var_addition() {
580 let x = scalar_var(0);
581 let g = group_var(0);
582
583 let res = (x * g) + g;
584
585 assert_eq!(res.terms().len(), 2);
586 assert_eq!(
587 res.terms()[0],
588 Term {
589 scalar: x.into(),
590 elem: g
591 }
592 );
593 assert_eq!(
594 res.terms()[1],
595 Term {
596 scalar: ScalarTerm::Unit,
597 elem: g
598 }
599 );
600 }
601
602 #[test]
603 fn test_scalar_group_multiplication() {
604 let x = scalar_var(0);
605 let g = group_var(0);
606
607 let term1 = x * g;
608 let term2 = g * x;
609
610 assert_eq!(term1.scalar, x.into());
611 assert_eq!(term1.elem, g);
612 assert_eq!(term2.scalar, x.into());
613 assert_eq!(term2.elem, g);
614 }
615
616 #[test]
617 fn test_scalar_coefficient_multiplication() {
618 let x = scalar_var(0);
619 let weighted = x * Scalar::from(5u64);
620
621 assert_eq!(weighted.term, x);
622 assert_eq!(weighted.weight, Scalar::from(5u64));
623 }
624
625 #[test]
626 fn test_group_coefficient_multiplication() {
627 let g = group_var(0);
628 let weighted = g * Scalar::from(3u64);
629
630 assert_eq!(weighted.term, g);
631 assert_eq!(weighted.weight, Scalar::from(3u64));
632 }
633
634 #[test]
635 fn test_term_coefficient_multiplication() {
636 let x = scalar_var(0);
637 let g = group_var(0);
638 let term = Term {
639 scalar: x.into(),
640 elem: g,
641 };
642 let weighted = term * Scalar::from(7u64);
643
644 assert_eq!(weighted.term, term);
645 assert_eq!(weighted.weight, Scalar::from(7u64));
646 }
647
648 #[test]
649 fn test_scalar_var_negation() {
650 let x = scalar_var(0);
651 let neg_x = -x;
652
653 assert_eq!(neg_x.term, x);
654 assert_eq!(neg_x.weight, -Scalar::ONE);
655 }
656
657 #[test]
658 fn test_group_var_negation() {
659 let g = group_var(0);
660 let neg_g = -g;
661
662 assert_eq!(neg_g.term, g);
663 assert_eq!(neg_g.weight, -Scalar::ONE);
664 }
665
666 #[test]
667 fn test_term_negation() {
668 let x = scalar_var(0);
669 let g = group_var(0);
670 let term = Term {
671 scalar: x.into(),
672 elem: g,
673 };
674 let neg_term = -term;
675
676 assert_eq!(neg_term.term, term);
677 assert_eq!(neg_term.weight, -Scalar::ONE);
678 }
679
680 #[test]
681 fn test_weighted_negation() {
682 let x = scalar_var(0);
683 let weighted = x * Scalar::from(5u64);
684 let neg_weighted = -weighted;
685
686 assert_eq!(neg_weighted.term, x);
687 assert_eq!(neg_weighted.weight, -Scalar::from(5u64));
688 }
689
690 #[test]
691 fn test_scalar_var_subtraction() {
692 let x = scalar_var(0);
693 let y = scalar_var(1);
694
695 let diff = x - y;
696 assert_eq!(diff.terms().len(), 2);
697 assert_eq!(diff.terms()[0].term, y);
698 assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
699 assert_eq!(diff.terms()[1].term, x);
700 assert_eq!(diff.terms()[1].weight, Scalar::ONE);
701 }
702
703 #[test]
704 fn test_group_var_subtraction() {
705 let g = group_var(0);
706 let h = group_var(1);
707
708 let diff = g - h;
709 assert_eq!(diff.terms().len(), 2);
710 assert_eq!(diff.terms()[0].term, h);
711 assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
712 assert_eq!(diff.terms()[1].term, g);
713 assert_eq!(diff.terms()[1].weight, Scalar::ONE);
714 }
715
716 #[test]
717 fn test_term_subtraction() {
718 let x = scalar_var(0);
719 let g = group_var(0);
720 let y = scalar_var(1);
721 let h = group_var(1);
722
723 let term1 = Term {
724 scalar: x.into(),
725 elem: g,
726 };
727 let term2 = Term {
728 scalar: y.into(),
729 elem: h,
730 };
731
732 let diff = term1 - term2;
733 assert_eq!(diff.terms().len(), 2);
734 assert_eq!(diff.terms()[0].term, term2);
735 assert_eq!(diff.terms()[0].weight, -Scalar::ONE);
736 assert_eq!(diff.terms()[1].term, term1);
737 assert_eq!(diff.terms()[1].weight, Scalar::ONE);
738 }
739
740 #[test]
741 fn test_sum_addition_chaining() {
742 let x = scalar_var(0);
743 let y = scalar_var(1);
744 let z = scalar_var(2);
745
746 let sum = x + y + z;
747 assert_eq!(sum.terms().len(), 3);
748 assert_eq!(sum.terms()[0], x);
749 assert_eq!(sum.terms()[1], y);
750 assert_eq!(sum.terms()[2], z);
751 }
752
753 #[test]
754 fn test_sum_plus_scalar_var() {
755 let x = scalar_var(0);
756 let y = scalar_var(1);
757 let z = scalar_var(2);
758
759 let sum = x + y;
760 let result = z + sum;
761 assert_eq!(result.terms().len(), 3);
762 assert_eq!(result.terms()[0], x);
763 assert_eq!(result.terms()[1], y);
764 assert_eq!(result.terms()[2], z);
765 }
766
767 #[test]
768 fn test_sum_plus_sum() {
769 let x = scalar_var(0);
770 let y = scalar_var(1);
771 let z = scalar_var(2);
772 let w = scalar_var(3);
773
774 let sum1 = x + y;
775 let sum2 = z + w;
776 let result = sum1 + sum2;
777
778 assert_eq!(result.terms().len(), 4);
779 assert_eq!(result.terms()[0], x);
780 assert_eq!(result.terms()[1], y);
781 assert_eq!(result.terms()[2], z);
782 assert_eq!(result.terms()[3], w);
783 }
784
785 #[test]
786 fn test_sum_negation() {
787 let x = scalar_var(0);
788 let y = scalar_var(1);
789
790 let sum = x + y;
791 let neg_sum = -sum;
792
793 assert_eq!(neg_sum.terms().len(), 2);
794 assert_eq!(neg_sum.terms()[0].term, x);
795 assert_eq!(neg_sum.terms()[0].weight, -Scalar::ONE);
796 assert_eq!(neg_sum.terms()[1].term, y);
797 assert_eq!(neg_sum.terms()[1].weight, -Scalar::ONE);
798 }
799
800 #[test]
801 fn test_weighted_addition() {
802 let x = scalar_var(0);
803 let y = scalar_var(1);
804
805 let weighted1 = x * Scalar::from(3u64);
806 let weighted2 = y * Scalar::from(5u64);
807 let sum = weighted1 + weighted2;
808
809 assert_eq!(sum.terms().len(), 2);
810 assert_eq!(sum.terms()[0].term, x);
811 assert_eq!(sum.terms()[0].weight, Scalar::from(3u64));
812 assert_eq!(sum.terms()[1].term, y);
813 assert_eq!(sum.terms()[1].weight, Scalar::from(5u64));
814 }
815
816 #[test]
817 fn test_weighted_plus_term() {
818 let x = scalar_var(0);
819 let y = scalar_var(1);
820
821 let weighted = x * Scalar::from(2u64);
822 let sum = weighted + y;
823
824 assert_eq!(sum.terms().len(), 2);
825 assert_eq!(sum.terms()[0].term, x);
826 assert_eq!(sum.terms()[0].weight, Scalar::from(2u64));
827 assert_eq!(sum.terms()[1].term, y);
828 assert_eq!(sum.terms()[1].weight, Scalar::ONE);
829 }
830
831 #[test]
832 fn test_weighted_scalar_multiplication() {
833 let x = scalar_var(0);
834 let weighted = x * Scalar::from(2u64);
835 let result = weighted * Scalar::from(3u64);
836
837 assert_eq!(result.term, x);
838 assert_eq!(result.weight, Scalar::from(6u64));
839 }
840
841 #[test]
842 fn test_weighted_group_var_times_scalar_var() {
843 let x = scalar_var(0);
844 let g = group_var(0);
845
846 let weighted_g = g * Scalar::from(5u64);
847 let result = x * weighted_g;
848
849 assert_eq!(result.term.scalar, x.into());
850 assert_eq!(result.term.elem, g);
851 assert_eq!(result.weight, Scalar::from(5u64));
852 }
853
854 #[test]
855 fn test_weighted_scalar_var_times_group_var() {
856 let x = scalar_var(0);
857 let g = group_var(0);
858
859 let weighted_x = x * Scalar::from(3u64);
860 let result = weighted_x * g;
861
862 assert_eq!(result.term.scalar, x.into());
863 assert_eq!(result.term.elem, g);
864 assert_eq!(result.weight, Scalar::from(3u64));
865 }
866
867 #[test]
868 fn test_sum_scalar_multiplication_distributive() {
869 let x = scalar_var(0);
870 let y = scalar_var(1);
871
872 let sum = x + y;
873 let result = sum * Scalar::from(2u64);
874
875 assert_eq!(result.terms().len(), 2);
876 assert_eq!(result.terms()[0].term, x);
877 assert_eq!(result.terms()[0].weight, Scalar::from(2u64));
878 assert_eq!(result.terms()[1].term, y);
879 assert_eq!(result.terms()[1].weight, Scalar::from(2u64));
880 }
881
882 #[test]
883 fn test_sum_subtraction_distributive() {
884 let x = scalar_var(0);
885 let y = scalar_var(1);
886 let z = scalar_var(2);
887
888 let sum1 = x + y;
889 let result = sum1 - z;
890
891 assert_eq!(result.terms().len(), 3);
892 assert_eq!(result.terms()[0].term, x);
893 assert_eq!(result.terms()[0].weight, Scalar::ONE);
894 assert_eq!(result.terms()[1].term, y);
895 assert_eq!(result.terms()[1].weight, Scalar::ONE);
896 assert_eq!(result.terms()[2].term, z);
897 assert_eq!(result.terms()[2].weight, -Scalar::ONE);
898 }
899
900 #[test]
901 fn test_weighted_sum_scalar_multiplication() {
902 let x = scalar_var(0);
903 let y = scalar_var(1);
904
905 let weighted1 = x * Scalar::from(2u64);
906 let weighted2 = y * Scalar::from(3u64);
907 let sum = weighted1 + weighted2;
908 let result = sum * Scalar::from(4u64);
909
910 assert_eq!(result.terms().len(), 2);
911 assert_eq!(result.terms()[0].term, x);
912 assert_eq!(result.terms()[0].weight, Scalar::from(8u64));
913 assert_eq!(result.terms()[1].term, y);
914 assert_eq!(result.terms()[1].weight, Scalar::from(12u64));
915 }
916
917 #[test]
918 fn test_pedersen_commitment_expression() {
919 let x = scalar_var(0);
920 let r = scalar_var(1);
921 let g = group_var(0);
922 let h = group_var(1);
923
924 let commitment = x * g + r * h;
925 assert_eq!(commitment.terms().len(), 2);
926 assert_eq!(commitment.terms()[0].scalar, x.into());
927 assert_eq!(commitment.terms()[0].elem, g);
928 assert_eq!(commitment.terms()[1].scalar, r.into());
929 assert_eq!(commitment.terms()[1].elem, h);
930 }
931
932 #[test]
933 fn test_weighted_pedersen_commitment() {
934 let x = scalar_var(0);
935 let r = scalar_var(1);
936 let g = group_var(0);
937 let h = group_var(1);
938
939 let commitment = x * g * Scalar::from(3u64) + r * h * Scalar::from(2u64);
940 assert_eq!(commitment.terms().len(), 2);
941 assert_eq!(commitment.terms()[0].term.scalar, x.into());
942 assert_eq!(commitment.terms()[0].term.elem, g);
943 assert_eq!(commitment.terms()[0].weight, Scalar::from(3u64));
944 assert_eq!(commitment.terms()[1].term.scalar, r.into());
945 assert_eq!(commitment.terms()[1].term.elem, h);
946 assert_eq!(commitment.terms()[1].weight, Scalar::from(2u64));
947 }
948
949 #[test]
950 fn test_complex_multi_term_expression() {
951 let scalars = [scalar_var(0), scalar_var(1), scalar_var(2), scalar_var(3)];
952 let groups = [group_var(0), group_var(1), group_var(2), group_var(3)];
953
954 let expr = scalars[0] * groups[0] + scalars[1] * groups[1] + scalars[2] * groups[2]
955 - scalars[3] * groups[3];
956
957 assert_eq!(expr.terms().len(), 4);
958
959 for i in 0..3 {
960 assert_eq!(expr.terms()[i].term.scalar, scalars[i].into());
961 assert_eq!(expr.terms()[i].term.elem, groups[i]);
962 assert_eq!(expr.terms()[i].weight, Scalar::ONE);
963 }
964
965 assert_eq!(expr.terms()[3].term.scalar, scalars[3].into());
966 assert_eq!(expr.terms()[3].term.elem, groups[3]);
967 assert_eq!(expr.terms()[3].weight, -Scalar::ONE);
968 }
969
970 #[test]
971 fn test_chained_addition_with_coefficients() {
972 let x = scalar_var(0);
973 let y = scalar_var(1);
974 let z = scalar_var(2);
975 let g = group_var(0);
976 let h = group_var(1);
977 let k = group_var(2);
978
979 let expr =
980 x * g * Scalar::from(2u64) + y * h * Scalar::from(3u64) + z * k * Scalar::from(5u64);
981 assert_eq!(expr.terms().len(), 3);
982
983 let expected_coeffs = [2u64, 3u64, 5u64];
984 let expected_scalars = [x, y, z];
985 let expected_groups = [g, h, k];
986
987 for i in 0..3 {
988 assert_eq!(expr.terms()[i].term.scalar, expected_scalars[i].into());
989 assert_eq!(expr.terms()[i].term.elem, expected_groups[i]);
990 assert_eq!(expr.terms()[i].weight, Scalar::from(expected_coeffs[i]));
991 }
992 }
993
994 #[test]
995 fn test_mixing_sum_term_and_sum_weighted() {
996 let x = scalar_var(0);
997 let y = scalar_var(1);
998 let z = scalar_var(2);
999 let g = group_var(0);
1000 let h = group_var(1);
1001 let k = group_var(2);
1002
1003 let basic_sum = x * g + y * h; let weighted_term = z * k * Scalar::from(3u64); let mixed = basic_sum + weighted_term;
1006
1007 assert_eq!(mixed.terms().len(), 3);
1008 assert_eq!(mixed.terms()[0].term.scalar, x.into());
1009 assert_eq!(mixed.terms()[0].term.elem, g);
1010 assert_eq!(mixed.terms()[0].weight, Scalar::ONE);
1011 assert_eq!(mixed.terms()[1].term.scalar, y.into());
1012 assert_eq!(mixed.terms()[1].term.elem, h);
1013 assert_eq!(mixed.terms()[1].weight, Scalar::ONE);
1014 assert_eq!(mixed.terms()[2].term.scalar, z.into());
1015 assert_eq!(mixed.terms()[2].term.elem, k);
1016 assert_eq!(mixed.terms()[2].weight, Scalar::from(3u64));
1017 }
1018}