Skip to content

Commit 220e0ff

Browse files
authored
Add multi-type support on get_weight() (#4967)
* Add multi-type support on get_weight. * Fix bug on method call. * Adding logging suffix for QAT.
1 parent ebc4ca7 commit 220e0ff

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

references/classification/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
3030
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
3131
start_time = time.time()
3232
image, target = image.to(device), target.to(device)
33-
with torch.cuda.amp.autocast(enabled=args.amp):
33+
with torch.cuda.amp.autocast(enabled=scaler is not None):
3434
output = model(image)
3535
loss = criterion(output, target)
3636

3737
optimizer.zero_grad()
38-
if args.amp:
38+
if scaler is not None:
3939
scaler.scale(loss).backward()
4040
if args.clip_grad_norm is not None:
4141
# we should unscale the gradients of optimizer's assigned params if do gradient clipping

references/classification/train_quantization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def main(args):
121121
if args.distributed:
122122
train_sampler.set_epoch(epoch)
123123
print("Starting training for epoch", epoch)
124-
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
124+
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
125125
lr_scheduler.step()
126126
with torch.inference_mode():
127127
if epoch >= args.num_observer_update_epochs:
@@ -132,7 +132,7 @@ def main(args):
132132
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
133133
print("Evaluate QAT model")
134134

135-
evaluate(model, criterion, data_loader_test, device=device)
135+
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
136136
quantized_eval_model = copy.deepcopy(model_without_ddp)
137137
quantized_eval_model.eval()
138138
quantized_eval_model.to(torch.device("cpu"))
@@ -261,6 +261,7 @@ def get_args_parser(add_help=True):
261261
parser.add_argument(
262262
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
263263
)
264+
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
264265

265266
# Prototype models only
266267
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

torchvision/prototype/models/_api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
101101
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
102102
for t in ann.__args__: # type: ignore[union-attr]
103103
if isinstance(t, type) and issubclass(t, Weights):
104+
# ensure the name exists. handles builders with multiple types of weights like in quantization
105+
try:
106+
t.from_str(weight_name)
107+
except ValueError:
108+
continue
104109
weights_class = t
105110
break
106111

0 commit comments

Comments
 (0)