hacspec_lib/
hacspec_helper.rs1use std::convert::TryInto;
2
3pub use natmod::nat_mod;
5
6pub trait NatMod<const LEN: usize> {
7 const MODULUS: [u8; LEN];
8 const MODULUS_STR: &'static str;
9 const ZERO: [u8; LEN];
10
11 fn new(value: [u8; LEN]) -> Self;
12 fn value(&self) -> &[u8];
13
14 fn fsub(self, rhs: Self) -> Self
16 where
17 Self: Sized,
18 {
19 let lhs = num_bigint::BigUint::from_bytes_be(self.value());
20 let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
21 let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
22 let res = if lhs < rhs {
23 modulus.clone() + lhs - rhs
24 } else {
25 lhs - rhs
26 };
27 let res = res % modulus;
28 Self::from_bigint(res)
29 }
30
31 fn fadd(self, rhs: Self) -> Self
33 where
34 Self: Sized,
35 {
36 let lhs = num_bigint::BigUint::from_bytes_be(self.value());
37 let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
38 let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
39 let res = (lhs + rhs) % modulus;
40 Self::from_bigint(res)
41 }
42
43 fn fmul(self, rhs: Self) -> Self
45 where
46 Self: Sized,
47 {
48 let lhs = num_bigint::BigUint::from_bytes_be(self.value());
49 let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
50 let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
51 let res = (lhs * rhs) % modulus;
52 Self::from_bigint(res)
53 }
54
55 fn pow(self, rhs: u128) -> Self
57 where
58 Self: Sized,
59 {
60 let lhs = num_bigint::BigUint::from_bytes_be(self.value());
61 let rhs = num_bigint::BigUint::from(rhs);
62 let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
63 let res = lhs.modpow(&rhs, &modulus);
64 Self::from_bigint(res)
65 }
66
67 fn pow_felem(self, rhs: &Self) -> Self
69 where
70 Self: Sized,
71 {
72 let lhs = num_bigint::BigUint::from_bytes_be(self.value());
73 let rhs = num_bigint::BigUint::from_bytes_be(rhs.value());
74 let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
75 let res = lhs.modpow(&rhs, &modulus);
76 Self::from_bigint(res)
77 }
78
79 fn inv(self) -> Self
81 where
82 Self: Sized,
83 {
84 let val = num_bigint::BigUint::from_bytes_be(self.value());
85 let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
86 let m = &modulus - num_bigint::BigUint::from(2u32);
87 Self::from_bigint(val.modpow(&m, &modulus))
88 }
89
90 fn inv0(self) -> Self
91 where
92 Self: Sized,
93 {
94 if self.value() == Self::zero().value() {
95 Self::zero()
96 } else {
97 self.inv()
98 }
99 }
100
101 fn zero() -> Self
103 where
104 Self: Sized,
105 {
106 Self::new(Self::ZERO)
107 }
108
109 fn one() -> Self
111 where
112 Self: Sized,
113 {
114 let out = Self::new(Self::ZERO);
115 out.fadd(Self::from_u128(1))
116 }
117
118 fn two() -> Self
120 where
121 Self: Sized,
122 {
123 let out = Self::new(Self::ZERO);
124 out.fadd(Self::from_u128(2))
125 }
126
127 fn bit(&self, bit: u128) -> bool {
128 let val = num_bigint::BigUint::from_bytes_be(self.value());
129 val.bit(bit.try_into().unwrap())
130 }
131
132 fn pow2(x: usize) -> Self
134 where
135 Self: Sized,
136 {
137 let res = num_bigint::BigUint::from(1u32) << x;
138 Self::from_bigint(res)
139 }
140
141 fn neg(self) -> Self
142 where
143 Self: Sized,
144 {
145 Self::zero().fsub(self)
146 }
147
148 fn from_u128(literal: u128) -> Self
150 where
151 Self: Sized,
152 {
153 Self::from_bigint(num_bigint::BigUint::from(literal))
154 }
155
156 fn from_le_bytes(bytes: &[u8]) -> Self
160 where
161 Self: Sized,
162 {
163 let value = num_bigint::BigUint::from_bytes_le(bytes);
164 let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
165 Self::from_bigint(value % modulus)
166 }
167
168 fn from_be_bytes(bytes: &[u8]) -> Self
172 where
173 Self: Sized,
174 {
175 let value = num_bigint::BigUint::from_bytes_be(bytes);
176 let modulus = num_bigint::BigUint::from_bytes_be(&Self::MODULUS);
177 Self::from_bigint(value % modulus)
178 }
179
180 fn to_le_bytes(self) -> [u8; LEN]
181 where
182 Self: Sized,
183 {
184 Self::pad(&num_bigint::BigUint::from_bytes_be(self.value()).to_bytes_le())
185 }
186
187 fn to_be_bytes(self) -> [u8; LEN]
188 where
189 Self: Sized,
190 {
191 self.value().try_into().unwrap()
192 }
193
194 fn to_hex(&self) -> String {
196 let strs: Vec<String> = self.value().iter().map(|b| format!("{:02x}", b)).collect();
197 strs.join("")
198 }
199
200 fn from_hex(hex: &str) -> Self
202 where
203 Self: Sized,
204 {
205 debug_assert!(hex.len() % 2 == 0);
206 let l = hex.len() / 2;
207 let mut value = vec![0u8; l];
208 for i in 0..l {
209 value[i] = u8::from_str_radix(&hex[2 * i..2 * i + 2], 16)
210 .expect("An unexpected error occurred.");
211 }
212
213 Self::from_be_bytes(&value)
214 }
215
216 fn pad(bytes: &[u8]) -> [u8; LEN] {
217 let mut value = [0u8; LEN];
218 let upper = value.len();
219 let lower = upper - bytes.len();
220 value[lower..upper].copy_from_slice(bytes);
221 value
222 }
223
224 fn from_bigint(x: num_bigint::BigUint) -> Self
225 where
226 Self: Sized,
227 {
228 let max_value = Self::MODULUS;
229 debug_assert!(
230 x <= num_bigint::BigUint::from_bytes_be(&max_value),
231 "{} is too large for type {}!",
232 x,
233 stringify!($ident)
234 );
235 let repr = x.to_bytes_be();
236 if repr.len() > LEN {
237 panic!("{} is too large for this type", x)
238 }
239
240 Self::new(Self::pad(&repr))
241 }
242}
243
244pub type U8 = u8;
247#[allow(dead_code, non_snake_case)]
248pub fn U8(x: u8) -> u8 {
249 x
250}
251
252pub use std::io::Write;
254#[macro_export]
255macro_rules! create_test_vectors {
256 ($struct_name: ident, $($element: ident: $ty: ty),+) => {
257 #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
258 #[allow(non_snake_case)]
259 struct $struct_name { $($element: $ty),+ }
260 impl $struct_name {
261 #[cfg_attr(feature="use_attributes", not_hacspec)]
262 pub fn from_file<T: serde::de::DeserializeOwned>(file: &'static str) -> T {
263 let file = match std::fs::File::open(file) {
264 Ok(f) => f,
265 Err(_) => panic!("Couldn't open file {}.", file),
266 };
267 let reader = std::io::BufReader::new(file);
268 match serde_json::from_reader(reader) {
269 Ok(r) => r,
270 Err(e) => {
271 println!("{:?}", e);
272 panic!("Error reading file.")
273 },
274 }
275 }
276 #[cfg_attr(feature="use_attributes", not_hacspec)]
277 pub fn write_file(&self, file: &'static str) {
278 let mut file = match std::fs::File::create(file) {
279 Ok(f) => f,
280 Err(_) => panic!("Couldn't open file {}.", file),
281 };
282 let json = match serde_json::to_string_pretty(&self) {
283 Ok(j) => j,
284 Err(_) => panic!("Couldn't serialize this object."),
285 };
286 match file.write_all(&json.into_bytes()) {
287 Ok(_) => (),
288 Err(_) => panic!("Error writing to file."),
289 }
290 }
291 #[cfg_attr(feature="use_attributes", not_hacspec)]
292 pub fn new_array(file: &'static str) -> Vec<Self> {
293 let file = match std::fs::File::open(file) {
294 Ok(f) => f,
295 Err(_) => panic!("Couldn't open file."),
296 };
297 let reader = std::io::BufReader::new(file);
298 match serde_json::from_reader(reader) {
299 Ok(r) => r,
300 Err(e) => {
301 println!("{:?}", e);
302 panic!("Error reading file.")
303 },
304 }
305 }
306 }
307 };
308}