Skip to content

Commit 468d2d0

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Update references to use the new Model Registration API (#6369)
Summary: * Expose on Hub the public methods of the registration API * Limit methods and update docs. * Update references to use the new Model Registration API Reviewed By: datumbox Differential Revision: D38824242 fbshipit-source-id: 8898a56115b356ef70f03d347550412fc816e0e0
1 parent 9bf92aa commit 468d2d0

File tree

6 files changed

+16
-8
lines changed

6 files changed

+16
-8
lines changed

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def main(args):
221221
)
222222

223223
print("Creating model")
224-
model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
224+
model = torchvision.models.get_model(args.model, weights=args.weights, num_classes=num_classes)
225225
model.to(device)
226226

227227
if args.distributed and args.sync_bn:

references/classification/train_quantization.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ def main(args):
4646

4747
print("Creating model", args.model)
4848
# when training quantized models, we always start from a pre-trained fp32 reference model
49-
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
49+
prefix = "quantized_"
50+
model_name = args.model
51+
if not model_name.startswith(prefix):
52+
model_name = prefix + model_name
53+
model = torchvision.models.get_model(model_name, weights=args.weights, quantize=args.test_only)
5054
model.to(device)
5155

5256
if not (args.test_only or args.post_training_quantize):

references/detection/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,8 +216,8 @@ def main(args):
216216
if "rcnn" in args.model:
217217
if args.rpn_score_thresh is not None:
218218
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
219-
model = torchvision.models.detection.__dict__[args.model](
220-
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
219+
model = torchvision.models.get_model(
220+
args.model, weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, **kwargs
221221
)
222222
model.to(device)
223223
if args.distributed and args.sync_bn:

references/optical_flow/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def main(args):
215215
else:
216216
torch.backends.cudnn.benchmark = True
217217

218-
model = torchvision.models.optical_flow.__dict__[args.model](weights=args.weights)
218+
model = torchvision.models.get_model(args.model, weights=args.weights)
219219

220220
if args.distributed:
221221
model = model.to(args.local_rank)

references/segmentation/train.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,12 @@ def main(args):
156156
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
157157
)
158158

159-
model = torchvision.models.segmentation.__dict__[args.model](
160-
weights=args.weights, weights_backbone=args.weights_backbone, num_classes=num_classes, aux_loss=args.aux_loss
159+
model = torchvision.models.get_model(
160+
args.model,
161+
weights=args.weights,
162+
weights_backbone=args.weights_backbone,
163+
num_classes=num_classes,
164+
aux_loss=args.aux_loss,
161165
)
162166
model.to(device)
163167
if args.distributed:

references/video_classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def main(args):
246246
)
247247

248248
print("Creating model")
249-
model = torchvision.models.video.__dict__[args.model](weights=args.weights)
249+
model = torchvision.models.get_model(args.model, weights=args.weights)
250250
model.to(device)
251251
if args.distributed and args.sync_bn:
252252
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

0 commit comments

Comments
 (0)