natmod/
lib.rs

1//! // This trait lives in the library
2//!
3//! ```ignore
4//! pub trait NatModTrait<T> {
5//!     const MODULUS: T;
6//! }
7//!
8//! #[nat_mod("123456", 10)]
9//! struct MyNatMod {}
10//! ```
11
12use hex::FromHex;
13use proc_macro::TokenStream;
14use quote::quote;
15use syn::{parse::Parse, parse_macro_input, DeriveInput, Ident, LitInt, LitStr, Result, Token};
16
17#[derive(Clone, Debug)]
18struct NatModAttr {
19    /// Modulus as hex string and bytes
20    mod_str: String,
21    mod_bytes: Vec<u8>,
22    /// Number of bytes to use for the integer
23    int_size: usize,
24}
25
26impl Parse for NatModAttr {
27    fn parse(input: syn::parse::ParseStream) -> Result<Self> {
28        let mod_str = input.parse::<LitStr>()?.value();
29        let mod_bytes = Vec::<u8>::from_hex(&mod_str).expect("Invalid hex String");
30        input.parse::<Token![,]>()?;
31        let int_size = input.parse::<LitInt>()?.base10_parse::<usize>()?;
32        debug_assert!(input.is_empty(), "Left over tokens in attribute {input:?}");
33        Ok(NatModAttr {
34            mod_str,
35            mod_bytes,
36            int_size,
37        })
38    }
39}
40
41#[proc_macro_attribute]
42pub fn nat_mod(attr: TokenStream, item: TokenStream) -> TokenStream {
43    let item_ast = parse_macro_input!(item as DeriveInput);
44    let ident = item_ast.ident.clone();
45    let args = parse_macro_input!(attr as NatModAttr);
46
47    let num_bytes = args.int_size;
48    let modulus = args.mod_bytes;
49    let modulus_string = args.mod_str;
50
51    let mut padded_modulus = vec![0u8; num_bytes - modulus.len()];
52    padded_modulus.append(&mut modulus.clone());
53    let mod_iter1 = padded_modulus.iter();
54    let mod_iter2 = padded_modulus.iter();
55    let const_name = Ident::new(
56        &format!("{}_MODULUS", ident.to_string().to_uppercase()),
57        ident.span(),
58    );
59    let static_name = Ident::new(
60        &format!("{}_MODULUS_STR", ident.to_string().to_uppercase()),
61        ident.span(),
62    );
63    let mod_name = Ident::new(
64        &format!("{}_mod", ident.to_string().to_uppercase()),
65        ident.span(),
66    );
67
68    let out_struct = quote! {
69        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
70        pub struct #ident {
71            value: [u8; #num_bytes],
72        }
73
74        //#[not_hax]
75        #[allow(non_snake_case)]
76        mod #mod_name {
77            use super::*;
78
79            const #const_name: [u8; #num_bytes] = [#(#mod_iter1),*];
80            static #static_name: &str = #modulus_string;
81
82            impl NatMod<#num_bytes> for #ident {
83                const MODULUS: [u8; #num_bytes] = [#(#mod_iter2),*];
84                const MODULUS_STR: &'static str = #modulus_string;
85                const ZERO: [u8; #num_bytes] = [0u8; #num_bytes];
86
87
88                fn new(value: [u8; #num_bytes]) -> Self {
89                    Self {
90                        value
91                    }
92                }
93                fn value(&self) -> &[u8] {
94                    &self.value
95                }
96            }
97
98            impl core::convert::AsRef<[u8]> for #ident {
99                fn as_ref(&self) -> &[u8] {
100                    &self.value
101                }
102            }
103
104            impl core::fmt::Display for #ident {
105                fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
106                    write!(f, "{}", self.to_hex())
107                }
108            }
109
110
111            impl Into<[u8; #num_bytes]> for #ident {
112                fn into(self) -> [u8; #num_bytes] {
113                    self.value
114                }
115            }
116
117            impl core::ops::Add for #ident {
118                type Output = Self;
119
120                fn add(self, rhs: Self) -> Self::Output {
121                    self.fadd(rhs)
122                }
123            }
124
125            impl core::ops::Mul for #ident {
126                type Output = Self;
127
128                fn mul(self, rhs: Self) -> Self::Output {
129                    self.fmul(rhs)
130                }
131            }
132
133            impl core::ops::Sub for #ident {
134                type Output = Self;
135
136                fn sub(self, rhs: Self) -> Self::Output {
137                    self.fsub(rhs)
138                }
139            }
140        }
141    };
142
143    out_struct.into()
144}