hash_to_curve/
p256_hash.rs1use crate::Error;
4use hacspec_lib::{hacspec_helper::NatMod, i2osp, FunctionalVec};
5use p256::{is_square, sgn0, sqrt, P256FieldElement, P256Point, P256Scalar};
6use sha256::hash;
7
8#[allow(non_camel_case_types)]
15pub struct P256_XMD_SHA256_SSWU_RO {}
16
17const L: usize = 48;
19const B_IN_BYTES: usize = sha256::HASH_SIZE;
21const S_IN_BYTES: usize = 64;
23
24#[allow(non_snake_case)]
25fn expand_message(msg: &[u8], dst: &[u8], len_in_bytes: usize) -> Result<Vec<u8>, Error> {
26 let ell = (len_in_bytes + B_IN_BYTES - 1) / B_IN_BYTES;
27 if ell > 255 || len_in_bytes > u16::MAX.into() || dst.len() > 255 {
28 return Err(Error::InvalidEll);
29 }
30
31 let dst_prime = dst.concat_byte(dst.len() as u8);
32 let z_pad = vec![0u8; S_IN_BYTES];
33 let l_i_b_str = i2osp(len_in_bytes, 2);
34
35 let msg_prime = z_pad
37 .concat(msg)
38 .concat(&l_i_b_str)
39 .concat(&[0u8; 1])
40 .concat(&dst_prime);
41
42 let b_0 = hash(&msg_prime).to_vec(); let payload_1 = b_0.concat_byte(1).concat(&dst_prime);
45 let mut b_i = hash(&payload_1).to_vec(); let mut uniform_bytes = b_i.clone();
48 for i in 2..=ell {
49 let payload_i = strxor(&b_0, &b_i).concat_byte(i as u8).concat(&dst_prime);
51 b_i = hash(&payload_i).to_vec();
53 uniform_bytes.extend_from_slice(&b_i);
54 }
55 uniform_bytes.truncate(len_in_bytes);
56 Ok(uniform_bytes)
57}
58
59fn strxor(a: &[u8], b: &[u8]) -> Vec<u8> {
60 debug_assert_eq!(a.len(), b.len());
61 a.iter().zip(b.iter()).map(|(a, b)| a ^ b).collect()
62}
63
64pub fn hash_to_field(msg: &[u8], dst: &[u8], count: usize) -> Result<Vec<P256FieldElement>, Error> {
66 let len_in_bytes = count * L;
67 let uniform_bytes = expand_message(msg, dst, len_in_bytes)?;
68 let mut u = Vec::with_capacity(count);
69 for i in 0..count {
70 let elm_offset = L * i;
71 let tv = &uniform_bytes[elm_offset..L * (i + 1)];
72 let tv = P256FieldElement::from_be_bytes(tv);
73 u.push(tv);
74 }
75 Ok(u)
76}
77
78pub fn hash_to_scalar(msg: &[u8], dst: &[u8], count: usize) -> Result<Vec<P256Scalar>, Error> {
80 let len_in_bytes = count * L;
81 let uniform_bytes = expand_message(msg, dst, len_in_bytes)?;
82 let mut u = Vec::with_capacity(count);
83 for i in 0..count {
84 let elm_offset = L * i;
85 let tv = &uniform_bytes[elm_offset..L * (i + 1)];
86 let tv = P256Scalar::from_be_bytes(tv);
87 u.push(tv);
88 }
89 Ok(u)
90}
91
92fn map_to_curve(u: P256FieldElement) -> P256Point {
94 let a: &P256FieldElement = &P256FieldElement::from_u128(3u128).neg();
95 let b = &P256FieldElement::from_hex(
96 "5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b",
97 );
98 let z = P256FieldElement::from_u128(10u128).neg();
99 let tv1 = (z.pow(2) * u.pow(4) + z * u.pow(2)).inv0();
100 let x1 = if tv1 == P256FieldElement::zero() {
101 *b * (z * *a).inv()
102 } else {
103 (b.neg() * a.inv()) * (tv1 + P256FieldElement::from_u128(1u128))
104 };
105
106 let gx1 = x1.pow(3) + (*a) * x1 + (*b);
107 let x2 = z * u.pow(2) * x1;
108 let gx2 = x2.pow(3) + *a * x2 + *b;
109
110 let mut output = if is_square(&gx1) {
111 (x1, sqrt(&gx1))
112 } else {
113 (x2, sqrt(&gx2))
114 };
115
116 if sgn0(&u) != sgn0(&output.1) {
117 output.1 = output.1.neg();
118 }
119
120 output.into()
121}
122
123pub fn hash_to_curve(msg: &[u8], dst: &[u8]) -> Result<P256Point, Error> {
125 let u: Vec<P256FieldElement> = hash_to_field(msg, dst, 2)?;
126 let q0 = map_to_curve(u[0]);
127 let q1 = map_to_curve(u[1]);
128 let r = p256::point_add(q0, q1)?;
129 Ok(r)
130}
131
132#[cfg(test)]
133mod tests {
134 use super::*;
135 use std::fs::read_to_string;
136 const ID: &str = "P256_XMD_SHA-256_SSWU_RO_";
137 use serde_json::Value;
138
139 pub fn load_vectors(path: &std::path::Path) -> Value {
140 serde_json::from_str(&read_to_string(path).expect("File not found.")).unwrap()
141 }
142
143 #[test]
144 fn p256_xmd_sha256_sswu_ro_hash_to_field() {
145 let mut vector_path = std::path::Path::new("vectors").join(ID);
146 vector_path.set_extension("json");
147 eprintln!(" Reading {}", vector_path.display());
148
149 let tests = load_vectors(vector_path.as_path());
150 let dst = tests["dst"].as_str().unwrap().as_bytes();
151
152 for test_case in tests["vectors"].as_array().unwrap().iter() {
155 let msg_str = test_case["msg"].as_str().unwrap();
156 let msg = msg_str.as_bytes();
157
158 let u_expected: Vec<_> = test_case["u"]
159 .as_array()
160 .unwrap()
161 .iter()
162 .map(|u_i| {
163 let u_i = u_i.as_str().unwrap();
164 let u0_expected = u_i.trim_start_matches("0x");
165 P256FieldElement::from_be_bytes(&hex::decode(u0_expected).unwrap())
166 })
167 .collect();
168
169 let u_real = hash_to_field(msg, dst, 2).unwrap();
170 debug_assert_eq!(u_real.len(), u_expected.len());
171 for (u_real, u_expected) in u_real.iter().zip(u_expected.iter()) {
172 debug_assert_eq!(
173 u_expected.as_ref(),
174 u_real.as_ref(),
175 "u0 did not match for {msg_str}",
176 );
177 }
178 }
179 }
180 #[test]
181 fn p256_xmd_sha256_sswu_ro_map_to_curve() {
182 let mut vector_path = std::path::Path::new("vectors").join(ID);
183 vector_path.set_extension("json");
184 let vectors = load_vectors(vector_path.as_path());
185
186 let test_cases = vectors["vectors"].as_array().unwrap().clone();
187
188 for test_case in test_cases.iter() {
189 let u = test_case["u"].as_array().unwrap();
190 let u0 = u[0].as_str().unwrap().trim_start_matches("0x");
191 let u0 = P256FieldElement::from_be_bytes(&hex::decode(u0).unwrap());
192 let u1 = u[1].as_str().unwrap().trim_start_matches("0x");
193 let u1 = P256FieldElement::from_be_bytes(&hex::decode(u1).unwrap());
194
195 let (q0_x, q0_y) = map_to_curve(u0).into();
196 let (q1_x, q1_y) = map_to_curve(u1).into();
197
198 let q0_expected = &test_case["Q0"];
199 let q0_x_expected = q0_expected["x"].as_str().unwrap().trim_start_matches("0x");
200 let q0_x_expected =
201 P256FieldElement::from_be_bytes(&hex::decode(q0_x_expected).unwrap());
202 let q0_y_expected = q0_expected["y"].as_str().unwrap().trim_start_matches("0x");
203 let q0_y_expected =
204 P256FieldElement::from_be_bytes(&hex::decode(q0_y_expected).unwrap());
205
206 let q1_expected = &test_case["Q1"];
207 let q1_x_expected = q1_expected["x"].as_str().unwrap().trim_start_matches("0x");
208 let q1_x_expected =
209 P256FieldElement::from_be_bytes(&hex::decode(q1_x_expected).unwrap());
210 let q1_y_expected = q1_expected["y"].as_str().unwrap().trim_start_matches("0x");
211 let q1_y_expected =
212 P256FieldElement::from_be_bytes(&hex::decode(q1_y_expected).unwrap());
213
214 debug_assert_eq!(q0_x_expected, q0_x, "x0 incorrect");
215 debug_assert_eq!(q0_y_expected, q0_y, "y0 incorrect");
216
217 debug_assert_eq!(q1_x_expected, q1_x, "x1 incorrect");
218 debug_assert_eq!(q1_y_expected, q1_y, "y1 incorrect");
219 }
220 }
221
222 #[test]
223 fn p256_xmd_sha256_sswu_ro_hash_to_curve() {
224 let mut vector_path = std::path::Path::new("vectors").join(ID);
225 vector_path.set_extension("json");
226 let vectors = load_vectors(vector_path.as_path());
227
228 let dst = vectors["dst"].as_str().unwrap();
229 let dst = dst.as_bytes();
230 let test_cases = vectors["vectors"].as_array().unwrap().clone();
231
232 for test_case in test_cases.iter() {
233 let msg = test_case["msg"].as_str().unwrap();
234 let msg = msg.as_bytes();
235
236 let p_expected = &test_case["P"];
237 let p_x_expected = p_expected["x"].as_str().unwrap().trim_start_matches("0x");
238 let p_x_expected = P256FieldElement::from_be_bytes(&hex::decode(p_x_expected).unwrap());
239 let p_y_expected = p_expected["y"].as_str().unwrap().trim_start_matches("0x");
240 let p_y_expected = P256FieldElement::from_be_bytes(&hex::decode(p_y_expected).unwrap());
241
242 let (x, y) = hash_to_curve(msg, dst).unwrap().into();
243
244 debug_assert_eq!(p_x_expected.as_ref(), x.as_ref(), "x-coordinate incorrect");
246 debug_assert_eq!(p_y_expected.as_ref(), y.as_ref(), "y-coordinate incorrect");
247 }
248 }
249}