hacspec_lib/
hacspec_helper.rs

1use std::convert::TryInto;
2
3/// This has to come from the lib.
4pub use natmod::nat_mod;
5
6pub trait NatMod<const LEN: usize> {
7    const MODULUS: [u8; LEN];
8    const MODULUS_STR: &'static str;
9    const ZERO: [u8; LEN];
10
11    fn new(value: [u8; LEN]) -> Self;
12    fn value(&self) -> &[u8];
13
14    /// Sub self with `rhs` and return the result `self - rhs % MODULUS`.
15    fn fsub(self, rhs: Self) -> Self
16    where
17        Self: Sized,
18    {
19        let lhs = num_bigint::BigUint::from_bytes_be(self.value());
20        let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
21        let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
22        let res = if lhs < rhs {
23            modulus.clone() + lhs - rhs
24        } else {
25            lhs - rhs
26        };
27        let res = res % modulus;
28        Self::from_bigint(res)
29    }
30
31    /// Add self with `rhs` and return the result `self + rhs % MODULUS`.
32    fn fadd(self, rhs: Self) -> Self
33    where
34        Self: Sized,
35    {
36        let lhs = num_bigint::BigUint::from_bytes_be(self.value());
37        let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
38        let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
39        let res = (lhs + rhs) % modulus;
40        Self::from_bigint(res)
41    }
42
43    /// Multiply self with `rhs` and return the result `self * rhs % MODULUS`.
44    fn fmul(self, rhs: Self) -> Self
45    where
46        Self: Sized,
47    {
48        let lhs = num_bigint::BigUint::from_bytes_be(self.value());
49        let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
50        let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
51        let res = (lhs * rhs) % modulus;
52        Self::from_bigint(res)
53    }
54
55    /// `self ^ rhs % MODULUS`.
56    fn pow(self, rhs: u128) -> Self
57    where
58        Self: Sized,
59    {
60        let lhs = num_bigint::BigUint::from_bytes_be(self.value());
61        let rhs = num_bigint::BigUint::from(rhs);
62        let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
63        let res = lhs.modpow(&rhs, &modulus);
64        Self::from_bigint(res)
65    }
66
67    /// `self ^ rhs % MODULUS`.
68    fn pow_felem(self, rhs: &Self) -> Self
69    where
70        Self: Sized,
71    {
72        let lhs = num_bigint::BigUint::from_bytes_be(self.value());
73        let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
74        let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
75        let res = lhs.modpow(&rhs, &modulus);
76        Self::from_bigint(res)
77    }
78
79    /// Invert self and return the result `self ^ -1 % MODULUS`.
80    fn inv(self) -> Self
81    where
82        Self: Sized,
83    {
84        let val = num_bigint::BigUint::from_bytes_be(self.value());
85        let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
86        let m = &modulus - num_bigint::BigUint::from(2u32);
87        Self::from_bigint(val.modpow(&m, &modulus))
88    }
89
90    fn inv0(self) -> Self
91    where
92        Self: Sized,
93    {
94        if self.value() == Self::zero().value() {
95            Self::zero()
96        } else {
97            self.inv()
98        }
99    }
100
101    /// Zero element
102    fn zero() -> Self
103    where
104        Self: Sized,
105    {
106        Self::new(Self::ZERO)
107    }
108
109    /// One element
110    fn one() -> Self
111    where
112        Self: Sized,
113    {
114        let out = Self::new(Self::ZERO);
115        out.fadd(Self::from_u128(1))
116    }
117
118    /// One element
119    fn two() -> Self
120    where
121        Self: Sized,
122    {
123        let out = Self::new(Self::ZERO);
124        out.fadd(Self::from_u128(2))
125    }
126
127    fn bit(&self, bit: u128) -> bool {
128        let val = num_bigint::BigUint::from_bytes_be(self.value());
129        val.bit(bit.try_into().unwrap())
130    }
131
132    /// Returns 2 to the power of the argument
133    fn pow2(x: usize) -> Self
134    where
135        Self: Sized,
136    {
137        let res = num_bigint::BigUint::from(1u32) << x;
138        Self::from_bigint(res)
139    }
140
141    fn neg(self) -> Self
142    where
143        Self: Sized,
144    {
145        Self::zero().fsub(self)
146    }
147
148    /// Create a new [`#ident`] from a `u128` literal.
149    fn from_u128(literal: u128) -> Self
150    where
151        Self: Sized,
152    {
153        Self::from_bigint(num_bigint::BigUint::from(literal))
154    }
155
156    /// Create a new [`#ident`] from a little endian byte slice.
157    ///
158    /// This computes bytes % MODULUS
159    fn from_le_bytes(bytes: &[u8]) -> Self
160    where
161        Self: Sized,
162    {
163        let value = num_bigint::BigUint::from_bytes_le(bytes);
164        let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
165        Self::from_bigint(value % modulus)
166    }
167
168    /// Create a new [`#ident`] from a little endian byte slice.
169    ///
170    /// This computes bytes % MODULUS
171    fn from_be_bytes(bytes: &[u8]) -> Self
172    where
173        Self: Sized,
174    {
175        let value = num_bigint::BigUint::from_bytes_be(bytes);
176        let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
177        Self::from_bigint(value % modulus)
178    }
179
180    fn to_le_bytes(self) -> [u8; LEN]
181    where
182        Self: Sized,
183    {
184        Self::pad(&num_bigint::BigUint::from_bytes_be(self.value()).to_bytes_le())
185    }
186
187    fn to_be_bytes(self) -> [u8; LEN]
188    where
189        Self: Sized,
190    {
191        self.value().try_into().unwrap()
192    }
193
194    /// Get hex string representation of this.
195    fn to_hex(&self) -> String {
196        let strs: Vec<String> = self.value().iter().map(|b| format!("{:02x}", b)).collect();
197        strs.join("")
198    }
199
200    /// New from hex string
201    fn from_hex(hex: &str) -> Self
202    where
203        Self: Sized,
204    {
205        debug_assert!(hex.len() % 2 == 0);
206        let l = hex.len() / 2;
207        let mut value = vec![0u8; l];
208        for i in 0..l {
209            value[i] = u8::from_str_radix(&hex[2 * i..2 * i + 2], 16)
210                .expect("An unexpected error occurred.");
211        }
212
213        Self::from_be_bytes(&value)
214    }
215
216    fn pad(bytes: &[u8]) -> [u8; LEN] {
217        let mut value = [0u8; LEN];
218        let upper = value.len();
219        let lower = upper - bytes.len();
220        value[lower..upper].copy_from_slice(bytes);
221        value
222    }
223
224    fn from_bigint(x: num_bigint::BigUint) -> Self
225    where
226        Self: Sized,
227    {
228        let max_value = Self::MODULUS;
229        debug_assert!(
230            x <= num_bigint::BigUint::from_bytes_be(&max_value),
231            "{} is too large for type {}!",
232            x,
233            stringify!($ident)
234        );
235        let repr = x.to_bytes_be();
236        if repr.len() > LEN {
237            panic!("{} is too large for this type", x)
238        }
239
240        Self::new(Self::pad(&repr))
241    }
242}
243
244// === Secret Integers
245
246pub type U8 = u8;
247#[allow(dead_code, non_snake_case)]
248pub fn U8(x: u8) -> u8 {
249    x
250}
251
252// === Test vector helpers
253pub use std::io::Write;
254#[macro_export]
255macro_rules! create_test_vectors {
256    ($struct_name: ident, $($element: ident: $ty: ty),+) => {
257        #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
258        #[allow(non_snake_case)]
259        struct $struct_name { $($element: $ty),+ }
260        impl $struct_name {
261            #[cfg_attr(feature="use_attributes", not_hacspec)]
262            pub fn from_file<T: serde::de::DeserializeOwned>(file: &'static str) -> T {
263                let file = match std::fs::File::open(file) {
264                    Ok(f) => f,
265                    Err(_) => panic!("Couldn't open file {}.", file),
266                };
267                let reader = std::io::BufReader::new(file);
268                match serde_json::from_reader(reader) {
269                    Ok(r) => r,
270                    Err(e) => {
271                        println!("{:?}", e);
272                        panic!("Error reading file.")
273                    },
274                }
275            }
276            #[cfg_attr(feature="use_attributes", not_hacspec)]
277            pub fn write_file(&self, file: &'static str) {
278                let mut file = match std::fs::File::create(file) {
279                    Ok(f) => f,
280                    Err(_) => panic!("Couldn't open file {}.", file),
281                };
282                let json = match serde_json::to_string_pretty(&self) {
283                    Ok(j) => j,
284                    Err(_) => panic!("Couldn't serialize this object."),
285                };
286                match file.write_all(&json.into_bytes()) {
287                    Ok(_) => (),
288                    Err(_) => panic!("Error writing to file."),
289                }
290            }
291            #[cfg_attr(feature="use_attributes", not_hacspec)]
292            pub fn new_array(file: &'static str) -> Vec<Self> {
293                let file = match std::fs::File::open(file) {
294                    Ok(f) => f,
295                    Err(_) => panic!("Couldn't open file."),
296                };
297                let reader = std::io::BufReader::new(file);
298                match serde_json::from_reader(reader) {
299                    Ok(r) => r,
300                    Err(e) => {
301                        println!("{:?}", e);
302                        panic!("Error reading file.")
303                    },
304                }
305            }
306        }
307    };
308}