hash_to_curve/
p256_hash.rs

1//! This module implements Hash-to-Curve for NIST P-256.
2
3use crate::Error;
4use hacspec_lib::{hacspec_helper::NatMod, i2osp, FunctionalVec};
5use p256::{is_square, sgn0, sqrt, P256FieldElement, P256Point, P256Scalar};
6use sha256::hash;
7
8/// # 8.2 Suites for NIST P-256
9///
10/// `P256_XMD:SHA-256_SSWU_RO_`
11///
12/// [`P256_XMD:SHA-256_SSWU_NU_`](P256_XMD_SHA256_SSWU_NU) is identical to `P256_XMD:SHA-256_SSWU_RO_`,
13/// except that the encoding type is encode_to_curve (Section 3).
14#[allow(non_camel_case_types)]
15pub struct P256_XMD_SHA256_SSWU_RO {}
16
17/// bytes to generate per field element in `expand_message`
18const L: usize = 48;
19/// Output size of H = SHA-256 in bytes
20const B_IN_BYTES: usize = sha256::HASH_SIZE;
21/// Input block size of H = SHA-256
22const S_IN_BYTES: usize = 64;
23
24#[allow(non_snake_case)]
25fn expand_message(msg: &[u8], dst: &[u8], len_in_bytes: usize) -> Result<Vec<u8>, Error> {
26    let ell = (len_in_bytes + B_IN_BYTES - 1) / B_IN_BYTES;
27    if ell > 255 || len_in_bytes > u16::MAX.into() || dst.len() > 255 {
28        return Err(Error::InvalidEll);
29    }
30
31    let dst_prime = dst.concat_byte(dst.len() as u8);
32    let z_pad = vec![0u8; S_IN_BYTES];
33    let l_i_b_str = i2osp(len_in_bytes, 2);
34
35    // msg_prime = Z_pad || msg || l_i_b_str || 0 || dst_prime
36    let msg_prime = z_pad
37        .concat(msg)
38        .concat(&l_i_b_str)
39        .concat(&[0u8; 1])
40        .concat(&dst_prime);
41
42    let b_0 = hash(&msg_prime).to_vec(); // H(msg_prime)
43
44    let payload_1 = b_0.concat_byte(1).concat(&dst_prime);
45    let mut b_i = hash(&payload_1).to_vec(); // H(b_0 || 1 || dst_prime)
46
47    let mut uniform_bytes = b_i.clone();
48    for i in 2..=ell {
49        // i < 256 is checked before
50        let payload_i = strxor(&b_0, &b_i).concat_byte(i as u8).concat(&dst_prime);
51        // H((b_0 ^ b_(i-1)) || 1 || dst_prime)
52        b_i = hash(&payload_i).to_vec();
53        uniform_bytes.extend_from_slice(&b_i);
54    }
55    uniform_bytes.truncate(len_in_bytes);
56    Ok(uniform_bytes)
57}
58
59fn strxor(a: &[u8], b: &[u8]) -> Vec<u8> {
60    debug_assert_eq!(a.len(), b.len());
61    a.iter().zip(b.iter()).map(|(a, b)| a ^ b).collect()
62}
63
64/// Hash a message to a P256 field element.
65pub fn hash_to_field(msg: &[u8], dst: &[u8], count: usize) -> Result<Vec<P256FieldElement>, Error> {
66    let len_in_bytes = count * L;
67    let uniform_bytes = expand_message(msg, dst, len_in_bytes)?;
68    let mut u = Vec::with_capacity(count);
69    for i in 0..count {
70        let elm_offset = L * i;
71        let tv = &uniform_bytes[elm_offset..L * (i + 1)];
72        let tv = P256FieldElement::from_be_bytes(tv);
73        u.push(tv);
74    }
75    Ok(u)
76}
77
78/// Hash a message to a P256 scalar.
79pub fn hash_to_scalar(msg: &[u8], dst: &[u8], count: usize) -> Result<Vec<P256Scalar>, Error> {
80    let len_in_bytes = count * L;
81    let uniform_bytes = expand_message(msg, dst, len_in_bytes)?;
82    let mut u = Vec::with_capacity(count);
83    for i in 0..count {
84        let elm_offset = L * i;
85        let tv = &uniform_bytes[elm_offset..L * (i + 1)];
86        let tv = P256Scalar::from_be_bytes(tv);
87        u.push(tv);
88    }
89    Ok(u)
90}
91
92/// SSWU
93fn map_to_curve(u: P256FieldElement) -> P256Point {
94    let a: &P256FieldElement = &P256FieldElement::from_u128(3u128).neg();
95    let b = &P256FieldElement::from_hex(
96        "5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b",
97    );
98    let z = P256FieldElement::from_u128(10u128).neg();
99    let tv1 = (z.pow(2) * u.pow(4) + z * u.pow(2)).inv0();
100    let x1 = if tv1 == P256FieldElement::zero() {
101        *b * (z * *a).inv()
102    } else {
103        (b.neg() * a.inv()) * (tv1 + P256FieldElement::from_u128(1u128))
104    };
105
106    let gx1 = x1.pow(3) + (*a) * x1 + (*b);
107    let x2 = z * u.pow(2) * x1;
108    let gx2 = x2.pow(3) + *a * x2 + *b;
109
110    let mut output = if is_square(&gx1) {
111        (x1, sqrt(&gx1))
112    } else {
113        (x2, sqrt(&gx2))
114    };
115
116    if sgn0(&u) != sgn0(&output.1) {
117        output.1 = output.1.neg();
118    }
119
120    output.into()
121}
122
123/// Hash a message to a point in P256.
124pub fn hash_to_curve(msg: &[u8], dst: &[u8]) -> Result<P256Point, Error> {
125    let u: Vec<P256FieldElement> = hash_to_field(msg, dst, 2)?;
126    let q0 = map_to_curve(u[0]);
127    let q1 = map_to_curve(u[1]);
128    let r = p256::point_add(q0, q1)?;
129    Ok(r)
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use std::fs::read_to_string;
136    const ID: &str = "P256_XMD_SHA-256_SSWU_RO_";
137    use serde_json::Value;
138
139    pub fn load_vectors(path: &std::path::Path) -> Value {
140        serde_json::from_str(&read_to_string(path).expect("File not found.")).unwrap()
141    }
142
143    #[test]
144    fn p256_xmd_sha256_sswu_ro_hash_to_field() {
145        let mut vector_path = std::path::Path::new("vectors").join(ID);
146        vector_path.set_extension("json");
147        eprintln!(" Reading {}", vector_path.display());
148
149        let tests = load_vectors(vector_path.as_path());
150        let dst = tests["dst"].as_str().unwrap().as_bytes();
151
152        //debug_assert_eq!(tests["ciphersuite"].as_str().unwrap(), ID);
153
154        for test_case in tests["vectors"].as_array().unwrap().iter() {
155            let msg_str = test_case["msg"].as_str().unwrap();
156            let msg = msg_str.as_bytes();
157
158            let u_expected: Vec<_> = test_case["u"]
159                .as_array()
160                .unwrap()
161                .iter()
162                .map(|u_i| {
163                    let u_i = u_i.as_str().unwrap();
164                    let u0_expected = u_i.trim_start_matches("0x");
165                    P256FieldElement::from_be_bytes(&hex::decode(u0_expected).unwrap())
166                })
167                .collect();
168
169            let u_real = hash_to_field(msg, dst, 2).unwrap();
170            debug_assert_eq!(u_real.len(), u_expected.len());
171            for (u_real, u_expected) in u_real.iter().zip(u_expected.iter()) {
172                debug_assert_eq!(
173                    u_expected.as_ref(),
174                    u_real.as_ref(),
175                    "u0 did not match for {msg_str}",
176                );
177            }
178        }
179    }
180    #[test]
181    fn p256_xmd_sha256_sswu_ro_map_to_curve() {
182        let mut vector_path = std::path::Path::new("vectors").join(ID);
183        vector_path.set_extension("json");
184        let vectors = load_vectors(vector_path.as_path());
185
186        let test_cases = vectors["vectors"].as_array().unwrap().clone();
187
188        for test_case in test_cases.iter() {
189            let u = test_case["u"].as_array().unwrap();
190            let u0 = u[0].as_str().unwrap().trim_start_matches("0x");
191            let u0 = P256FieldElement::from_be_bytes(&hex::decode(u0).unwrap());
192            let u1 = u[1].as_str().unwrap().trim_start_matches("0x");
193            let u1 = P256FieldElement::from_be_bytes(&hex::decode(u1).unwrap());
194
195            let (q0_x, q0_y) = map_to_curve(u0).into();
196            let (q1_x, q1_y) = map_to_curve(u1).into();
197
198            let q0_expected = &test_case["Q0"];
199            let q0_x_expected = q0_expected["x"].as_str().unwrap().trim_start_matches("0x");
200            let q0_x_expected =
201                P256FieldElement::from_be_bytes(&hex::decode(q0_x_expected).unwrap());
202            let q0_y_expected = q0_expected["y"].as_str().unwrap().trim_start_matches("0x");
203            let q0_y_expected =
204                P256FieldElement::from_be_bytes(&hex::decode(q0_y_expected).unwrap());
205
206            let q1_expected = &test_case["Q1"];
207            let q1_x_expected = q1_expected["x"].as_str().unwrap().trim_start_matches("0x");
208            let q1_x_expected =
209                P256FieldElement::from_be_bytes(&hex::decode(q1_x_expected).unwrap());
210            let q1_y_expected = q1_expected["y"].as_str().unwrap().trim_start_matches("0x");
211            let q1_y_expected =
212                P256FieldElement::from_be_bytes(&hex::decode(q1_y_expected).unwrap());
213
214            debug_assert_eq!(q0_x_expected, q0_x, "x0 incorrect");
215            debug_assert_eq!(q0_y_expected, q0_y, "y0 incorrect");
216
217            debug_assert_eq!(q1_x_expected, q1_x, "x1 incorrect");
218            debug_assert_eq!(q1_y_expected, q1_y, "y1 incorrect");
219        }
220    }
221
222    #[test]
223    fn p256_xmd_sha256_sswu_ro_hash_to_curve() {
224        let mut vector_path = std::path::Path::new("vectors").join(ID);
225        vector_path.set_extension("json");
226        let vectors = load_vectors(vector_path.as_path());
227
228        let dst = vectors["dst"].as_str().unwrap();
229        let dst = dst.as_bytes();
230        let test_cases = vectors["vectors"].as_array().unwrap().clone();
231
232        for test_case in test_cases.iter() {
233            let msg = test_case["msg"].as_str().unwrap();
234            let msg = msg.as_bytes();
235
236            let p_expected = &test_case["P"];
237            let p_x_expected = p_expected["x"].as_str().unwrap().trim_start_matches("0x");
238            let p_x_expected = P256FieldElement::from_be_bytes(&hex::decode(p_x_expected).unwrap());
239            let p_y_expected = p_expected["y"].as_str().unwrap().trim_start_matches("0x");
240            let p_y_expected = P256FieldElement::from_be_bytes(&hex::decode(p_y_expected).unwrap());
241
242            let (x, y) = hash_to_curve(msg, dst).unwrap().into();
243
244            // debug_assert!(!inf, "Point should not be infinite");
245            debug_assert_eq!(p_x_expected.as_ref(), x.as_ref(), "x-coordinate incorrect");
246            debug_assert_eq!(p_y_expected.as_ref(), y.as_ref(), "y-coordinate incorrect");
247        }
248    }
249}