5656except  ImportError  as  e :
5757    has_functorch  =  False 
5858
59- try :
60-     import  torch ._dynamo 
61-     has_dynamo  =  True 
62- except  ImportError :
63-     has_dynamo  =  False 
64-     pass 
65- 
59+ has_compile  =  hasattr (torch , 'compile' )
6660
6761if  torch .cuda .is_available ():
6862    torch .backends .cuda .matmul .allow_tf32  =  True 
8175                    help = 'Provide train fwd/bwd/opt breakdown detail if True. Defaults to False' )
8276parser .add_argument ('--no-retry' , action = 'store_true' , default = False ,
8377                    help = 'Do not decay batch size and retry on error.' )
84- parser .add_argument ('--results-file' , default = '' , type = str ,  metavar = 'FILENAME' , 
78+ parser .add_argument ('--results-file' , default = '' , type = str ,
8579                    help = 'Output csv file for validation results (summary)' )
80+ parser .add_argument ('--results-format' , default = 'csv' , type = str ,
81+                     help = 'Format for results file one of (csv, json) (default: csv).' )
8682parser .add_argument ('--num-warm-iter' , default = 10 , type = int ,
8783                    metavar = 'N' , help = 'Number of warmup iterations (default: 10)' )
8884parser .add_argument ('--num-bench-iter' , default = 40 , type = int ,
113109                    help = 'Numeric precision. One of (amp, float32, float16, bfloat16, tf32)' )
114110parser .add_argument ('--fuser' , default = '' , type = str ,
115111                    help = "Select jit fuser. One of ('', 'te', 'old', 'nvfuser')" )
116- parser .add_argument ('--dynamo-backend' , default = None , type = str ,
117-                     help = "Select dynamo backend. Default: None" )
118112parser .add_argument ('--fast-norm' , default = False , action = 'store_true' ,
119113                    help = 'enable experimental fast-norm' )
120114
121115# codegen (model compilation) options 
122116scripting_group  =  parser .add_mutually_exclusive_group ()
123117scripting_group .add_argument ('--torchscript' , dest = 'torchscript' , action = 'store_true' ,
124118                             help = 'convert model torchscript for inference' )
119+ scripting_group .add_argument ('--torchcompile' , nargs = '?' , type = str , default = None , const = 'inductor' ,
120+                              help = "Enable compilation w/ specified backend (default: inductor)." )
125121scripting_group .add_argument ('--aot-autograd' , default = False , action = 'store_true' ,
126122                             help = "Enable AOT Autograd optimization." )
127- scripting_group .add_argument ('--dynamo' , default = False , action = 'store_true' ,
128-                              help = "Enable Dynamo optimization." )
123+ 
129124
130125# train optimizer parameters 
131126parser .add_argument ('--opt' , default = 'sgd' , type = str , metavar = 'OPTIMIZER' ,
@@ -218,9 +213,8 @@ def __init__(
218213            detail = False ,
219214            device = 'cuda' ,
220215            torchscript = False ,
216+             torchcompile = None ,
221217            aot_autograd = False ,
222-             dynamo = False ,
223-             dynamo_backend = None ,
224218            precision = 'float32' ,
225219            fuser = '' ,
226220            num_warm_iter = 10 ,
@@ -259,20 +253,19 @@ def __init__(
259253        self .input_size  =  data_config ['input_size' ]
260254        self .batch_size  =  kwargs .pop ('batch_size' , 256 )
261255
262-         self .scripted  =  False 
256+         self .compiled  =  False 
263257        if  torchscript :
264258            self .model  =  torch .jit .script (self .model )
265-             self .scripted  =  True 
266-         elif  dynamo :
267-             assert  has_dynamo ,  " torch._dynamo  is needed for --dynamo" 
259+             self .compiled  =  True 
260+         elif  torchcompile :
261+             assert  has_compile ,  'A version of  torch w/ torch.compile()  is required, possibly a nightly.' 
268262            torch ._dynamo .reset ()
269-             if  dynamo_backend  is  not   None :
270-                 self .model  =  torch ._dynamo .optimize (dynamo_backend )(self .model )
271-             else :
272-                 self .model  =  torch ._dynamo .optimize ()(self .model )
263+             self .model  =  torch .compile (self .model , backend = torchcompile )
264+             self .compiled  =  True 
273265        elif  aot_autograd :
274266            assert  has_functorch , "functorch is needed for --aot-autograd" 
275267            self .model  =  memory_efficient_fusion (self .model )
268+             self .compiled  =  True 
276269
277270        self .example_inputs  =  None 
278271        self .num_warm_iter  =  num_warm_iter 
@@ -344,7 +337,7 @@ def _step():
344337            param_count = round (self .param_count  /  1e6 , 2 ),
345338        )
346339
347-         retries  =  0  if  self .scripted  else  2   # skip profiling if model is scripted 
340+         retries  =  0  if  self .compiled  else  2   # skip profiling if model is scripted 
348341        while  retries :
349342            retries  -=  1 
350343            try :
@@ -642,7 +635,6 @@ def main():
642635        model_cfgs  =  [(n , None ) for  n  in  model_names ]
643636
644637    if  len (model_cfgs ):
645-         results_file  =  args .results_file  or  './benchmark.csv' 
646638        _logger .info ('Running bulk validation on these pretrained models: {}' .format (', ' .join (model_names )))
647639        results  =  []
648640        try :
@@ -663,22 +655,30 @@ def main():
663655            sort_key  =  'infer_gmacs' 
664656        results  =  filter (lambda  x : sort_key  in  x , results )
665657        results  =  sorted (results , key = lambda  x : x [sort_key ], reverse = True )
666-         if  len (results ):
667-             write_results (results_file , results )
668658    else :
669659        results  =  benchmark (args )
670660
661+     if  args .results_file :
662+         write_results (args .results_file , results , format = args .results_format )
663+ 
671664    # output results in JSON to stdout w/ delimiter for runner script 
672665    print (f'--result\n { json .dumps (results , indent = 4 )}  ' )
673666
674667
675- def  write_results (results_file , results ):
668+ def  write_results (results_file , results ,  format = 'csv' ):
676669    with  open (results_file , mode = 'w' ) as  cf :
677-         dw  =  csv .DictWriter (cf , fieldnames = results [0 ].keys ())
678-         dw .writeheader ()
679-         for  r  in  results :
680-             dw .writerow (r )
681-         cf .flush ()
670+         if  format  ==  'json' :
671+             json .dump (results , cf , indent = 4 )
672+         else :
673+             if  not  isinstance (results , (list , tuple )):
674+                 results  =  [results ]
675+             if  not  results :
676+                 return 
677+             dw  =  csv .DictWriter (cf , fieldnames = results [0 ].keys ())
678+             dw .writeheader ()
679+             for  r  in  results :
680+                 dw .writerow (r )
681+             cf .flush ()
682682
683683
684684if  __name__  ==  '__main__' :
0 commit comments