@@ -8,9 +8,9 @@ use proc_macro2::{TokenStream as TokenStream2, TokenTree};
88use quote:: { quote, ToTokens , TokenStreamExt } ;
99use syn:: {
1010 parse:: { Parse , ParseStream } ,
11- parse_macro_input,
11+ parse_macro_input, parse_quote ,
1212 spanned:: Spanned ,
13- DeriveInput , Error , Generics , Ident , ItemFn , ItemType , LitStr , Visibility ,
13+ DeriveInput , Error , FnArg , Generics , Ident , ItemFn , ItemType , LitStr , Pat , Visibility ,
1414} ;
1515
1616/// Parses a type definition, extracts its identifier and generic parameters
@@ -157,6 +157,28 @@ pub fn derive_protocol(item: TokenStream) -> TokenStream {
157157 result. into ( )
158158}
159159
160+ /// Get the name of a function's argument at `arg_index`.
161+ fn get_function_arg_name ( f : & ItemFn , arg_index : usize , errors : & mut TokenStream2 ) -> Option < Ident > {
162+ if let Some ( FnArg :: Typed ( arg) ) = f. sig . inputs . iter ( ) . nth ( arg_index) {
163+ if let Pat :: Ident ( pat_ident) = & * arg. pat {
164+ // The argument has a valid name such as `handle` or `_handle`.
165+ Some ( pat_ident. ident . clone ( ) )
166+ } else {
167+ // The argument is unnamed, i.e. `_`.
168+ errors. append_all ( err ! ( arg. span( ) , "Entry method's arguments must be named" ) ) ;
169+ None
170+ }
171+ } else {
172+ // Either there are too few arguments, or it's the wrong kind of
173+ // argument (e.g. `self`).
174+ //
175+ // Don't append an error in this case. The error will be caught
176+ // by the typecheck later on, which will give a better error
177+ // message.
178+ None
179+ }
180+ }
181+
160182/// Custom attribute for a UEFI executable entrypoint
161183#[ proc_macro_attribute]
162184pub fn entry ( args : TokenStream , input : TokenStream ) -> TokenStream {
@@ -190,6 +212,9 @@ pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream {
190212 ) ) ;
191213 }
192214
215+ let image_handle_ident = get_function_arg_name ( & f, 0 , & mut errors) ;
216+ let system_table_ident = get_function_arg_name ( & f, 1 , & mut errors) ;
217+
193218 // show most errors at once instead of one by one
194219 if !errors. is_empty ( ) {
195220 return errors. into ( ) ;
@@ -199,6 +224,18 @@ pub fn entry(args: TokenStream, input: TokenStream) -> TokenStream {
199224 let unsafety = f. sig . unsafety . take ( ) ;
200225 // strip any visibility modifiers
201226 f. vis = Visibility :: Inherited ;
227+ // Set the global image handle. If `image_handle_ident` is `None`
228+ // then the typecheck is going to fail anyway.
229+ if let Some ( image_handle_ident) = image_handle_ident {
230+ f. block . stmts . insert (
231+ 0 ,
232+ parse_quote ! {
233+ unsafe {
234+ #system_table_ident. boot_services( ) . set_image_handle( #image_handle_ident) ;
235+ }
236+ } ,
237+ ) ;
238+ }
202239
203240 let ident = & f. sig . ident ;
204241
0 commit comments