Skip to content

Commit 7e5ccbf

Browse files
committed
Fix bug on method call.
1 parent 5865ea7 commit 7e5ccbf

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

references/classification/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
3030
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
3131
start_time = time.time()
3232
image, target = image.to(device), target.to(device)
33-
with torch.cuda.amp.autocast(enabled=args.amp):
33+
with torch.cuda.amp.autocast(enabled=scaler is not None):
3434
output = model(image)
3535
loss = criterion(output, target)
3636

3737
optimizer.zero_grad()
38-
if args.amp:
38+
if scaler is not None:
3939
scaler.scale(loss).backward()
4040
if args.clip_grad_norm is not None:
4141
# we should unscale the gradients of optimizer's assigned params if do gradient clipping

references/classification/train_quantization.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def main(args):
121121
if args.distributed:
122122
train_sampler.set_epoch(epoch)
123123
print("Starting training for epoch", epoch)
124-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
124+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
125125
lr_scheduler.step()
126126
with torch.inference_mode():
127127
if epoch >= args.num_observer_update_epochs:
@@ -261,6 +261,7 @@ def get_args_parser(add_help=True):
261261
parser.add_argument(
262262
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
263263
)
264+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
264265

265266
# Prototype models only
266267
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

0 commit comments

Comments
 (0)