diff --git a/references/detection/engine.py b/references/detection/engine.py index 88fb2fe8e37..0e5d55f189d 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -9,7 +9,7 @@ from coco_utils import get_coco_api_from_dataset -def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): +def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) @@ -27,10 +27,9 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] - - loss_dict = model(images, targets) - - losses = sum(loss for loss in loss_dict.values()) + with torch.cuda.amp.autocast(enabled=scaler is not None): + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) @@ -44,8 +43,13 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): sys.exit(1) optimizer.zero_grad() - losses.backward() - optimizer.step() + if scaler is not None: + scaler.scale(losses).backward() + scaler.step(optimizer) + scaler.update() + else: + losses.backward() + optimizer.step() if lr_scheduler is not None: lr_scheduler.step() diff --git a/references/detection/train.py b/references/detection/train.py index 09bc6334718..ae13a32bd22 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -144,6 +144,9 @@ def get_args_parser(add_help=True): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + return parser @@ -209,6 +212,8 @@ def main(args): params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + scaler = torch.cuda.amp.GradScaler() if args.amp else None + args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "multisteplr": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) @@ -225,6 +230,8 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 + if args.amp: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: evaluate(model, data_loader_test, device=device) @@ -235,7 +242,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) + train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler) lr_scheduler.step() if args.output_dir: checkpoint = { @@ -245,6 +252,8 @@ def main(args): "args": args, "epoch": epoch, } + if args.amp: + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))