@@ -7,36 +7,10 @@ use proc_macro::TokenStream;
77use proc_macro2:: { TokenStream as TokenStream2 , TokenTree } ;
88use quote:: { quote, ToTokens , TokenStreamExt } ;
99use syn:: {
10- parse:: { Parse , ParseStream } ,
11- parse_macro_input, parse_quote,
12- spanned:: Spanned ,
13- DeriveInput , Error , FnArg , Generics , Ident , ItemFn , ItemType , LitStr , Pat , Visibility ,
10+ parse_macro_input, parse_quote, spanned:: Spanned , Error , Fields , FnArg , Ident , ItemFn ,
11+ ItemStruct , LitStr , Pat , Visibility ,
1412} ;
1513
16- /// Parses a type definition, extracts its identifier and generic parameters
17- struct TypeDefinition {
18- ident : Ident ,
19- generics : Generics ,
20- }
21-
22- impl Parse for TypeDefinition {
23- fn parse ( input : ParseStream ) -> syn:: Result < Self > {
24- if let Ok ( d) = DeriveInput :: parse ( input) {
25- Ok ( Self {
26- ident : d. ident ,
27- generics : d. generics ,
28- } )
29- } else if let Ok ( t) = ItemType :: parse ( input) {
30- Ok ( Self {
31- ident : t. ident ,
32- generics : t. generics ,
33- } )
34- } else {
35- Err ( input. error ( "Input is not an alias, enum, struct or union definition" ) )
36- }
37- }
38- }
39-
4014macro_rules! err {
4115 ( $span: expr, $message: expr $( , ) ?) => {
4216 Error :: new( $span. span( ) , $message) . to_compile_error( )
@@ -46,28 +20,70 @@ macro_rules! err {
4620 } ;
4721}
4822
49- /// `unsafe_guid` attribute macro, implements the `Identify` trait for any type
50- /// (mostly works like a custom derive, but also supports type aliases)
23+ /// Attribute macro for marking structs as UEFI protocols.
24+ ///
25+ /// The macro takes one argument, a GUID string.
26+ ///
27+ /// The macro can only be applied to a struct, and the struct must have
28+ /// named fields (i.e. not a unit or tuple struct). It implements the
29+ /// [`Protocol`] trait and the `unsafe` [`Identify`] trait for the
30+ /// struct. It also adds a hidden field that causes the struct to be
31+ /// marked as [`!Send` and `!Sync`][send-and-sync].
32+ ///
33+ /// # Safety
34+ ///
35+ /// The caller must ensure that the correct GUID is attached to the
36+ /// type. An incorrect GUID could lead to invalid casts and other
37+ /// unsound behavior.
38+ ///
39+ /// # Example
40+ ///
41+ /// ```
42+ /// use uefi::{Identify, guid};
43+ /// use uefi::proto::unsafe_protocol;
44+ ///
45+ /// #[unsafe_protocol("12345678-9abc-def0-1234-56789abcdef0")]
46+ /// struct ExampleProtocol {}
47+ ///
48+ /// assert_eq!(ExampleProtocol::GUID, guid!("12345678-9abc-def0-1234-56789abcdef0"));
49+ /// ```
50+ ///
51+ /// [`Identify`]: https://docs.rs/uefi/latest/uefi/trait.Identify.html
52+ /// [`Protocol`]: https://docs.rs/uefi/latest/uefi/proto/trait.Protocol.html
53+ /// [send-and-sync]: https://doc.rust-lang.org/nomicon/send-and-sync.html
5154#[ proc_macro_attribute]
52- pub fn unsafe_guid ( args : TokenStream , input : TokenStream ) -> TokenStream {
53- // Parse the arguments and input using Syn
55+ pub fn unsafe_protocol ( args : TokenStream , input : TokenStream ) -> TokenStream {
56+ // Parse `args` as a GUID string.
5457 let ( time_low, time_mid, time_high_and_version, clock_seq_and_variant, node) =
5558 match parse_guid ( parse_macro_input ! ( args as LitStr ) ) {
5659 Ok ( data) => data,
5760 Err ( tokens) => return tokens. into ( ) ,
5861 } ;
5962
60- let mut result : TokenStream2 = input . clone ( ) . into ( ) ;
63+ let item_struct = parse_macro_input ! ( input as ItemStruct ) ;
6164
62- let type_definition = parse_macro_input ! ( input as TypeDefinition ) ;
63-
64- // At this point, we know everything we need to implement Identify
65- let ident = & type_definition. ident ;
66- let ( impl_generics, ty_generics, where_clause) = type_definition. generics . split_for_impl ( ) ;
65+ let ident = & item_struct. ident ;
66+ let struct_attrs = & item_struct. attrs ;
67+ let struct_vis = & item_struct. vis ;
68+ let struct_fields = if let Fields :: Named ( struct_fields) = & item_struct. fields {
69+ & struct_fields. named
70+ } else {
71+ return err ! ( item_struct, "Protocol struct must used named fields" ) . into ( ) ;
72+ } ;
73+ let struct_generics = & item_struct. generics ;
74+ let ( impl_generics, ty_generics, where_clause) = item_struct. generics . split_for_impl ( ) ;
75+
76+ quote ! {
77+ #( #struct_attrs) *
78+ #struct_vis struct #ident #struct_generics {
79+ // Add a hidden field with `PhantomData` of a raw
80+ // pointer. This has the implicit side effect of making the
81+ // struct !Send and !Sync.
82+ _no_send_or_sync: :: core:: marker:: PhantomData <* const u8 >,
83+ #struct_fields
84+ }
6785
68- result. append_all ( quote ! {
6986 unsafe impl #impl_generics :: uefi:: Identify for #ident #ty_generics #where_clause {
70- #[ doc( hidden) ]
7187 const GUID : :: uefi:: Guid = :: uefi:: Guid :: from_values(
7288 #time_low,
7389 #time_mid,
@@ -76,8 +92,10 @@ pub fn unsafe_guid(args: TokenStream, input: TokenStream) -> TokenStream {
7692 #node,
7793 ) ;
7894 }
79- } ) ;
80- result. into ( )
95+
96+ impl #impl_generics :: uefi:: proto:: Protocol for #ident #ty_generics #where_clause { }
97+ }
98+ . into ( )
8199}
82100
83101/// Create a `Guid` at compile time.
@@ -164,28 +182,6 @@ fn parse_guid(guid_lit: LitStr) -> Result<(u32, u16, u16, u16, u64), TokenStream
164182 ) )
165183}
166184
167- /// Custom derive for the `Protocol` trait
168- #[ proc_macro_derive( Protocol ) ]
169- pub fn derive_protocol ( item : TokenStream ) -> TokenStream {
170- // Parse the input using Syn
171- let item = parse_macro_input ! ( item as DeriveInput ) ;
172-
173- // Then implement Protocol
174- let ident = item. ident . clone ( ) ;
175- let ( impl_generics, ty_generics, where_clause) = item. generics . split_for_impl ( ) ;
176- let result = quote ! {
177- // Mark this as a `Protocol` implementation
178- impl #impl_generics :: uefi:: proto:: Protocol for #ident #ty_generics #where_clause { }
179-
180- // Most UEFI functions expect to be called on the bootstrap processor.
181- impl #impl_generics !Send for #ident #ty_generics #where_clause { }
182-
183- // Most UEFI functions do not support multithreaded access.
184- impl #impl_generics !Sync for #ident #ty_generics #where_clause { }
185- } ;
186- result. into ( )
187- }
188-
189185/// Get the name of a function's argument at `arg_index`.
190186fn get_function_arg_name ( f : & ItemFn , arg_index : usize , errors : & mut TokenStream2 ) -> Option < Ident > {
191187 if let Some ( FnArg :: Typed ( arg) ) = f. sig . inputs . iter ( ) . nth ( arg_index) {
0 commit comments