@@ -26,27 +26,30 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
2626 metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
2727 metric_logger .add_meter ("img/s" , utils .SmoothedValue (window_size = 10 , fmt = "{value}" ))
2828
29+ amp = getattr (args , "amp" , False )
30+ clip_grad_norm = getattr (args , "clip_grad_norm" , None )
31+
2932 header = f"Epoch: [{ epoch } ]"
3033 for i , (image , target ) in enumerate (metric_logger .log_every (data_loader , args .print_freq , header )):
3134 start_time = time .time ()
3235 image , target = image .to (device ), target .to (device )
33- with torch .cuda .amp .autocast (enabled = args . amp ):
36+ with torch .cuda .amp .autocast (enabled = amp ):
3437 output = model (image )
3538 loss = criterion (output , target )
3639
3740 optimizer .zero_grad ()
38- if args . amp :
41+ if amp :
3942 scaler .scale (loss ).backward ()
40- if args . clip_grad_norm is not None :
43+ if clip_grad_norm is not None :
4144 # we should unscale the gradients of optimizer's assigned params if do gradient clipping
4245 scaler .unscale_ (optimizer )
43- nn .utils .clip_grad_norm_ (model .parameters (), args . clip_grad_norm )
46+ nn .utils .clip_grad_norm_ (model .parameters (), clip_grad_norm )
4447 scaler .step (optimizer )
4548 scaler .update ()
4649 else :
4750 loss .backward ()
48- if args . clip_grad_norm is not None :
49- nn .utils .clip_grad_norm_ (model .parameters (), args . clip_grad_norm )
51+ if clip_grad_norm is not None :
52+ nn .utils .clip_grad_norm_ (model .parameters (), clip_grad_norm )
5053 optimizer .step ()
5154
5255 if model_ema and i % args .model_ema_steps == 0 :
0 commit comments