Skip to content

Commit 403bded

Browse files
authored
Updated classification reference script to use torch.cuda.amp (#4547)
* Updated classification reference script to use torch.cuda.amp * Assigned scaler to None if amp is False * Fixed linter errors
1 parent 261cbf7 commit 403bded

File tree

1 file changed

+15
-28
lines changed

1 file changed

+15
-28
lines changed

references/classification/train.py

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@
1212
from torch.utils.data.dataloader import default_collate
1313
from torchvision.transforms.functional import InterpolationMode
1414

15-
try:
16-
from apex import amp
17-
except ImportError:
18-
amp = None
1915

20-
21-
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False, model_ema=None):
16+
def train_one_epoch(
17+
model, criterion, optimizer, data_loader, device, epoch, print_freq, amp=False, model_ema=None, scaler=None
18+
):
2219
model.train()
2320
metric_logger = utils.MetricLogger(delimiter=" ")
2421
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
@@ -29,13 +26,16 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
2926
start_time = time.time()
3027
image, target = image.to(device), target.to(device)
3128
output = model(image)
32-
loss = criterion(output, target)
3329

3430
optimizer.zero_grad()
35-
if apex:
36-
with amp.scale_loss(loss, optimizer) as scaled_loss:
37-
scaled_loss.backward()
31+
if amp:
32+
with torch.cuda.amp.autocast():
33+
loss = criterion(output, target)
34+
scaler.scale(loss).backward()
35+
scaler.step(optimizer)
36+
scaler.update()
3837
else:
38+
loss = criterion(output, target)
3939
loss.backward()
4040
optimizer.step()
4141

@@ -156,12 +156,6 @@ def load_data(traindir, valdir, args):
156156

157157

158158
def main(args):
159-
if args.apex and amp is None:
160-
raise RuntimeError(
161-
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
162-
"to enable mixed-precision training."
163-
)
164-
165159
if args.output_dir:
166160
utils.mkdir(args.output_dir)
167161

@@ -228,8 +222,7 @@ def main(args):
228222
else:
229223
raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt))
230224

231-
if args.apex:
232-
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
225+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
233226

234227
args.lr_scheduler = args.lr_scheduler.lower()
235228
if args.lr_scheduler == "steplr":
@@ -292,7 +285,9 @@ def main(args):
292285
for epoch in range(args.start_epoch, args.epochs):
293286
if args.distributed:
294287
train_sampler.set_epoch(epoch)
295-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex, model_ema)
288+
train_one_epoch(
289+
model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.amp, model_ema, scaler
290+
)
296291
lr_scheduler.step()
297292
evaluate(model, criterion, data_loader_test, device=device)
298293
if model_ema:
@@ -385,15 +380,7 @@ def get_args_parser(add_help=True):
385380
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
386381

387382
# Mixed precision training parameters
388-
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
389-
parser.add_argument(
390-
"--apex-opt-level",
391-
default="O1",
392-
type=str,
393-
help="For apex mixed precision training"
394-
"O0 for FP32 training, O1 for mixed precision training."
395-
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
396-
)
383+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
397384

398385
# distributed training parameters
399386
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")

0 commit comments

Comments
 (0)