1212from  torch .utils .data .dataloader  import  default_collate 
1313from  torchvision .datasets .samplers  import  DistributedSampler , UniformClipSampler , RandomClipSampler 
1414
15- try :
16-     from  apex  import  amp 
17- except  ImportError :
18-     amp  =  None 
19- 
20- 
2115try :
2216    from  torchvision .prototype  import  models  as  PM 
2317except  ImportError :
2418    PM  =  None 
2519
2620
27- def  train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , print_freq , apex = False ):
21+ def  train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , print_freq , scaler = None ):
2822    model .train ()
2923    metric_logger  =  utils .MetricLogger (delimiter = "  " )
3024    metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
@@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
3428    for  video , target  in  metric_logger .log_every (data_loader , print_freq , header ):
3529        start_time  =  time .time ()
3630        video , target  =  video .to (device ), target .to (device )
37-         output  =  model (video )
38-         loss  =  criterion (output , target )
31+         with  torch .cuda .amp .autocast (enabled = scaler  is  not None ):
32+             output  =  model (video )
33+             loss  =  criterion (output , target )
3934
4035        optimizer .zero_grad ()
41-         if  apex :
42-             with  amp .scale_loss (loss , optimizer ) as  scaled_loss :
43-                 scaled_loss .backward ()
36+ 
37+         if  scaler  is  not None :
38+             scaler .scale (loss ).backward ()
39+             scaler .step (optimizer )
40+             scaler .update ()
4441        else :
4542            loss .backward ()
46-         optimizer .step ()
43+              optimizer .step ()
4744
4845        acc1 , acc5  =  utils .accuracy (output , target , topk = (1 , 5 ))
4946        batch_size  =  video .shape [0 ]
@@ -101,11 +98,6 @@ def collate_fn(batch):
10198def  main (args ):
10299    if  args .weights  and  PM  is  None :
103100        raise  ImportError ("The prototype module couldn't be found. Please install the latest torchvision nightly." )
104-     if  args .apex  and  amp  is  None :
105-         raise  RuntimeError (
106-             "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " 
107-             "to enable mixed-precision training." 
108-         )
109101
110102    if  args .output_dir :
111103        utils .mkdir (args .output_dir )
@@ -224,9 +216,7 @@ def main(args):
224216
225217    lr  =  args .lr  *  args .world_size 
226218    optimizer  =  torch .optim .SGD (model .parameters (), lr = lr , momentum = args .momentum , weight_decay = args .weight_decay )
227- 
228-     if  args .apex :
229-         model , optimizer  =  amp .initialize (model , optimizer , opt_level = args .apex_opt_level )
219+     scaler  =  torch .cuda .amp .GradScaler () if  args .amp  else  None 
230220
231221    # convert scheduler to be per iteration, not per epoch, for warmup that lasts 
232222    # between different epochs 
@@ -267,6 +257,8 @@ def main(args):
267257        optimizer .load_state_dict (checkpoint ["optimizer" ])
268258        lr_scheduler .load_state_dict (checkpoint ["lr_scheduler" ])
269259        args .start_epoch  =  checkpoint ["epoch" ] +  1 
260+         if  args .amp :
261+             scaler .load_state_dict (checkpoint ["scaler" ])
270262
271263    if  args .test_only :
272264        evaluate (model , criterion , data_loader_test , device = device )
@@ -277,9 +269,7 @@ def main(args):
277269    for  epoch  in  range (args .start_epoch , args .epochs ):
278270        if  args .distributed :
279271            train_sampler .set_epoch (epoch )
280-         train_one_epoch (
281-             model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , args .print_freq , args .apex 
282-         )
272+         train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , args .print_freq , scaler )
283273        evaluate (model , criterion , data_loader_test , device = device )
284274        if  args .output_dir :
285275            checkpoint  =  {
@@ -289,6 +279,8 @@ def main(args):
289279                "epoch" : epoch ,
290280                "args" : args ,
291281            }
282+             if  args .amp :
283+                 checkpoint ["scaler" ] =  scaler .state_dict ()
292284            utils .save_on_master (checkpoint , os .path .join (args .output_dir , f"model_{ epoch }  ))
293285            utils .save_on_master (checkpoint , os .path .join (args .output_dir , "checkpoint.pth" ))
294286
@@ -363,24 +355,16 @@ def parse_args():
363355        action = "store_true" ,
364356    )
365357
366-     # Mixed precision training parameters 
367-     parser .add_argument ("--apex" , action = "store_true" , help = "Use apex for mixed precision training" )
368-     parser .add_argument (
369-         "--apex-opt-level" ,
370-         default = "O1" ,
371-         type = str ,
372-         help = "For apex mixed precision training" 
373-         "O0 for FP32 training, O1 for mixed precision training." 
374-         "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet" ,
375-     )
376- 
377358    # distributed training parameters 
378359    parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
379360    parser .add_argument ("--dist-url" , default = "env://" , type = str , help = "url used to set up distributed training" )
380361
381362    # Prototype models only 
382363    parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
383364
365+     # Mixed precision training parameters 
366+     parser .add_argument ("--amp" , action = "store_true" , help = "Use torch.cuda.amp for mixed precision training" )
367+ 
384368    args  =  parser .parse_args ()
385369
386370    return  args 
0 commit comments