88# 
99# ==-------------------------------------------------------------------------==# 
1010
11- 
1211import  yaml 
1312import  argparse 
14- 
1513from  pathlib  import  Path 
1614from  header  import  HeaderFile 
15+ from  gpu_headers  import  GpuHeaderFile  as  GpuHeader 
1716from  class_implementation .classes .macro  import  Macro 
1817from  class_implementation .classes .type  import  Type 
1918from  class_implementation .classes .function  import  Function 
2221from  class_implementation .classes .object  import  Object 
2322
2423
25- def  yaml_to_classes (yaml_data ):
24+ def  yaml_to_classes (yaml_data ,  header_class ,  entry_points = None ):
2625    """ 
2726    Convert YAML data to header classes. 
2827
2928    Args: 
3029        yaml_data: The YAML data containing header specifications. 
30+         header_class: The class to use for creating the header. 
31+         entry_points: A list of specific function names to include in the header. 
3132
3233    Returns: 
3334        HeaderFile: An instance of HeaderFile populated with the data. 
3435    """ 
3536    header_name  =  yaml_data .get ("header" )
36-     header  =  HeaderFile (header_name )
37+     header  =  header_class (header_name )
3738
3839    for  macro_data  in  yaml_data .get ("macros" , []):
3940        header .add_macro (Macro (macro_data ["macro_name" ], macro_data ["macro_value" ]))
@@ -49,12 +50,15 @@ def yaml_to_classes(yaml_data):
4950        )
5051
5152    functions  =  yaml_data .get ("functions" , [])
53+     if  entry_points :
54+         entry_points_set  =  set (entry_points )
55+         functions  =  [f  for  f  in  functions  if  f ["name" ] in  entry_points_set ]
5256    sorted_functions  =  sorted (functions , key = lambda  x : x ["name" ])
5357    guards  =  []
5458    guarded_function_dict  =  {}
5559    for  function_data  in  sorted_functions :
5660        guard  =  function_data .get ("guard" , None )
57-         if  guard  ==  None :
61+         if  guard  is  None :
5862            arguments  =  [arg ["type" ] for  arg  in  function_data ["arguments" ]]
5963            attributes  =  function_data .get ("attributes" , None )
6064            standards  =  function_data .get ("standards" , None )
@@ -105,19 +109,21 @@ def yaml_to_classes(yaml_data):
105109    return  header 
106110
107111
108- def  load_yaml_file (yaml_file ):
112+ def  load_yaml_file (yaml_file ,  header_class ,  entry_points ):
109113    """ 
110114    Load YAML file and convert it to header classes. 
111115
112116    Args: 
113-         yaml_file: The path to the YAML file. 
117+         yaml_file: Path to the YAML file. 
118+         header_class: The class to use for creating the header (HeaderFile or GpuHeader). 
119+         entry_points: A list of specific function names to include in the header. 
114120
115121    Returns: 
116-         HeaderFile: An instance of HeaderFile populated with the data from the YAML file . 
122+         HeaderFile: An instance of HeaderFile populated with the data. 
117123    """ 
118124    with  open (yaml_file , "r" ) as  f :
119125        yaml_data  =  yaml .safe_load (f )
120-     return  yaml_to_classes (yaml_data )
126+     return  yaml_to_classes (yaml_data ,  header_class ,  entry_points )
121127
122128
123129def  fill_public_api (header_str , h_def_content ):
@@ -207,7 +213,14 @@ def increase_indent(self, flow=False, indentless=False):
207213    print (f"Added function { new_function .name }   to { yaml_file }  " )
208214
209215
210- def  main (yaml_file , h_def_file , output_dir , add_function = None ):
216+ def  main (
217+     yaml_file ,
218+     output_dir = None ,
219+     h_def_file = None ,
220+     add_function = None ,
221+     entry_points = None ,
222+     export_decls = False ,
223+ ):
211224    """ 
212225    Main function to generate header files from YAML and .h.def templates. 
213226
@@ -216,41 +229,50 @@ def main(yaml_file, h_def_file, output_dir, add_function=None):
216229        h_def_file: Path to the .h.def template file. 
217230        output_dir: Directory to output the generated header file. 
218231        add_function: Details of the function to be added to the YAML file (if any). 
232+         entry_points: A list of specific function names to include in the header. 
233+         export_decls: Flag to use GpuHeader for exporting declarations. 
219234    """ 
220- 
221235    if  add_function :
222236        add_function_to_yaml (yaml_file , add_function )
223237
224-     header  =  load_yaml_file (yaml_file )
225- 
226-     with  open (h_def_file , "r" ) as  f :
227-         h_def_content  =  f .read ()
238+     header_class  =  GpuHeader  if  export_decls  else  HeaderFile 
239+     header  =  load_yaml_file (yaml_file , header_class , entry_points )
228240
229241    header_str  =  str (header )
230-     final_header_content  =  fill_public_api (header_str , h_def_content )
231242
232-     output_file_name  =  Path (h_def_file ).stem 
233-     output_file_path  =  Path (output_dir ) /  output_file_name 
234- 
235-     with  open (output_file_path , "w" ) as  f :
236-         f .write (final_header_content )
243+     if  output_dir :
244+         output_file_path  =  Path (output_dir )
245+         if  output_file_path .is_dir ():
246+             output_file_path  /=  f"{ Path (yaml_file ).stem }  .h" 
247+     else :
248+         output_file_path  =  Path (f"{ Path (yaml_file ).stem }  .h" )
249+ 
250+     if  not  export_decls  and  h_def_file :
251+         with  open (h_def_file , "r" ) as  f :
252+             h_def_content  =  f .read ()
253+         final_header_content  =  fill_public_api (header_str , h_def_content )
254+         with  open (output_file_path , "w" ) as  f :
255+             f .write (final_header_content )
256+     else :
257+         with  open (output_file_path , "w" ) as  f :
258+             f .write (header_str )
237259
238260    print (f"Generated header file: { output_file_path }  " )
239261
240262
241263if  __name__  ==  "__main__" :
242-     parser  =  argparse .ArgumentParser (
243-         description = "Generate header files from YAML and .h.def templates" 
244-     )
264+     parser  =  argparse .ArgumentParser (description = "Generate header files from YAML" )
245265    parser .add_argument (
246266        "yaml_file" , help = "Path to the YAML file containing header specification" 
247267    )
248-     parser .add_argument ("h_def_file" , help = "Path to the .h.def template file" )
249268    parser .add_argument (
250269        "--output_dir" ,
251-         default = "." ,
252270        help = "Directory to output the generated header file" ,
253271    )
272+     parser .add_argument (
273+         "--h_def_file" ,
274+         help = "Path to the .h.def template file (required if not using --export_decls)" ,
275+     )
254276    parser .add_argument (
255277        "--add_function" ,
256278        nargs = 6 ,
@@ -264,6 +286,21 @@ def main(yaml_file, h_def_file, output_dir, add_function=None):
264286        ),
265287        help = "Add a function to the YAML file" ,
266288    )
289+     parser .add_argument (
290+         "--e" , action = "append" , help = "Entry point to include" , dest = "entry_points" 
291+     )
292+     parser .add_argument (
293+         "--export-decls" ,
294+         action = "store_true" ,
295+         help = "Flag to use GpuHeader for exporting declarations" ,
296+     )
267297    args  =  parser .parse_args ()
268298
269-     main (args .yaml_file , args .h_def_file , args .output_dir , args .add_function )
299+     main (
300+         args .yaml_file ,
301+         args .output_dir ,
302+         args .h_def_file ,
303+         args .add_function ,
304+         args .entry_points ,
305+         args .export_decls ,
306+     )
0 commit comments