1use hex::FromHex;
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::{parse::Parse, parse_macro_input, DeriveInput, Ident, LitInt, LitStr, Result, Token};
16
17#[derive(Clone, Debug)]
18struct NatModAttr {
19 mod_str: String,
21 mod_bytes: Vec<u8>,
22 int_size: usize,
24}
25
26impl Parse for NatModAttr {
27 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
28 let mod_str = input.parse::<LitStr>()?.value();
29 let mod_bytes = Vec::<u8>::from_hex(&mod_str).expect("Invalid hex String");
30 input.parse::<Token![,]>()?;
31 let int_size = input.parse::<LitInt>()?.base10_parse::<usize>()?;
32 debug_assert!(input.is_empty(), "Left over tokens in attribute {input:?}");
33 Ok(NatModAttr {
34 mod_str,
35 mod_bytes,
36 int_size,
37 })
38 }
39}
40
41#[proc_macro_attribute]
42pub fn nat_mod(attr: TokenStream, item: TokenStream) -> TokenStream {
43 let item_ast = parse_macro_input!(item as DeriveInput);
44 let ident = item_ast.ident.clone();
45 let args = parse_macro_input!(attr as NatModAttr);
46
47 let num_bytes = args.int_size;
48 let modulus = args.mod_bytes;
49 let modulus_string = args.mod_str;
50
51 let mut padded_modulus = vec![0u8; num_bytes - modulus.len()];
52 padded_modulus.append(&mut modulus.clone());
53 let mod_iter1 = padded_modulus.iter();
54 let mod_iter2 = padded_modulus.iter();
55 let const_name = Ident::new(
56 &format!("{}_MODULUS", ident.to_string().to_uppercase()),
57 ident.span(),
58 );
59 let static_name = Ident::new(
60 &format!("{}_MODULUS_STR", ident.to_string().to_uppercase()),
61 ident.span(),
62 );
63 let mod_name = Ident::new(
64 &format!("{}_mod", ident.to_string().to_uppercase()),
65 ident.span(),
66 );
67
68 let out_struct = quote! {
69 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
70 pub struct #ident {
71 value: [u8; #num_bytes],
72 }
73
74 #[allow(non_snake_case)]
76 mod #mod_name {
77 use super::*;
78
79 const #const_name: [u8; #num_bytes] = [#(#mod_iter1),*];
80 static #static_name: &str = #modulus_string;
81
82 impl NatMod<#num_bytes> for #ident {
83 const MODULUS: [u8; #num_bytes] = [#(#mod_iter2),*];
84 const MODULUS_STR: &'static str = #modulus_string;
85 const ZERO: [u8; #num_bytes] = [0u8; #num_bytes];
86
87
88 fn new(value: [u8; #num_bytes]) -> Self {
89 Self {
90 value
91 }
92 }
93 fn value(&self) -> &[u8] {
94 &self.value
95 }
96 }
97
98 impl core::convert::AsRef<[u8]> for #ident {
99 fn as_ref(&self) -> &[u8] {
100 &self.value
101 }
102 }
103
104 impl core::fmt::Display for #ident {
105 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
106 write!(f, "{}", self.to_hex())
107 }
108 }
109
110
111 impl Into<[u8; #num_bytes]> for #ident {
112 fn into(self) -> [u8; #num_bytes] {
113 self.value
114 }
115 }
116
117 impl core::ops::Add for #ident {
118 type Output = Self;
119
120 fn add(self, rhs: Self) -> Self::Output {
121 self.fadd(rhs)
122 }
123 }
124
125 impl core::ops::Mul for #ident {
126 type Output = Self;
127
128 fn mul(self, rhs: Self) -> Self::Output {
129 self.fmul(rhs)
130 }
131 }
132
133 impl core::ops::Sub for #ident {
134 type Output = Self;
135
136 fn sub(self, rhs: Self) -> Self::Output {
137 self.fsub(rhs)
138 }
139 }
140 }
141 };
142
143 out_struct.into()
144}