1- from typing import Any , Optional
1+ from functools import partial
2+ from typing import Any , List , Optional , Union
23
34import torch
45import torch .nn as nn
56from torch import Tensor
67from torchvision .models import shufflenetv2
78
8- from ..._internally_replaced_utils import load_state_dict_from_url
9+ from ...transforms import ImageClassificationEval , InterpolationMode
10+ from .._api import WeightsEnum , Weights
11+ from .._meta import _IMAGENET_CATEGORIES
12+ from .._utils import handle_legacy_interface , _ovewrite_named_param
13+ from ..shufflenetv2 import ShuffleNet_V2_X0_5_Weights , ShuffleNet_V2_X1_0_Weights
914from .utils import _fuse_modules , _replace_relu , quantize_model
1015
16+
1117__all__ = [
1218 "QuantizableShuffleNetV2" ,
19+ "ShuffleNet_V2_X0_5_QuantizedWeights" ,
20+ "ShuffleNet_V2_X1_0_QuantizedWeights" ,
1321 "shufflenet_v2_x0_5" ,
1422 "shufflenet_v2_x1_0" ,
1523]
1624
1725
18- model_urls = {
19- "shufflenetv2_x0.5" : "https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth" ,
20- "shufflenetv2_x1.0" : "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth" ,
21- }
22-
23-
24- quant_model_urls = {
25- "shufflenetv2_x0.5_fbgemm" : "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth" ,
26- "shufflenetv2_x1.0_fbgemm" : "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth" ,
27- }
28-
29-
3026class QuantizableInvertedResidual (shufflenetv2 .InvertedResidual ):
3127 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
3228 super ().__init__ (* args , ** kwargs )
@@ -80,39 +76,86 @@ def fuse_model(self, is_qat: Optional[bool] = None) -> None:
8076
8177
8278def _shufflenetv2 (
83- arch : str ,
84- pretrained : bool ,
79+ stages_repeats : List [int ],
80+ stages_out_channels : List [int ],
81+ * ,
82+ weights : Optional [WeightsEnum ],
8583 progress : bool ,
8684 quantize : bool ,
87- * args : Any ,
8885 ** kwargs : Any ,
8986) -> QuantizableShuffleNetV2 :
87+ if weights is not None :
88+ _ovewrite_named_param (kwargs , "num_classes" , len (weights .meta ["categories" ]))
89+ if "backend" in weights .meta :
90+ _ovewrite_named_param (kwargs , "backend" , weights .meta ["backend" ])
91+ backend = kwargs .pop ("backend" , "fbgemm" )
9092
91- model = QuantizableShuffleNetV2 (* args , ** kwargs )
93+ model = QuantizableShuffleNetV2 (stages_repeats , stages_out_channels , ** kwargs )
9294 _replace_relu (model )
93-
9495 if quantize :
95- # TODO use pretrained as a string to specify the backend
96- backend = "fbgemm"
9796 quantize_model (model , backend )
98- else :
99- assert pretrained in [True , False ]
100-
101- if pretrained :
102- model_url : Optional [str ] = None
103- if quantize :
104- model_url = quant_model_urls [arch + "_" + backend ]
105- else :
106- model_url = model_urls [arch ]
10797
108- state_dict = load_state_dict_from_url (model_url , progress = progress )
98+ if weights is not None :
99+ model .load_state_dict (weights .get_state_dict (progress = progress ))
109100
110- model .load_state_dict (state_dict )
111101 return model
112102
113103
104+ _COMMON_META = {
105+ "task" : "image_classification" ,
106+ "architecture" : "ShuffleNetV2" ,
107+ "publication_year" : 2018 ,
108+ "size" : (224 , 224 ),
109+ "min_size" : (1 , 1 ),
110+ "categories" : _IMAGENET_CATEGORIES ,
111+ "interpolation" : InterpolationMode .BILINEAR ,
112+ "backend" : "fbgemm" ,
113+ "quantization" : "ptq" ,
114+ "recipe" : "https://github.com/pytorch/vision/tree/main/references/classification#post-training-quantized-models" ,
115+ }
116+
117+
118+ class ShuffleNet_V2_X0_5_QuantizedWeights (WeightsEnum ):
119+ IMAGENET1K_FBGEMM_V1 = Weights (
120+ url = "https://download.pytorch.org/models/quantized/shufflenetv2_x0.5_fbgemm-00845098.pth" ,
121+ transforms = partial (ImageClassificationEval , crop_size = 224 ),
122+ meta = {
123+ ** _COMMON_META ,
124+ "num_params" : 1366792 ,
125+ "unquantized" : ShuffleNet_V2_X0_5_Weights .IMAGENET1K_V1 ,
126+ "acc@1" : 57.972 ,
127+ "acc@5" : 79.780 ,
128+ },
129+ )
130+ DEFAULT = IMAGENET1K_FBGEMM_V1
131+
132+
133+ class ShuffleNet_V2_X1_0_QuantizedWeights (WeightsEnum ):
134+ IMAGENET1K_FBGEMM_V1 = Weights (
135+ url = "https://download.pytorch.org/models/quantized/shufflenetv2_x1_fbgemm-db332c57.pth" ,
136+ transforms = partial (ImageClassificationEval , crop_size = 224 ),
137+ meta = {
138+ ** _COMMON_META ,
139+ "num_params" : 2278604 ,
140+ "unquantized" : ShuffleNet_V2_X1_0_Weights .IMAGENET1K_V1 ,
141+ "acc@1" : 68.360 ,
142+ "acc@5" : 87.582 ,
143+ },
144+ )
145+ DEFAULT = IMAGENET1K_FBGEMM_V1
146+
147+
148+ @handle_legacy_interface (
149+ weights = (
150+ "pretrained" ,
151+ lambda kwargs : ShuffleNet_V2_X0_5_QuantizedWeights .IMAGENET1K_FBGEMM_V1
152+ if kwargs .get ("quantize" , False )
153+ else ShuffleNet_V2_X0_5_Weights .IMAGENET1K_V1 ,
154+ )
155+ )
114156def shufflenet_v2_x0_5 (
115- pretrained : bool = False ,
157+ * ,
158+ weights : Optional [Union [ShuffleNet_V2_X0_5_QuantizedWeights , ShuffleNet_V2_X0_5_Weights ]] = None ,
116159 progress : bool = True ,
117160 quantize : bool = False ,
118161 ** kwargs : Any ,
@@ -123,17 +166,28 @@ def shufflenet_v2_x0_5(
123166 <https://arxiv.org/abs/1807.11164>`_.
124167
125168 Args:
126- pretrained (bool): If True, returns a model pre-trained on ImageNet
169+ pretrained (ShuffleNet_V2_X0_5_QuantizedWeights or ShuffleNet_V2_X0_5_Weights, optional): The pretrained
170+ weights for the model
127171 progress (bool): If True, displays a progress bar of the download to stderr
128172 quantize (bool): If True, return a quantized version of the model
129173 """
174+ weights = (ShuffleNet_V2_X0_5_QuantizedWeights if quantize else ShuffleNet_V2_X0_5_Weights ).verify (weights )
130175 return _shufflenetv2 (
131- "shufflenetv2_x0.5" , pretrained , progress , quantize , [4 , 8 , 4 ], [24 , 48 , 96 , 192 , 1024 ], ** kwargs
176+ [4 , 8 , 4 ], [24 , 48 , 96 , 192 , 1024 ], weights = weights , progress = progress , quantize = quantize , ** kwargs
132177 )
133178
134179
180+ @handle_legacy_interface (
181+ weights = (
182+ "pretrained" ,
183+ lambda kwargs : ShuffleNet_V2_X1_0_QuantizedWeights .IMAGENET1K_FBGEMM_V1
184+ if kwargs .get ("quantize" , False )
185+ else ShuffleNet_V2_X1_0_Weights .IMAGENET1K_V1 ,
186+ )
187+ )
135188def shufflenet_v2_x1_0 (
136- pretrained : bool = False ,
189+ * ,
190+ weights : Optional [Union [ShuffleNet_V2_X1_0_QuantizedWeights , ShuffleNet_V2_X1_0_Weights ]] = None ,
137191 progress : bool = True ,
138192 quantize : bool = False ,
139193 ** kwargs : Any ,
@@ -144,10 +198,12 @@ def shufflenet_v2_x1_0(
144198 <https://arxiv.org/abs/1807.11164>`_.
145199
146200 Args:
147- pretrained (bool): If True, returns a model pre-trained on ImageNet
201+ pretrained (ShuffleNet_V2_X1_0_QuantizedWeights or ShuffleNet_V2_X1_0_Weights, optional): The pretrained
202+ weights for the model
148203 progress (bool): If True, displays a progress bar of the download to stderr
149204 quantize (bool): If True, return a quantized version of the model
150205 """
206+ weights = (ShuffleNet_V2_X1_0_QuantizedWeights if quantize else ShuffleNet_V2_X1_0_Weights ).verify (weights )
151207 return _shufflenetv2 (
152- "shufflenetv2_x1.0" , pretrained , progress , quantize , [4 , 8 , 4 ], [24 , 116 , 232 , 464 , 1024 ], ** kwargs
208+ [4 , 8 , 4 ], [24 , 116 , 232 , 464 , 1024 ], weights = weights , progress = progress , quantize = quantize , ** kwargs
153209 )
0 commit comments