1212from torch .utils .data .dataloader import default_collate
1313from torchvision .datasets .samplers import DistributedSampler , UniformClipSampler , RandomClipSampler
1414
15- try :
16- from apex import amp
17- except ImportError :
18- amp = None
19-
20-
2115try :
2216 from torchvision .prototype import models as PM
2317except ImportError :
2418 PM = None
2519
2620
27- def train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , print_freq , apex = False ):
21+ def train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , print_freq , scaler = None ):
2822 model .train ()
2923 metric_logger = utils .MetricLogger (delimiter = " " )
3024 metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
@@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
3428 for video , target in metric_logger .log_every (data_loader , print_freq , header ):
3529 start_time = time .time ()
3630 video , target = video .to (device ), target .to (device )
37- output = model (video )
38- loss = criterion (output , target )
31+ with torch .cuda .amp .autocast (enabled = scaler is not None ):
32+ output = model (video )
33+ loss = criterion (output , target )
3934
4035 optimizer .zero_grad ()
41- if apex :
42- with amp .scale_loss (loss , optimizer ) as scaled_loss :
43- scaled_loss .backward ()
36+
37+ if scaler is not None :
38+ scaler .scale (loss ).backward ()
39+ scaler .step (optimizer )
40+ scaler .update ()
4441 else :
4542 loss .backward ()
46- optimizer .step ()
43+ optimizer .step ()
4744
4845 acc1 , acc5 = utils .accuracy (output , target , topk = (1 , 5 ))
4946 batch_size = video .shape [0 ]
@@ -101,11 +98,6 @@ def collate_fn(batch):
10198def main (args ):
10299 if args .weights and PM is None :
103100 raise ImportError ("The prototype module couldn't be found. Please install the latest torchvision nightly." )
104- if args .apex and amp is None :
105- raise RuntimeError (
106- "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
107- "to enable mixed-precision training."
108- )
109101
110102 if args .output_dir :
111103 utils .mkdir (args .output_dir )
@@ -224,9 +216,7 @@ def main(args):
224216
225217 lr = args .lr * args .world_size
226218 optimizer = torch .optim .SGD (model .parameters (), lr = lr , momentum = args .momentum , weight_decay = args .weight_decay )
227-
228- if args .apex :
229- model , optimizer = amp .initialize (model , optimizer , opt_level = args .apex_opt_level )
219+ scaler = torch .cuda .amp .GradScaler () if args .amp else None
230220
231221 # convert scheduler to be per iteration, not per epoch, for warmup that lasts
232222 # between different epochs
@@ -267,6 +257,8 @@ def main(args):
267257 optimizer .load_state_dict (checkpoint ["optimizer" ])
268258 lr_scheduler .load_state_dict (checkpoint ["lr_scheduler" ])
269259 args .start_epoch = checkpoint ["epoch" ] + 1
260+ if args .amp :
261+ scaler .load_state_dict (checkpoint ["scaler" ])
270262
271263 if args .test_only :
272264 evaluate (model , criterion , data_loader_test , device = device )
@@ -277,9 +269,7 @@ def main(args):
277269 for epoch in range (args .start_epoch , args .epochs ):
278270 if args .distributed :
279271 train_sampler .set_epoch (epoch )
280- train_one_epoch (
281- model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , args .print_freq , args .apex
282- )
272+ train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , args .print_freq , scaler )
283273 evaluate (model , criterion , data_loader_test , device = device )
284274 if args .output_dir :
285275 checkpoint = {
@@ -289,6 +279,8 @@ def main(args):
289279 "epoch" : epoch ,
290280 "args" : args ,
291281 }
282+ if args .amp :
283+ checkpoint ["scaler" ] = scaler .state_dict ()
292284 utils .save_on_master (checkpoint , os .path .join (args .output_dir , f"model_{ epoch } .pth" ))
293285 utils .save_on_master (checkpoint , os .path .join (args .output_dir , "checkpoint.pth" ))
294286
@@ -363,24 +355,16 @@ def parse_args():
363355 action = "store_true" ,
364356 )
365357
366- # Mixed precision training parameters
367- parser .add_argument ("--apex" , action = "store_true" , help = "Use apex for mixed precision training" )
368- parser .add_argument (
369- "--apex-opt-level" ,
370- default = "O1" ,
371- type = str ,
372- help = "For apex mixed precision training"
373- "O0 for FP32 training, O1 for mixed precision training."
374- "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet" ,
375- )
376-
377358 # distributed training parameters
378359 parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
379360 parser .add_argument ("--dist-url" , default = "env://" , type = str , help = "url used to set up distributed training" )
380361
381362 # Prototype models only
382363 parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
383364
365+ # Mixed precision training parameters
366+ parser .add_argument ("--amp" , action = "store_true" , help = "Use torch.cuda.amp for mixed precision training" )
367+
384368 args = parser .parse_args ()
385369
386370 return args
0 commit comments