Skip to content

Commit 9a34c0c

Browse files
authored
Adding multiweight support to Quantized ResNet (#4827)
* Adding multi-weight support to Quantized ResNet. * Update references script to support testing quantized models with the new API. * Handle quantized models correctly in ref script. * Fixing references for quantization.
1 parent 6a60b9b commit 9a34c0c

File tree

5 files changed

+126
-12
lines changed

5 files changed

+126
-12
lines changed

references/classification/README.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,9 @@ torchrun --nproc_per_node=8 train.py\
151151

152152
## Quantized
153153

154-
### Parameters used for generating quantized models:
154+
### Post training quantized models
155155

156-
For all post training quantized models (All quantized models except mobilenet-v2), the settings are:
156+
For all post training quantized models, the settings are:
157157

158158
1. num_calibration_batches: 32
159159
2. num_workers: 16
@@ -162,8 +162,11 @@ For all post training quantized models (All quantized models except mobilenet-v2
162162
5. backend: 'fbgemm'
163163

164164
```
165-
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='<model_name>'
165+
python train_quantization.py --device='cpu' --post-training-quantize --backend='fbgemm' --model='$MODEL'
166166
```
167+
Here `$MODEL` is one of `googlenet`, `inception_v3`, `resnet18`, `resnet50`, `resnext101_32x8d` and `shufflenet_v2_x1_0`.
168+
169+
### QAT MobileNetV2
167170

168171
For Mobilenet-v2, the model was trained with quantization aware training, the settings used are:
169172
1. num_workers: 16
@@ -185,6 +188,8 @@ torchrun --nproc_per_node=8 train_quantization.py --model='mobilenet_v2'
185188

186189
Training converges at about 10 epochs.
187190

191+
### QAT MobileNetV3
192+
188193
For Mobilenet-v3 Large, the model was trained with quantization aware training, the settings used are:
189194
1. num_workers: 16
190195
2. batch_size: 32

references/classification/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def load_data(traindir, valdir, args):
153153
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
154154
)
155155
else:
156-
fn = PM.__dict__[args.model]
156+
fn = PM.quantization.__dict__[args.model] if hasattr(args, "backend") else PM.__dict__[args.model]
157157
weights = PM._api.get_weight(fn, args.weights)
158158
preprocessing = weights.transforms()
159159

references/classification/train_quantization.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@
1212
from train import train_one_epoch, evaluate, load_data
1313

1414

15+
try:
16+
from torchvision.prototype import models as PM
17+
except ImportError:
18+
PM = None
19+
20+
1521
def main(args):
1622
if args.output_dir:
1723
utils.mkdir(args.output_dir)
@@ -46,7 +52,12 @@ def main(args):
4652

4753
print("Creating model", args.model)
4854
# when training quantized models, we always start from a pre-trained fp32 reference model
49-
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
55+
if not args.weights:
56+
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
57+
else:
58+
if PM is None:
59+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
60+
model = PM.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
5061
model.to(device)
5162

5263
if not (args.test_only or args.post_training_quantize):
@@ -251,6 +262,9 @@ def get_args_parser(add_help=True):
251262
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
252263
)
253264

265+
# Prototype models only
266+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
267+
254268
return parser
255269

256270

test/test_prototype_models.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,15 @@ def get_models_with_module_names(module):
3030
return [(fn, module_name) for fn in TM.get_models_from_module(module)]
3131

3232

33-
def test_get_weight():
34-
fn = models.resnet50
35-
weight_name = "ImageNet1K_RefV2"
36-
assert models._api.get_weight(fn, weight_name) == models.ResNet50Weights.ImageNet1K_RefV2
33+
@pytest.mark.parametrize(
34+
"model_fn, weight",
35+
[
36+
(models.resnet50, models.ResNet50Weights.ImageNet1K_RefV2),
37+
(models.quantization.resnet50, models.quantization.QuantizedResNet50Weights.ImageNet1K_FBGEMM_RefV1),
38+
],
39+
)
40+
def test_get_weight(model_fn, weight):
41+
assert models._api.get_weight(model_fn, weight.name) == weight
3742

3843

3944
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models))
@@ -43,6 +48,12 @@ def test_classification_model(model_fn, dev):
4348
TM.test_classification_model(model_fn, dev)
4449

4550

51+
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
52+
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
53+
def test_quantized_classification_model(model_fn):
54+
TM.test_quantized_classification_model(model_fn)
55+
56+
4657
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
4758
@pytest.mark.parametrize("dev", cpu_and_gpu())
4859
@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled")
@@ -60,6 +71,7 @@ def test_video_model(model_fn, dev):
6071
@pytest.mark.parametrize(
6172
"model_fn, module_name",
6273
get_models_with_module_names(models)
74+
+ get_models_with_module_names(models.quantization)
6375
+ get_models_with_module_names(models.segmentation)
6476
+ get_models_with_module_names(models.video),
6577
)
@@ -70,6 +82,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev):
7082
"models": {
7183
"input_shape": (1, 3, 224, 224),
7284
},
85+
"quantization": {
86+
"input_shape": (1, 3, 224, 224),
87+
},
7388
"segmentation": {
7489
"input_shape": (1, 3, 520, 520),
7590
},

