Skip to content

Commit f121ca7

Browse files
authored
Porting model tests (#5622)
* Porting tests * Remove unnecessary variable * Fix linter * Move prototype to extended tests * Fix download models job
1 parent c88b8dc commit f121ca7

File tree

11 files changed

+125
-134
lines changed

11 files changed

+125
-134
lines changed

.circleci/config.yml

Lines changed: 19 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ jobs:
335335
file_or_dir: test/test_onnx.py
336336

337337
unittest_prototype:
338+
docker:
339+
- image: circleci/python:3.7
340+
resource_class: xlarge
341+
steps:
342+
- checkout
343+
- install_torchvision
344+
- install_prototype_dependencies
345+
- pip_install:
346+
args: scipy pycocotools h5py
347+
descr: Install optional dependencies
348+
- run_tests_selective:
349+
file_or_dir: test/test_prototype_*.py
350+
351+
unittest_extended:
338352
docker:
339353
- image: circleci/python:3.7
340354
resource_class: xlarge
@@ -346,18 +360,14 @@ jobs:
346360
command: |
347361
sudo apt update -qy && sudo apt install -qy parallel wget
348362
mkdir -p ~/.cache/torch/hub/checkpoints
349-
python scripts/collect_model_urls.py torchvision/prototype/models \
363+
python scripts/collect_model_urls.py torchvision/models \
350364
| parallel -j0 'wget --no-verbose -O ~/.cache/torch/hub/checkpoints/`basename {}` {}\?source=ci'
351365
- install_torchvision
352-
- install_prototype_dependencies
353-
- pip_install:
354-
args: scipy pycocotools h5py
355-
descr: Install optional dependencies
356366
- run:
357-
name: Enable prototype tests
358-
command: echo 'export PYTORCH_TEST_WITH_PROTOTYPE=1' >> $BASH_ENV
367+
name: Enable extended tests
368+
command: echo 'export PYTORCH_TEST_WITH_EXTENDED=1' >> $BASH_ENV
359369
- run_tests_selective:
360-
file_or_dir: test/test_prototype_*.py
370+
file_or_dir: test/test_extended_*.py
361371

362372
binary_linux_wheel:
363373
<<: *binary_common
@@ -1094,6 +1104,7 @@ workflows:
10941104
- unittest_torchhub
10951105
- unittest_onnx
10961106
- unittest_prototype
1107+
- unittest_extended
10971108
{{ unittest_workflows() }}
10981109

10991110
cmake:

test/test_backbone_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,30 +23,30 @@ def get_available_models():
2323
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
2424
def test_resnet_fpn_backbone(backbone_name):
2525
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
26-
model = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)
26+
model = resnet_fpn_backbone(backbone_name=backbone_name)
2727
assert isinstance(model, BackboneWithFPN)
2828
y = model(x)
2929
assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
3030

3131
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
32-
resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False, trainable_layers=6)
32+
resnet_fpn_backbone(backbone_name=backbone_name, trainable_layers=6)
3333
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
34-
resnet_fpn_backbone(backbone_name, False, returned_layers=[0, 1, 2, 3])
34+
resnet_fpn_backbone(backbone_name=backbone_name, returned_layers=[0, 1, 2, 3])
3535
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
36-
resnet_fpn_backbone(backbone_name, False, returned_layers=[2, 3, 4, 5])
36+
resnet_fpn_backbone(backbone_name=backbone_name, returned_layers=[2, 3, 4, 5])
3737

3838

3939
@pytest.mark.parametrize("backbone_name", ("mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"))
4040
def test_mobilenet_backbone(backbone_name):
4141
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
42-
mobilenet_backbone(backbone_name=backbone_name, pretrained=False, fpn=False, trainable_layers=-1)
42+
mobilenet_backbone(backbone_name=backbone_name, fpn=False, trainable_layers=-1)
4343
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
44-
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2])
44+
mobilenet_backbone(backbone_name=backbone_name, fpn=True, returned_layers=[-1, 0, 1, 2])
4545
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
46-
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6])
47-
model_fpn = mobilenet_backbone(backbone_name, False, fpn=True)
46+
mobilenet_backbone(backbone_name=backbone_name, fpn=True, returned_layers=[3, 4, 5, 6])
47+
model_fpn = mobilenet_backbone(backbone_name=backbone_name, fpn=True)
4848
assert isinstance(model_fpn, BackboneWithFPN)
49-
model = mobilenet_backbone(backbone_name, False, fpn=False)
49+
model = mobilenet_backbone(backbone_name=backbone_name, fpn=False)
5050
assert isinstance(model, torch.nn.Sequential)
5151

5252

@@ -100,7 +100,7 @@ def forward(self, x):
100100

