From 906a7d4a8e8536c07b9620a7a90fd75f19cdd5a0 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 14 Nov 2021 22:35:08 +0800 Subject: [PATCH 1/6] support amp training --- references/detection/train.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 09bc6334718..f4c08dcc904 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -143,6 +143,9 @@ def get_args_parser(add_help=True): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + + # Mixed precision training parameters + parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") return parser @@ -208,7 +211,9 @@ def main(args): params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - + + scaler = torch.cuda.amp.GradScaler() if args.amp else None + args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "multisteplr": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma) @@ -225,6 +230,8 @@ def main(args): optimizer.load_state_dict(checkpoint["optimizer"]) lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) args.start_epoch = checkpoint["epoch"] + 1 + if args.amp: + scaler.load_state_dict(checkpoint["scaler"]) if args.test_only: evaluate(model, data_loader_test, device=device) @@ -235,7 +242,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq) + train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler) lr_scheduler.step() if args.output_dir: checkpoint = { @@ -245,6 +252,8 @@ def main(args): "args": args, "epoch": epoch, } + if args.amp: + scaler.load_state_dict(checkpoint["scaler"]) utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) From 87859b31f824a1331bef84f3750f56afc6afec5c Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 14 Nov 2021 22:40:01 +0800 Subject: [PATCH 2/6] support amp training --- references/detection/engine.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/references/detection/engine.py b/references/detection/engine.py index 88fb2fe8e37..804f7914bac 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -9,7 +9,7 @@ from coco_utils import get_coco_api_from_dataset -def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): +def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) @@ -27,10 +27,9 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): for images, targets in metric_logger.log_every(data_loader, print_freq, header): images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] - - loss_dict = model(images, targets) - - losses = sum(loss for loss in loss_dict.values()) + with torch.cuda.amp.autocast(enabled=scaler is not None): + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) # reduce losses over all GPUs for logging purposes loss_dict_reduced = utils.reduce_dict(loss_dict) @@ -44,8 +43,13 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): sys.exit(1) optimizer.zero_grad() - losses.backward() - optimizer.step() + if scaler: + scaler.scale(losses).backward() + scaler.step(optimizer) + scaler.update() + else: + losses.backward() + optimizer.step() if lr_scheduler is not None: lr_scheduler.step() From 610f0d72af9faf5d41b9aa905e578930e352bf0b Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Sun, 14 Nov 2021 22:53:08 +0800 Subject: [PATCH 3/6] support amp training --- references/detection/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/train.py b/references/detection/train.py index f4c08dcc904..2a9fbf6e320 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -253,7 +253,7 @@ def main(args): "epoch": epoch, } if args.amp: - scaler.load_state_dict(checkpoint["scaler"]) + checkpoint['scaler'] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) From 3fbf9d4a7ed3161aeb4c2a3bcf3e32a341344df9 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 15 Nov 2021 19:12:38 +0800 Subject: [PATCH 4/6] Update references/detection/train.py Co-authored-by: Vasilis Vryniotis --- references/detection/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/train.py b/references/detection/train.py index 2a9fbf6e320..5c50dcfae4f 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -253,7 +253,7 @@ def main(args): "epoch": epoch, } if args.amp: - checkpoint['scaler'] = scaler.state_dict() + checkpoint["scaler"] = scaler.state_dict() utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth")) utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth")) From 1d6b6077935efd0d53f2810a70eb84e6a5251bd4 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 15 Nov 2021 19:12:44 +0800 Subject: [PATCH 5/6] Update references/detection/engine.py Co-authored-by: Vasilis Vryniotis --- references/detection/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/engine.py b/references/detection/engine.py index 804f7914bac..0e5d55f189d 100644 --- a/references/detection/engine.py +++ b/references/detection/engine.py @@ -43,7 +43,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc sys.exit(1) optimizer.zero_grad() - if scaler: + if scaler is not None: scaler.scale(losses).backward() scaler.step(optimizer) scaler.update() From bfc9225b158a692e28ee6517a6a508931e187b77 Mon Sep 17 00:00:00 2001 From: Hu Ye Date: Mon, 15 Nov 2021 20:07:15 +0800 Subject: [PATCH 6/6] fix lint issues --- references/detection/train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 5c50dcfae4f..ae13a32bd22 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -143,7 +143,7 @@ def get_args_parser(add_help=True): # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") - + # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") @@ -211,9 +211,9 @@ def main(args): params = [p for p in model.parameters() if p.requires_grad] optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - + scaler = torch.cuda.amp.GradScaler() if args.amp else None - + args.lr_scheduler = args.lr_scheduler.lower() if args.lr_scheduler == "multisteplr": lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)