torchvision/prototype/models/quantization/resnet.py

Lines changed: 83 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,18 @@
1212
from ...transforms.presets import ImageNetEval
1313
from .._api import Weights, WeightEntry
1414
from .._meta import _IMAGENET_CATEGORIES
15-
from ..resnet import ResNet50Weights
15+
from ..resnet import ResNet18Weights, ResNet50Weights, ResNeXt101_32x8dWeights
1616

1717

18-
__all__ = ["QuantizableResNet", "QuantizedResNet50Weights", "resnet50"]
18+
__all__ = [
19+
"QuantizableResNet",
20+
"QuantizedResNet18Weights",
21+
"QuantizedResNet50Weights",
22+
"QuantizedResNeXt101_32x8dWeights",
23+
"resnet18",
24+
"resnet50",
25+
"resnext101_32x8d",
26+
]
1927

2028

2129
def _resnet(
@@ -47,22 +55,67 @@ def _resnet(
4755
"size": (224, 224),
4856
"categories": _IMAGENET_CATEGORIES,
4957
"backend": "fbgemm",
58+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
5059
}
5160

5261

62+
class QuantizedResNet18Weights(Weights):
63+
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
64+
url="https://download.pytorch.org/models/quantized/resnet18_fbgemm_16fa66dd.pth",
65+
transforms=partial(ImageNetEval, crop_size=224),
66+
meta={
67+
**_common_meta,
68+
"acc@1": 69.494,
69+
"acc@5": 88.882,
70+
},
71+
)
72+
73+
5374
class QuantizedResNet50Weights(Weights):
5475
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
5576
url="https://download.pytorch.org/models/quantized/resnet50_fbgemm_bf931d71.pth",
5677
transforms=partial(ImageNetEval, crop_size=224),
5778
meta={
5879
**_common_meta,
59-
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#quantized",
6080
"acc@1": 75.920,
6181
"acc@5": 92.814,
6282
},
6383
)
6484

6585

86+
class QuantizedResNeXt101_32x8dWeights(Weights):
87+
ImageNet1K_FBGEMM_RefV1 = WeightEntry(
88+
url="https://download.pytorch.org/models/quantized/resnext101_32x8_fbgemm_09835ccf.pth",
89+
transforms=partial(ImageNetEval, crop_size=224),
90+
meta={
91+
**_common_meta,
92+
"acc@1": 78.986,
93+
"acc@5": 94.480,
94+
},
95+
)
96+
97+
98+
def resnet18(
99+
weights: Optional[Union[QuantizedResNet18Weights, ResNet18Weights]] = None,
100+
progress: bool = True,
101+
quantize: bool = False,
102+
**kwargs: Any,
103+
) -> QuantizableResNet:
104+
if "pretrained" in kwargs:
105+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
106+
if kwargs.pop("pretrained"):
107+
weights = QuantizedResNet18Weights.ImageNet1K_FBGEMM_RefV1 if quantize else ResNet18Weights.ImageNet1K_RefV1
108+
else:
109+
weights = None
110+
111+
if quantize:
112+
weights = QuantizedResNet18Weights.verify(weights)
113+
else:
114+
weights = ResNet18Weights.verify(weights)
115+
116+
return _resnet(QuantizableBasicBlock, [2, 2, 2, 2], weights, progress, quantize, **kwargs)
117+
118+
66119
def resnet50(
67120
weights: Optional[Union[QuantizedResNet50Weights, ResNet50Weights]] = None,
68121
progress: bool = True,
@@ -82,3 +135,30 @@ def resnet50(
82135
weights = ResNet50Weights.verify(weights)
83136

84137
return _resnet(QuantizableBottleneck, [3, 4, 6, 3], weights, progress, quantize, **kwargs)
138+
139+
140+
def resnext101_32x8d(
141+
weights: Optional[Union[QuantizedResNeXt101_32x8dWeights, ResNeXt101_32x8dWeights]] = None,
142+
progress: bool = True,
143+
quantize: bool = False,
144+
**kwargs: Any,
145+
) -> QuantizableResNet:
146+
if "pretrained" in kwargs:
147+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
148+
if kwargs.pop("pretrained"):
149+
weights = (
150+
QuantizedResNeXt101_32x8dWeights.ImageNet1K_FBGEMM_RefV1
151+
if quantize
152+
else ResNeXt101_32x8dWeights.ImageNet1K_RefV1
153+
)
154+
else:
155+
weights = None
156+
157+
if quantize:
158+
weights = QuantizedResNeXt101_32x8dWeights.verify(weights)
159+
else:
160+
weights = ResNeXt101_32x8dWeights.verify(weights)
161+
162+
kwargs["groups"] = 32
163+
kwargs["width_per_group"] = 8
164+
return _resnet(QuantizableBottleneck, [3, 4, 23, 3], weights, progress, quantize, **kwargs)

0 commit comments

Comments
 (0)