1212from torch .utils .data .dataloader import default_collate
1313from torchvision .transforms .functional import InterpolationMode
1414
15- try :
16- from apex import amp
17- except ImportError :
18- amp = None
1915
20-
21- def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , print_freq , apex = False , model_ema = None ):
16+ def train_one_epoch (
17+ model , criterion , optimizer , data_loader , device , epoch , print_freq , amp = False , model_ema = None , scaler = None
18+ ):
2219 model .train ()
2320 metric_logger = utils .MetricLogger (delimiter = " " )
2421 metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
@@ -29,13 +26,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
2926 start_time = time .time ()
3027 image , target = image .to (device ), target .to (device )
3128 output = model (image )
32- loss = criterion (output , target )
3329
3430 optimizer .zero_grad ()
35- if apex :
36- with amp .scale_loss (loss , optimizer ) as scaled_loss :
37- scaled_loss .backward ()
31+ if amp :
32+ with torch .cuda .amp .autocast ():
33+ loss = criterion (output , target )
34+ scaler .scale (loss ).backward ()
35+ scaler .step (optimizer )
36+ scaler .update ()
3837 else :
38+ loss = criterion (output , target )
3939 loss .backward ()
4040 optimizer .step ()
4141
@@ -156,12 +156,6 @@ def load_data(traindir, valdir, args):
156156
157157
158158def main (args ):
159- if args .apex and amp is None :
160- raise RuntimeError (
161- "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
162- "to enable mixed-precision training."
163- )
164-
165159 if args .output_dir :
166160 utils .mkdir (args .output_dir )
167161
@@ -228,8 +222,7 @@ def main(args):
228222 else :
229223 raise RuntimeError ("Invalid optimizer {}. Only SGD and RMSprop are supported." .format (args .opt ))
230224
231- if args .apex :
232- model , optimizer = amp .initialize (model , optimizer , opt_level = args .apex_opt_level )
225+ scaler = torch .cuda .amp .GradScaler () if args .amp else None
233226
234227 args .lr_scheduler = args .lr_scheduler .lower ()
235228 if args .lr_scheduler == "steplr" :
@@ -292,7 +285,9 @@ def main(args):
292285 for epoch in range (args .start_epoch , args .epochs ):
293286 if args .distributed :
294287 train_sampler .set_epoch (epoch )
295- train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args .print_freq , args .apex , model_ema )
288+ train_one_epoch (
289+ model , criterion , optimizer , data_loader , device , epoch , args .print_freq , args .amp , model_ema , scaler
290+ )
296291 lr_scheduler .step ()
297292 evaluate (model , criterion , data_loader_test , device = device )
298293 if model_ema :
@@ -385,15 +380,7 @@ def get_args_parser(add_help=True):
385380 parser .add_argument ("--random-erase" , default = 0.0 , type = float , help = "random erasing probability (default: 0.0)" )
386381
387382 # Mixed precision training parameters
388- parser .add_argument ("--apex" , action = "store_true" , help = "Use apex for mixed precision training" )
389- parser .add_argument (
390- "--apex-opt-level" ,
391- default = "O1" ,
392- type = str ,
393- help = "For apex mixed precision training"
394- "O0 for FP32 training, O1 for mixed precision training."
395- "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet" ,
396- )
383+ parser .add_argument ("--amp" , action = "store_true" , help = "Use torch.cuda.amp for mixed precision training" )
397384
398385 # distributed training parameters
399386 parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
0 commit comments