diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 1f363f57dad..0cd88e8022f 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -12,19 +12,13 @@ from torch.utils.data.dataloader import default_collate from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler -try: - from apex import amp -except ImportError: - amp = None - - try: from torchvision.prototype import models as PM except ImportError: PM = None -def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): +def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}")) @@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi for video, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() video, target = video.to(device), target.to(device) - output = model(video) - loss = criterion(output, target) + with torch.cuda.amp.autocast(enabled=scaler is not None): + output = model(video) + loss = criterion(output, target) optimizer.zero_grad() - if apex: - with amp.scale_loss(loss, optimizer) as scaled_loss: - scaled_loss.backward() + + if scaler is not None: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() else: loss.backward() - optimizer.step() + optimizer.step() acc1, acc5 = utils.accuracy(output, target, topk=(1, 5)) batch_size = video.shape[0] @@ -101,11 +98,6 @@ def collate_fn(batch): def main(args): if args.weights and PM is None: raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") - if args.apex and amp is None: - raise RuntimeError( - "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex " - "to enable mixed-precision training." - ) if args.output_dir: utils.mkdir(args.output_dir) @@ -224,9 +216,7 @@ def main(args): lr = args.lr * args.world_size optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay) - - if args.apex: - model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level) + scaler = torch.cuda.amp.GradScaler() if args.amp else None # convert scheduler to be per iteration, not per epoch, for warmup that lasts # between different epochs @@ -267,6 +257,8 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 + if args.amp: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: evaluate(model, criterion, data_loader_test, device=device) @@ -277,9 +269,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch( - model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex - ) + train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler) evaluate(model, criterion, data_loader_test, device=device) if args.output_dir: checkpoint = { @@ -289,6 +279,8 @@ def main(args): "epoch": epoch, "args": args, } + if args.amp: + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) @@ -363,17 +355,6 @@ def parse_args(): action="store_true", ) - # Mixed precision training parameters - parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training") - parser.add_argument( - "--apex-opt-level", - default="O1", - type=str, - help="For apex mixed precision training" - "O0 for FP32 training, O1 for mixed precision training." - "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet", - ) - # distributed training parameters parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") @@ -381,6 +362,9 @@ def parse_args(): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + args = parser.parse_args() return args