#![warn(missing_docs)]
use elgamal::Ciphertext;
use hacspec_lib::{hacspec_helper::NatMod, Randomness};
use super::coprf_setup::{BlindingPublicKey, CoPRFKey};
use crate::{coprf::coprf_setup::CoPRFReceiverContext, p256_sha256, Error};
use p256::P256Point;
pub type Input<'a> = &'a [u8];
pub type Output = P256Point;
pub type BlindInput = Ciphertext;
pub type BlindOutput = Ciphertext;
pub fn blind(
bpk: BlindingPublicKey,
input: Input,
context_string: Vec<u8>,
randomness: &mut Randomness,
) -> Result<BlindInput, Error> {
let inputElement = p256_sha256::hash_to_group(input, &context_string)?;
if inputElement == p256_sha256::identity() {
return Err(Error::InvalidInputError);
}
let blindInput = elgamal::encrypt(bpk, inputElement, randomness)?;
Ok(blindInput)
}
pub fn blind_evaluate(
key: CoPRFKey,
bpk: BlindingPublicKey,
blind_input: BlindInput,
randomness: &mut Randomness,
) -> Result<BlindOutput, Error> {
let input_rerandomized = elgamal::rerandomize(bpk, blind_input, randomness)?;
elgamal::scalar_mul_ciphertext(key, input_rerandomized).map_err(|e| e.into())
}
pub fn finalize(
context: &CoPRFReceiverContext,
blind_output: BlindOutput,
) -> Result<Output, Error> {
elgamal::decrypt(context.bsk, blind_output).map_err(|e| e.into())
}
pub fn prepare_blind_convert(
bpk: BlindingPublicKey,
y: Output,
randomness: &mut Randomness,
) -> Result<BlindInput, Error> {
elgamal::encrypt(bpk, y, randomness).map_err(|e| e.into())
}
pub fn blind_convert(
bpk: BlindingPublicKey,
key_from: CoPRFKey,
key_to: CoPRFKey,
blind_input: BlindInput,
randomness: &mut Randomness,
) -> Result<BlindOutput, Error> {
let delta = key_to * key_from.inv();
let ctx_rerandomized = elgamal::rerandomize(bpk, blind_input, randomness)?;
elgamal::scalar_mul_ciphertext(delta, ctx_rerandomized).map_err(|e| e.into())
}
#[cfg(test)]
mod tests {
use crate::coprf::coprf_setup::{derive_key, CoPRFEvaluatorContext};
use super::*;
pub fn evaluate(context_string: &[u8], key: CoPRFKey, input: Input) -> Result<Output, Error> {
let inputElement = p256_sha256::hash_to_group(input, context_string)?;
if inputElement == P256Point::AtInfinity {
return Err(Error::InvalidInputError);
}
let evaluatedElement = p256::p256_point_mul(key, inputElement)?;
Ok(evaluatedElement)
}
pub fn convert(
key_origin: CoPRFKey,
key_destination: CoPRFKey,
y: Output,
) -> Result<Output, Error> {
let delta = key_destination * key_origin.inv();
let result = p256::p256_point_mul(delta, y)?;
Ok(result)
}
fn generate_randomness() -> Randomness {
use rand::prelude::*;
let mut rng = rand::thread_rng();
let mut randomness = [0u8; 1000000];
rng.fill_bytes(&mut randomness);
let randomness = Randomness::new(randomness.to_vec());
randomness
}
#[test]
fn self_test_eval_convert() {
let mut randomness = generate_randomness();
let test_context = b"Test";
let test_input = b"TestInput";
let evaluator_context = CoPRFEvaluatorContext::new(&mut randomness).unwrap();
let key_origin1 = derive_key(&evaluator_context, b"1").unwrap();
let key_origin2 = derive_key(&evaluator_context, b"2").unwrap();
let key_destination = derive_key(&evaluator_context, b"3").unwrap();
let y_under_origin1 = evaluate(test_context, key_origin1, test_input).unwrap();
let y_under_origin2 = evaluate(test_context, key_origin2, test_input).unwrap();
let y_under_destination = evaluate(test_context, key_destination, test_input).unwrap();
let converted_y_from_1 = convert(key_origin1, key_destination, y_under_origin1).unwrap();
let converted_y_from_2 = convert(key_origin2, key_destination, y_under_origin2).unwrap();
debug_assert_eq!(converted_y_from_1, converted_y_from_2);
debug_assert_eq!(converted_y_from_1, y_under_destination);
}
#[test]
fn test_blind_evaluate() {
let mut randomness = generate_randomness();
let test_context = b"Test";
let test_input = b"TestInput";
let evaluator_context = CoPRFEvaluatorContext::new(&mut randomness).unwrap();
let receiver_context = CoPRFReceiverContext::new(&mut randomness);
let blind_input = blind(
receiver_context.get_bpk(),
test_input,
test_context.to_vec(),
&mut randomness,
)
.unwrap();
let evaluation_key = derive_key(&evaluator_context, b"TestKey").unwrap();
let blind_result = blind_evaluate(
evaluation_key,
receiver_context.get_bpk(),
blind_input,
&mut randomness,
)
.unwrap();
let unblinded_result = finalize(&receiver_context, blind_result).unwrap();
let expected_result = evaluate(test_context, evaluation_key, test_input).unwrap();
debug_assert_eq!(unblinded_result, expected_result);
}
#[test]
fn blind_convergence() {
let mut randomness = generate_randomness();
let test_context = b"Test";
let test_input = b"TestInput";
let evaluator_context = CoPRFEvaluatorContext::new(&mut randomness).unwrap();
let receiver_context = CoPRFReceiverContext::new(&mut randomness);
let key_origin1 = derive_key(&evaluator_context, b"1").unwrap();
let key_origin2 = derive_key(&evaluator_context, b"2").unwrap();
let key_destination = derive_key(&evaluator_context, b"3").unwrap();
let y_under_destination = evaluate(test_context, key_destination, test_input).unwrap();
let y1 = evaluate(test_context, key_origin1, test_input).unwrap();
let y2 = evaluate(test_context, key_origin2, test_input).unwrap();
let blind1 =
prepare_blind_convert(receiver_context.get_bpk(), y1, &mut randomness).unwrap();
let blind2 =
prepare_blind_convert(receiver_context.get_bpk(), y2, &mut randomness).unwrap();
let blind_result_1 = blind_convert(
receiver_context.get_bpk(),
key_origin1,
key_destination,
blind1,
&mut randomness,
)
.unwrap();
let blind_result_2 = blind_convert(
receiver_context.get_bpk(),
key_origin2,
key_destination,
blind2,
&mut randomness,
)
.unwrap();
let res1 = finalize(&receiver_context, blind_result_1).unwrap();
let res2 = finalize(&receiver_context, blind_result_2).unwrap();
debug_assert_eq!(res1, res2);
debug_assert_eq!(res1, y_under_destination);
}
#[test]
fn test_blind_conversion() {
let mut randomness = generate_randomness();
let test_context = b"Test";
let test_input = b"TestInput";
let evaluator_context = CoPRFEvaluatorContext::new(&mut randomness).unwrap();
let receiver_context = CoPRFReceiverContext::new(&mut randomness);
let blind_input = blind(
receiver_context.get_bpk(),
test_input,
test_context.to_vec(),
&mut randomness,
)
.unwrap();
let key_eval = derive_key(&evaluator_context, b"TestKey").unwrap();
let key_destination = derive_key(&evaluator_context, b"DestinationKey").unwrap();
let blind_result = blind_evaluate(
key_eval,
receiver_context.get_bpk(),
blind_input,
&mut randomness,
)
.unwrap();
let expected_result = evaluate(test_context, key_destination, test_input).unwrap();
let blind_converted_result = blind_convert(
receiver_context.get_bpk(),
key_eval,
key_destination,
blind_result,
&mut randomness,
)
.unwrap();
let unblinded_converted_result =
finalize(&receiver_context, blind_converted_result).unwrap();
debug_assert_eq!(expected_result, unblinded_converted_result);
let unblinded_intermediate_result = finalize(&receiver_context, blind_result).unwrap();
let prepped_input = prepare_blind_convert(
receiver_context.get_bpk(),
unblinded_intermediate_result,
&mut randomness,
)
.unwrap();
let blind_converted_result = blind_convert(
receiver_context.get_bpk(),
key_eval,
key_destination,
prepped_input,
&mut randomness,
)
.unwrap();
let unblinded_converted_result =
finalize(&receiver_context, blind_converted_result).unwrap();
debug_assert_eq!(expected_result, unblinded_converted_result);
}
}