Skip to content

Commit 829c5c1

Browse files
committed
porting shufflenetv2
1 parent 3db3e77 commit 829c5c1

File tree

2 files changed

+94
-172
lines changed

2 files changed

+94
-172
lines changed

torchvision/models/quantization/shufflenetv2.py

Lines changed: 94 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,28 @@
1-
from typing import Any, Optional
1+
from functools import partial
2+
from typing import Any, List, Optional, Union
23

34
import torch
45
import torch.nn as nn
56
from torch import Tensor
67
from torchvision.models import shufflenetv2
78

8-
from ..._internally_replaced_utils import load_state_dict_from_url
9+
from ...transforms import ImageClassificationEval, InterpolationMode
10+
from .._api import WeightsEnum, Weights
11+
from .._meta import _IMAGENET_CATEGORIES
12+
from .._utils import handle_legacy_interface, _ovewrite_named_param
13+
from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights, ShuffleNet_V2_X1_0_Weights
914
from .utils import _fuse_modules, _replace_relu, quantize_model
1015

16+
1117
__all__ = [
1218
"QuantizableShuffleNetV2",
19+
"ShuffleNet_V2_X0_5_QuantizedWeights",
20+
"ShuffleNet_V2_X1_0_QuantizedWeights",
1321
"shufflenet_v2_x0_5",
1422
"shufflenet_v2_x1_0",
1523
]
1624

1725

18-
model_urls = {
19-
"shufflenetv2_x0.5": "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth",
20-
"shufflenetv2_x1.0": "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth",
21-
}
22-
23-
24-
quant_model_urls = {
25-
"shufflenetv2_x0.5_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
26-
"shufflenetv2_x1.0_fbgemm": "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
27-
}
28-
29-
3026
class QuantizableInvertedResidual(shufflenetv2.InvertedResidual):
3127
def __init__(self, *args: Any, **kwargs: Any) -> None:
3228
super().__init__(*args, **kwargs)
@@ -80,39 +76,86 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None:
8076

8177

8278
def _shufflenetv2(
83-
arch: str,
84-
pretrained: bool,
79+
stages_repeats: List[int],
80+
stages_out_channels: List[int],
81+
*,
82+
weights: Optional[WeightsEnum],
8583
progress: bool,
8684
quantize: bool,
87-
*args: Any,
8885
**kwargs: Any,
8986
) -> QuantizableShuffleNetV2:
87+
if weights is not None:
88+
_ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"]))
89+
if "backend" in weights.meta:
90+
_ovewrite_named_param(kwargs, "backend", weights.meta["backend"])
91+
backend = kwargs.pop("backend", "fbgemm")
9092

91-
model = QuantizableShuffleNetV2(*args, **kwargs)
93+
model = QuantizableShuffleNetV2(stages_repeats, stages_out_channels, **kwargs)
9294
_replace_relu(model)
93-
9495
if quantize:
95-
# TODO use pretrained as a string to specify the backend
96-
backend = "fbgemm"
9796
quantize_model(model, backend)
98-
else:
99-
assert pretrained in [True, False]
100-
101-
if pretrained:
102-
model_url: Optional[str] = None
103-
if quantize:
104-
model_url = quant_model_urls[arch + "_" + backend]
105-
else:
106-
model_url = model_urls[arch]
10797

108-
state_dict = load_state_dict_from_url(model_url, progress=progress)
98+
if weights is not None:
99+
model.load_state_dict(weights.get_state_dict(progress=progress))
109100

110-
model.load_state_dict(state_dict)
111101
return model
112102

113103

