Skip to content

Commit 469eadd

Browse files
committed
Fix test and linter
1 parent 829c5c1 commit 469eadd

File tree

3 files changed

+21
-14
lines changed

3 files changed

+21
-14
lines changed

test/test_prototype_models.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ def _build_model(fn, **kwargs):
5959
("ResNet50_Weights.DEFAULT", torchvision.models.ResNet50_Weights.IMAGENET1K_V2),
6060
(
6161
"ResNet50_QuantizedWeights.DEFAULT",
62-
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
62+
torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
6363
),
6464
(
6565
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
66-
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
66+
torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
6767
),
6868
],
6969
)
@@ -73,9 +73,9 @@ def test_get_weight(name, weight):
7373

7474
@pytest.mark.parametrize(
7575
"model_fn",
76-
TM.get_models_from_module(models)
76+
TM.get_models_from_module(torchvision.models)
7777
+ TM.get_models_from_module(models.detection)
78-
+ TM.get_models_from_module(models.quantization)
78+
+ TM.get_models_from_module(torchvision.models.quantization)
7979
+ TM.get_models_from_module(models.segmentation)
8080
+ TM.get_models_from_module(models.video)
8181
+ TM.get_models_from_module(models.optical_flow),
@@ -91,7 +91,7 @@ def test_naming_conventions(model_fn):
9191
"model_fn",
9292
TM.get_models_from_module(torchvision.models)
9393
+ TM.get_models_from_module(models.detection)
94-
+ TM.get_models_from_module(models.quantization)
94+
+ TM.get_models_from_module(torchvision.models.quantization)
9595
+ TM.get_models_from_module(models.segmentation)
9696
+ TM.get_models_from_module(models.video)
9797
+ TM.get_models_from_module(models.optical_flow),
@@ -150,12 +150,6 @@ def test_detection_model(model_fn, dev):
150150
TM.test_detection_model(model_fn, dev)
151151

152152

153-
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization))
154-
@run_if_test_with_prototype
155-
def test_quantized_classification_model(model_fn):
156-
TM.test_quantized_classification_model(model_fn)
157-
158-
159153
@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.segmentation))
160154
@pytest.mark.parametrize("dev", cpu_and_gpu())
161155
@run_if_test_with_prototype
@@ -181,7 +175,6 @@ def test_raft(model_builder, scripted):
181175
@pytest.mark.parametrize(
182176
"model_fn",
183177
TM.get_models_from_module(models.detection)
184-
+ TM.get_models_from_module(models.quantization)
185178
+ TM.get_models_from_module(models.segmentation)
186179
+ TM.get_models_from_module(models.video)
187180
+ TM.get_models_from_module(models.optical_flow),

torchvision/models/quantization/mobilenetv3.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
from .._api import WeightsEnum, Weights
1111
from .._meta import _IMAGENET_CATEGORIES
1212
from .._utils import handle_legacy_interface, _ovewrite_named_param
13-
from ..mobilenetv3 import InvertedResidual, InvertedResidualConfig, MobileNetV3, _mobilenet_v3_conf, MobileNet_V3_Large_Weights
13+
from ..mobilenetv3 import (
14+
InvertedResidual,
15+
InvertedResidualConfig,
16+
MobileNetV3,
17+
_mobilenet_v3_conf,
18+
MobileNet_V3_Large_Weights,
19+
)
1420
from .utils import _fuse_modules, _replace_relu
1521

1622

torchvision/models/quantization/resnet.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,14 @@
44
import torch
55
import torch.nn as nn
66
from torch import Tensor
7-
from torchvision.models.resnet import Bottleneck, BasicBlock, ResNet, ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
7+
from torchvision.models.resnet import (
8+
Bottleneck,
9+
BasicBlock,
10+
ResNet,
11+
ResNet18_Weights,
12+
ResNet50_Weights,
13+
ResNeXt101_32X8D_Weights,
14+
)
815

916
from ...transforms import ImageClassificationEval, InterpolationMode
1017
from .._api import WeightsEnum, Weights
@@ -138,6 +145,7 @@ def _resnet(
138145

139146
return model
140147

148+
141149
_COMMON_META = {
142150
"task": "image_classification",
143151
"size": (224, 224),

0 commit comments

Comments
 (0)