Skip to content

Commit c187c2b

Browse files
vinhngxfmassa
authored andcommitted
Fix apex distributed training (#1124)
* adding mixed precision training with Apex * fix APEX default optimization level * adding python version check for apex * fix LINT errors and raise exceptions if apex not available * fixing apex distributed training * fix throughput calculation: include forward pass * remove torch.cuda.set_device(args.gpu) as it's already called in init_distributed_mode * fix linter: new line * move Apex initialization code back to the beginning of main * move apex initialization to before lr_scheduler - for peace of mind. Though, doing apex initialization after lr_scheduler seems to work fine as well
1 parent 5d1372c commit c187c2b

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

references/classification/train.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri
2626

2727
header = 'Epoch: [{}]'.format(epoch)
2828
for image, target in metric_logger.log_every(data_loader, print_freq, header):
29+
start_time = time.time()
2930
image, target = image.to(device), target.to(device)
3031
output = model(image)
3132
loss = criterion(output, target)
3233

33-
start_time = time.time()
3434
optimizer.zero_grad()
3535
if apex:
3636
with amp.scale_loss(loss, optimizer) as scaled_loss:
@@ -170,23 +170,23 @@ def main(args):
170170
if args.distributed and args.sync_bn:
171171
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
172172

173-
model_without_ddp = model
174-
if args.distributed:
175-
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
176-
model_without_ddp = model.module
177-
178173
criterion = nn.CrossEntropyLoss()
179174

180175
optimizer = torch.optim.SGD(
181176
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
182177

183-
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
184-
185178
if args.apex:
186179
model, optimizer = amp.initialize(model, optimizer,
187180
opt_level=args.apex_opt_level
188181
)
189182

183+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
184+
185+
model_without_ddp = model
186+
if args.distributed:
187+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
188+
model_without_ddp = model.module
189+
190190
if args.resume:
191191
checkpoint = torch.load(args.resume, map_location='cpu')
192192
model_without_ddp.load_state_dict(checkpoint['model'])

0 commit comments

Comments
 (0)