101101
class TestFxFeatureExtraction:
102102
inp = torch.rand(1, 3, 224, 224, dtype=torch.float32, device="cpu")
103-
model_defaults = {"num_classes": 1, "pretrained": False}
103+
model_defaults = {"num_classes": 1}
104104
leaf_modules = []
105105

106106
def _create_feature_extractor(self, *args, **kwargs):

test/test_cpp_models.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -53,50 +53,49 @@ def read_image2():
5353
"see https://github.com/pytorch/vision/issues/1191",
5454
)
5555
class Tester(unittest.TestCase):
56-
pretrained = False
5756
image = read_image1()
5857

5958
def test_alexnet(self):
60-
process_model(models.alexnet(self.pretrained), self.image, _C_tests.forward_alexnet, "Alexnet")
59+
process_model(models.alexnet(), self.image, _C_tests.forward_alexnet, "Alexnet")
6160

6261
def test_vgg11(self):
63-
process_model(models.vgg11(self.pretrained), self.image, _C_tests.forward_vgg11, "VGG11")
62+
process_model(models.vgg11(), self.image, _C_tests.forward_vgg11, "VGG11")
6463

6564
def test_vgg13(self):
66-
process_model(models.vgg13(self.pretrained), self.image, _C_tests.forward_vgg13, "VGG13")
65+
process_model(models.vgg13(), self.image, _C_tests.forward_vgg13, "VGG13")
6766

6867
def test_vgg16(self):
69-
process_model(models.vgg16(self.pretrained), self.image, _C_tests.forward_vgg16, "VGG16")
68+
process_model(models.vgg16(), self.image, _C_tests.forward_vgg16, "VGG16")
7069

7170
def test_vgg19(self):
72-
process_model(models.vgg19(self.pretrained), self.image, _C_tests.forward_vgg19, "VGG19")
71+
process_model(models.vgg19(), self.image, _C_tests.forward_vgg19, "VGG19")
7372

7473
def test_vgg11_bn(self):
75-
process_model(models.vgg11_bn(self.pretrained), self.image, _C_tests.forward_vgg11bn, "VGG11BN")
74+
process_model(models.vgg11_bn(), self.image, _C_tests.forward_vgg11bn, "VGG11BN")
7675

7776
def test_vgg13_bn(self):
78-
process_model(models.vgg13_bn(self.pretrained), self.image, _C_tests.forward_vgg13bn, "VGG13BN")
77+
process_model(models.vgg13_bn(), self.image, _C_tests.forward_vgg13bn, "VGG13BN")
7978

8079
def test_vgg16_bn(self):
81-
process_model(models.vgg16_bn(self.pretrained), self.image, _C_tests.forward_vgg16bn, "VGG16BN")
80+
process_model(models.vgg16_bn(), self.image, _C_tests.forward_vgg16bn, "VGG16BN")
8281

8382
def test_vgg19_bn(self):
84-
process_model(models.vgg19_bn(self.pretrained), self.image, _C_tests.forward_vgg19bn, "VGG19BN")
83+
process_model(models.vgg19_bn(), self.image, _C_tests.forward_vgg19bn, "VGG19BN")
8584

8685
def test_resnet18(self):
87-
process_model(models.resnet18(self.pretrained), self.image, _C_tests.forward_resnet18, "Resnet18")
86+
process_model(models.resnet18(), self.image, _C_tests.forward_resnet18, "Resnet18")
8887

8988
def test_resnet34(self):
90-
process_model(models.resnet34(self.pretrained), self.image, _C_tests.forward_resnet34, "Resnet34")
89+
process_model(models.resnet34(), self.image, _C_tests.forward_resnet34, "Resnet34")
9190

9291
def test_resnet50(self):
93-
process_model(models.resnet50(self.pretrained), self.image, _C_tests.forward_resnet50, "Resnet50")
92+
process_model(models.resnet50(), self.image, _C_tests.forward_resnet50, "Resnet50")
9493

9594
def test_resnet101(self):
96-
process_model(models.resnet101(self.pretrained), self.image, _C_tests.forward_resnet101, "Resnet101")
95+
process_model(models.resnet101(), self.image, _C_tests.forward_resnet101, "Resnet101")
9796

9897
def test_resnet152(self):
99-
process_model(models.resnet152(self.pretrained), self.image, _C_tests.forward_resnet152, "Resnet152")
98+
process_model(models.resnet152(), self.image, _C_tests.forward_resnet152, "Resnet152")
10099

101100
def test_resnext50_32x4d(self):
102101
process_model(models.resnext50_32x4d(), self.image, _C_tests.forward_resnext50_32x4d, "ResNext50_32x4d")
@@ -111,48 +110,44 @@ def test_wide_resnet101_2(self):
111110
process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, "WideResNet101_2")
112111

