bolt_attribute_bolt_system_input/
lib.rs

1use proc_macro::TokenStream;
2
3use quote::quote;
4use syn::{parse_macro_input, Fields, ItemStruct, Lit, Meta, NestedMeta};
5
6/// This macro attribute is used to define a BOLT system input.
7///
8/// The input can be defined as a struct and will be transformed into an Anchor context.
9///
10///
11/// # Example
12/// ```ignore
13///#[system_input]
14///pub struct Components {
15///    pub position: Position,
16///}
17///
18/// ```
19#[proc_macro_attribute]
20pub fn system_input(_attr: TokenStream, item: TokenStream) -> TokenStream {
21    // Parse the input TokenStream (the struct) into a Rust data structure
22    let input = parse_macro_input!(item as ItemStruct);
23
24    // Ensure the struct has named fields
25    let fields = match &input.fields {
26        Fields::Named(fields) => &fields.named,
27        _ => panic!("system_input macro only supports structs with named fields"),
28    };
29    let name = &input.ident;
30
31    // Collect imports for components
32    let components_imports: Vec<_> = fields
33        .iter()
34        .filter_map(|field| {
35            field.attrs.iter().find_map(|attr| {
36                if let Ok(Meta::List(meta_list)) = attr.parse_meta() {
37                    if meta_list.path.is_ident("component_id") {
38                        meta_list.nested.first().and_then(|nested_meta| {
39                            if let NestedMeta::Lit(Lit::Str(lit_str)) = nested_meta {
40                                let component_type =
41                                    format!("bolt_types::Component{}", lit_str.value());
42                                if let Ok(parsed_component_type) =
43                                    syn::parse_str::<syn::Type>(&component_type)
44                                {
45                                    let field_type = &field.ty;
46                                    let component_import = quote! {
47                                        use #parsed_component_type as #field_type;
48                                    };
49                                    return Some(component_import);
50                                }
51                            }
52                            None
53                        })
54                    } else {
55                        None
56                    }
57                } else {
58                    None
59                }
60            })
61        })
62        .collect();
63
64    // Transform fields for the struct definition
65    let transformed_fields = fields.iter().map(|f| {
66        let field_name = &f.ident;
67        let field_type = &f.ty;
68        quote! {
69            #[account()]
70            pub #field_name: Account<'info, #field_type>,
71        }
72    });
73
74    // Generate the new struct with the Accounts derive and transformed fields
75    let output_struct = quote! {
76        #[derive(Accounts)]
77        pub struct #name<'info> {
78            #(#transformed_fields)*
79            /// CHECK: Authority check
80            #[account()]
81            pub authority: AccountInfo<'info>,
82        }
83    };
84
85    // Generate the try_to_vec method
86    let try_to_vec_fields = fields.iter().map(|f| {
87        let field_name = &f.ident;
88        quote! {
89            self.#field_name.try_to_vec()?
90        }
91    });
92
93    let try_from_fields = fields.iter().enumerate().map(|(i, f)| {
94        let field_name = &f.ident;
95        quote! {
96            #field_name: Account::try_from(context.remaining_accounts.as_ref().get(#i).ok_or_else(|| ErrorCode::ConstraintAccountIsNone)?)?,
97        }
98    });
99
100    let number_of_components = fields.len();
101
102    let output_trait = quote! {
103        pub trait NumberOfComponents<'a, 'b, 'c, 'info, T> {
104            const NUMBER_OF_COMPONENTS: usize;
105        }
106    };
107
108    let output_trait_implementation = quote! {
109        impl<'a, 'b, 'c, 'info, T: bolt_lang::Bumps> NumberOfComponents<'a, 'b, 'c, 'info, T> for Context<'a, 'b, 'c, 'info, T> {
110            const NUMBER_OF_COMPONENTS: usize = #number_of_components;
111        }
112    };
113
114    // Generate the implementation of try_to_vec for the struct
115    let output_impl = quote! {
116        impl<'info> #name<'info> {
117            pub fn try_to_vec(&self) -> Result<Vec<Vec<u8>>> {
118                Ok(vec![#(#try_to_vec_fields,)*])
119            }
120
121            fn try_from<'a, 'b>(context: &Context<'a, 'b, 'info, 'info, VariadicBoltComponents<'info>>) -> Result<Self> {
122                Ok(Self {
123                    authority: context.accounts.authority.clone(),
124                    #(#try_from_fields)*
125                })
126            }
127        }
128    };
129
130    // Combine the struct definition and its implementation into the final TokenStream
131    let output = quote! {
132        #output_struct
133        #output_impl
134        #output_trait
135        #output_trait_implementation
136        #(#components_imports)*
137
138        #[derive(Accounts)]
139        pub struct VariadicBoltComponents<'info> {
140            /// CHECK: Authority check
141            #[account()]
142            pub authority: AccountInfo<'info>,
143        }
144    };
145
146    TokenStream::from(output)
147}