104+
_COMMON_META = {
105+
"task": "image_classification",
106+
"architecture": "ShuffleNetV2",
107+
"publication_year": 2018,
108+
"size": (224, 224),
109+
"min_size": (1, 1),
110+
"categories": _IMAGENET_CATEGORIES,
111+
"interpolation": InterpolationMode.BILINEAR,
112+
"backend": "fbgemm",
113+
"quantization": "ptq",
114+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models",
115+
}
116+
117+
118+
class ShuffleNet_V2_X0_5_QuantizedWeights(WeightsEnum):
119+
IMAGENET1K_FBGEMM_V1 = Weights(
120+
url="https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth",
121+
transforms=partial(ImageClassificationEval, crop_size=224),
122+
meta={
123+
**_COMMON_META,
124+
"num_params": 1366792,
125+
"unquantized": ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
126+
"acc@1": 57.972,
127+
"acc@5": 79.780,
128+
},
129+
)
130+
DEFAULT = IMAGENET1K_FBGEMM_V1
131+
132+
133+
class ShuffleNet_V2_X1_0_QuantizedWeights(WeightsEnum):
134+
IMAGENET1K_FBGEMM_V1 = Weights(
135+
url="https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth",
136+
transforms=partial(ImageClassificationEval, crop_size=224),
137+
meta={
138+
**_COMMON_META,
139+
"num_params": 2278604,
140+
"unquantized": ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
141+
"acc@1": 68.360,
142+
"acc@5": 87.582,
143+
},
144+
)
145+
DEFAULT = IMAGENET1K_FBGEMM_V1
146+
147+
148+
@handle_legacy_interface(
149+
weights=(
150+
"pretrained",
151+
lambda kwargs: ShuffleNet_V2_X0_5_QuantizedWeights.IMAGENET1K_FBGEMM_V1
152+
if kwargs.get("quantize", False)
153+
else ShuffleNet_V2_X0_5_Weights.IMAGENET1K_V1,
154+
)
155+
)
114156
def shufflenet_v2_x0_5(
115-
pretrained: bool = False,
157+
*,
158+
weights: Optional[Union[ShuffleNet_V2_X0_5_QuantizedWeights, ShuffleNet_V2_X0_5_Weights]] = None,
116159
progress: bool = True,
117160
quantize: bool = False,
118161
**kwargs: Any,
@@ -123,17 +166,28 @@ def shufflenet_v2_x0_5(
123166
<https://arxiv.org/abs/1807.11164>`_.
124167
125168
Args:
126-
pretrained (bool): If True, returns a model pre-trained on ImageNet
169+
pretrained (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained
170+
weights for the model
127171
progress (bool): If True, displays a progress bar of the download to stderr
128172
quantize (bool): If True, return a quantized version of the model
129173
"""
174+
weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights).verify(weights)
130175
return _shufflenetv2(
131-
"shufflenetv2_x0.5", pretrained, progress, quantize, [4, 8, 4], [24, 48, 96, 192, 1024], **kwargs
176+
[4, 8, 4], [24, 48, 96, 192, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
132177
)
133178

134179

180+
@handle_legacy_interface(
181+
weights=(
182+
"pretrained",
183+
lambda kwargs: ShuffleNet_V2_X1_0_QuantizedWeights.IMAGENET1K_FBGEMM_V1
184+
if kwargs.get("quantize", False)
185+
else ShuffleNet_V2_X1_0_Weights.IMAGENET1K_V1,
186+
)
187+
)
135188
def shufflenet_v2_x1_0(
136-
pretrained: bool = False,
189+
*,
190+
weights: Optional[Union[ShuffleNet_V2_X1_0_QuantizedWeights, ShuffleNet_V2_X1_0_Weights]] = None,
137191
progress: bool = True,
138192
quantize: bool = False,
139193
**kwargs: Any,
@@ -144,10 +198,12 @@ def shufflenet_v2_x1_0(
144198
<https://arxiv.org/abs/1807.11164>`_.
145199
146200
Args:
147-
pretrained (bool): If True, returns a model pre-trained on ImageNet
201+
pretrained (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained
202+
weights for the model
148203
progress (bool): If True, displays a progress bar of the download to stderr
149204
quantize (bool): If True, return a quantized version of the model
150205
"""
206+
weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights).verify(weights)
151207
return _shufflenetv2(
152-
"shufflenetv2_x1.0", pretrained, progress, quantize, [4, 8, 4], [24, 116, 232, 464, 1024], **kwargs
208+
[4, 8, 4], [24, 116, 232, 464, 1024], weights=weights, progress=progress, quantize=quantize, **kwargs
153209
)

torchvision/prototype/models/quantization/shufflenetv2.py

Lines changed: 0 additions & 134 deletions
This file was deleted.

0 commit comments

Comments
 (0)