Skip to content

Commit d03a0ee

Browse files
committed
Fix bug on method call.
1 parent 5865ea7 commit d03a0ee

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

references/classification/train.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,27 +26,30 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
2626
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
2727
metric_logger.add_meter("img/s", utils.SmoothedValue(window_size=10, fmt="{value}"))
2828

29+
amp = getattr(args, "amp", False)
30+
clip_grad_norm = getattr(args, "clip_grad_norm", None)
31+
2932
header = f"Epoch: [{epoch}]"
3033
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
3134
start_time = time.time()
3235
image, target = image.to(device), target.to(device)
33-
with torch.cuda.amp.autocast(enabled=args.amp):
36+
with torch.cuda.amp.autocast(enabled=amp):
3437
output = model(image)
3538
loss = criterion(output, target)
3639

3740
optimizer.zero_grad()
38-
if args.amp:
41+
if amp:
3942
scaler.scale(loss).backward()
40-
if args.clip_grad_norm is not None:
43+
if clip_grad_norm is not None:
4144
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
4245
scaler.unscale_(optimizer)
43-
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
46+
nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
4447
scaler.step(optimizer)
4548
scaler.update()
4649
else:
4750
loss.backward()
48-
if args.clip_grad_norm is not None:
49-
nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad_norm)
51+
if clip_grad_norm is not None:
52+
nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
5053
optimizer.step()
5154

5255
if model_ema and i % args.model_ema_steps == 0:

references/classification/train_quantization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def main(args):
121121
if args.distributed:
122122
train_sampler.set_epoch(epoch)
123123
print("Starting training for epoch", epoch)
124-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
124+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
125125
lr_scheduler.step()
126126
with torch.inference_mode():
127127
if epoch >= args.num_observer_update_epochs:

0 commit comments

Comments
 (0)