113112
def test_squeezenet1_0(self):
114-
process_model(
115-
models.squeezenet1_0(self.pretrained), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0"
116-
)
113+
process_model(models.squeezenet1_0(), self.image, _C_tests.forward_squeezenet1_0, "Squeezenet1.0")
117114

118115
def test_squeezenet1_1(self):
119-
process_model(
120-
models.squeezenet1_1(self.pretrained), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1"
121-
)
116+
process_model(models.squeezenet1_1(), self.image, _C_tests.forward_squeezenet1_1, "Squeezenet1.1")
122117

123118
def test_densenet121(self):
124-
process_model(models.densenet121(self.pretrained), self.image, _C_tests.forward_densenet121, "Densenet121")
119+
process_model(models.densenet121(), self.image, _C_tests.forward_densenet121, "Densenet121")
125120

126121
def test_densenet169(self):
127-
process_model(models.densenet169(self.pretrained), self.image, _C_tests.forward_densenet169, "Densenet169")
122+
process_model(models.densenet169(), self.image, _C_tests.forward_densenet169, "Densenet169")
128123

129124
def test_densenet201(self):
130-
process_model(models.densenet201(self.pretrained), self.image, _C_tests.forward_densenet201, "Densenet201")
125+
process_model(models.densenet201(), self.image, _C_tests.forward_densenet201, "Densenet201")
131126

132127
def test_densenet161(self):
133-
process_model(models.densenet161(self.pretrained), self.image, _C_tests.forward_densenet161, "Densenet161")
128+
process_model(models.densenet161(), self.image, _C_tests.forward_densenet161, "Densenet161")
134129

135130
def test_mobilenet_v2(self):
136-
process_model(models.mobilenet_v2(self.pretrained), self.image, _C_tests.forward_mobilenetv2, "MobileNet")
131+
process_model(models.mobilenet_v2(), self.image, _C_tests.forward_mobilenetv2, "MobileNet")
137132

138133
def test_googlenet(self):
139-
process_model(models.googlenet(self.pretrained), self.image, _C_tests.forward_googlenet, "GoogLeNet")
134+
process_model(models.googlenet(), self.image, _C_tests.forward_googlenet, "GoogLeNet")
140135

141136
def test_mnasnet0_5(self):
142-
process_model(models.mnasnet0_5(self.pretrained), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5")
137+
process_model(models.mnasnet0_5(), self.image, _C_tests.forward_mnasnet0_5, "MNASNet0_5")
143138

144139
def test_mnasnet0_75(self):
145-
process_model(models.mnasnet0_75(self.pretrained), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75")
140+
process_model(models.mnasnet0_75(), self.image, _C_tests.forward_mnasnet0_75, "MNASNet0_75")
146141

147142
def test_mnasnet1_0(self):
148-
process_model(models.mnasnet1_0(self.pretrained), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0")
143+
process_model(models.mnasnet1_0(), self.image, _C_tests.forward_mnasnet1_0, "MNASNet1_0")
149144

150145
def test_mnasnet1_3(self):
151-
process_model(models.mnasnet1_3(self.pretrained), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3")
146+
process_model(models.mnasnet1_3(), self.image, _C_tests.forward_mnasnet1_3, "MNASNet1_3")
152147

153148
def test_inception_v3(self):
154149
self.image = read_image2()
155-
process_model(models.inception_v3(self.pretrained), self.image, _C_tests.forward_inceptionv3, "Inceptionv3")
150+
process_model(models.inception_v3(), self.image, _C_tests.forward_inceptionv3, "Inceptionv3")
156151

157152

158153
if __name__ == "__main__":

test/test_prototype_models.py renamed to test/test_extended_models.py

Lines changed: 20 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,16 @@
33

44
import pytest
55
import test_models as TM
6-
import torchvision
6+
from torchvision import models
77
from torchvision.models._api import WeightsEnum, Weights
88
from torchvision.models._utils import handle_legacy_interface
99

1010
run_if_test_with_prototype = pytest.mark.skipif(
11-
os.getenv("PYTORCH_TEST_WITH_PROTOTYPE") != "1",
12-
reason="Prototype tests are disabled by default. Set PYTORCH_TEST_WITH_PROTOTYPE=1 to run them.",
11+
os.getenv("PYTORCH_TEST_WITH_EXTENDED") != "1",
12+
reason="Extended tests are disabled by default. Set PYTORCH_TEST_WITH_EXTENDED=1 to run them.",
1313
)
1414

1515

16-
def _get_original_model(model_fn):
17-
original_module_name = model_fn.__module__.replace(".prototype", "")
18-
module = importlib.import_module(original_module_name)
19-
return module.__dict__[model_fn.__name__]
20-
21-
2216
def _get_parent_module(model_fn):
2317
parent_module_name = ".".join(model_fn.__module__.split(".")[:-1])
2418
module = importlib.import_module(parent_module_name)
@@ -38,44 +32,33 @@ def _get_model_weights(model_fn):
3832
return None
3933

4034

41-
def _build_model(fn, **kwargs):
42-
try:
43-
model = fn(**kwargs)
44-
except ValueError as e:
45-
msg = str(e)
46-
if "No checkpoint is available" in msg:
47-
pytest.skip(msg)
48-
raise e
49-
return model.eval()
50-
51-
5235
@pytest.mark.parametrize(
5336
"name, weight",
5437
[
55-
("ResNet50_Weights.IMAGENET1K_V1", torchvision.models.ResNet50_Weights.IMAGENET1K_V1),
56-
("ResNet50_Weights.DEFAULT", torchvision.models.ResNet50_Weights.IMAGENET1K_V2),
38+
("ResNet50_Weights.IMAGENET1K_V1", models.ResNet50_Weights.IMAGENET1K_V1),
39+
("ResNet50_Weights.DEFAULT", models.ResNet50_Weights.IMAGENET1K_V2),
5740
(
5841
"ResNet50_QuantizedWeights.DEFAULT",
59-
torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
42+
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V2,
6043
),
6144
(
6245
"ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1",
63-
torchvision.models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
46+
models.quantization.ResNet50_QuantizedWeights.IMAGENET1K_FBGEMM_V1,
6447
),
6548
],
6649
)
6750
def test_get_weight(name, weight):
68-
assert torchvision.models.get_weight(name) == weight
51+
assert models.get_weight(name) == weight
6952

7053

7154
@pytest.mark.parametrize(
7255
"model_fn",
73-
TM.get_models_from_module(torchvision.models)
74-
+ TM.get_models_from_module(torchvision.models.detection)
75-
+ TM.get_models_from_module(torchvision.models.quantization)
76-
+ TM.get_models_from_module(torchvision.models.segmentation)
77-
+ TM.get_models_from_module(torchvision.models.video)
78-
+ TM.get_models_from_module(torchvision.models.optical_flow),
56+
TM.get_models_from_module(models)
57+
+ TM.get_models_from_module(models.detection)
58+
+ TM.get_models_from_module(models.quantization)
59+
+ TM.get_models_from_module(models.segmentation)
60+
+ TM.get_models_from_module(models.video)
61+
+ TM.get_models_from_module(models.optical_flow),
7962
)
8063
def test_naming_conventions(model_fn):
8164
weights_enum = _get_model_weights(model_fn)
@@ -86,12 +69,12 @@ def test_naming_conventions(model_fn):
8669

8770
@pytest.mark.parametrize(
8871
"model_fn",
89-
TM.get_models_from_module(torchvision.models)
90-
+ TM.get_models_from_module(torchvision.models.detection)
91-
+ TM.get_models_from_module(torchvision.models.quantization)
92-
+ TM.get_models_from_module(torchvision.models.segmentation)
93-
+ TM.get_models_from_module(torchvision.models.video)
94-
+ TM.get_models_from_module(torchvision.models.optical_flow),
72+
TM.get_models_from_module(models)
73+
+ TM.get_models_from_module(models.detection)
74+
+ TM.get_models_from_module(models.quantization)
75+
+ TM.get_models_from_module(models.segmentation)
76+
+ TM.get_models_from_module(models.video)
77+
+ TM.get_models_from_module(models.optical_flow),
9578
)
9679
@run_if_test_with_prototype
9780
def test_schema_meta_validation(model_fn):

test/test_hub.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ class TestHub:
2626
# Python cache as we run all hub tests in the same python process.
2727

2828
def test_load_from_github(self):
29-
hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False)
29+
hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False)
3030
assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
3131

3232
def test_set_dir(self):
3333
temp_dir = tempfile.gettempdir()
3434
hub.set_dir(temp_dir)
35-
hub_model = hub.load("pytorch/vision", "resnet18", pretrained=True, progress=False)
35+
hub_model = hub.load("pytorch/vision", "resnet18", weights="DEFAULT", progress=False)
3636
assert sum_of_model_parameters(hub_model).item() == pytest.approx(SUM_OF_PRETRAINED_RESNET18_PARAMS)
3737
assert os.path.exists(temp_dir + "/pytorch_vision_master")
3838
shutil.rmtree(temp_dir + "/pytorch_vision_master")

0 commit comments

Comments
 (0)