Skip to content

Commit 19bb0f7

Browse files
authored
Merge branch 'main' into prototype/ssd_multiweight
2 parents c724c16 + 471f0fb commit 19bb0f7

File tree

7 files changed

+232
-37
lines changed

7 files changed

+232
-37
lines changed

test/builtin_dataset_mocks.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,7 @@ def caltech256(info, root, config):
452452

453453
@dataset_mocks.register_mock_data_fn
454454
def imagenet(info, root, config):
455-
devkit_root = root / "ILSVRC2012_devkit_t12"
456-
devkit_root.mkdir()
457-
458455
wnids = tuple(info.extra.wnid_to_category.keys())
459-
460456
if config.split == "train":
461457
images_root = root / "ILSVRC2012_img_train"
462458

@@ -470,7 +466,7 @@ def imagenet(info, root, config):
470466
num_examples=1,
471467
)
472468
make_tar(images_root, f"{wnid}.tar", files[0].parent)
473-
else:
469+
elif config.split == "val":
474470
num_samples = 3
475471
files = create_image_folder(
476472
root=root,
@@ -479,14 +475,26 @@ def imagenet(info, root, config):
479475
num_examples=num_samples,
480476
)
481477
images_root = files[0].parent
478+
else: # config.split == "test"
479+
images_root = root / "ILSVRC2012_img_test_v10102019"
482480

483-
data_root = devkit_root / "data"
484-
data_root.mkdir()
485-
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
486-
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
487-
file.write(f"{label}\n")
481+
num_samples = 3
488482

483+
create_image_folder(
484+
root=images_root,
485+
name="test",
486+
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
487+
num_examples=num_samples,
488+
)
489489
make_tar(root, f"{images_root.name}.tar", images_root)
490+
491+
devkit_root = root / "ILSVRC2012_devkit_t12"
492+
devkit_root.mkdir()
493+
data_root = devkit_root / "data"
494+
data_root.mkdir()
495+
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
496+
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
497+
file.write(f"{label}\n")
490498
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
491499

492500
return num_samples

test/test_datasets.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1914,11 +1914,13 @@ def inject_fake_data(self, tmpdir, config):
19141914

19151915
def test_flow(self):
19161916
# Make sure flow exists for train split, and make sure there are as many flow values as (pairs of) images
1917+
h, w = self.FLOW_H, self.FLOW_W
1918+
expected_flow = np.arange(2 * h * w).reshape(h, w, 2).transpose(2, 0, 1)
19171919
with self.create_dataset(split="train") as (dataset, _):
19181920
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
19191921
for _, _, flow in dataset:
1920-
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
1921-
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
1922+
assert flow.shape == (2, h, w)
1923+
np.testing.assert_allclose(flow, expected_flow)
19221924

19231925
# Make sure flow is always None for test split
19241926
with self.create_dataset(split="test") as (dataset, _):
@@ -2041,11 +2043,14 @@ def inject_fake_data(self, tmpdir, config):
20412043
def test_flow(self, config):
20422044
# Make sure flow always exists, and make sure there are as many flow values as (pairs of) images
20432045
# Also make sure the flow is properly decoded
2046+
2047+
h, w = self.FLOW_H, self.FLOW_W
2048+
expected_flow = np.arange(2 * h * w).reshape(h, w, 2).transpose(2, 0, 1)
20442049
with self.create_dataset(config=config) as (dataset, _):
20452050
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
20462051
for _, _, flow in dataset:
2047-
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
2048-
np.testing.assert_allclose(flow, np.arange(flow.size).reshape(flow.shape))
2052+
assert flow.shape == (2, h, w)
2053+
np.testing.assert_allclose(flow, expected_flow)
20492054

20502055

20512056
class FlyingThings3DTestCase(datasets_utils.ImageDatasetTestCase):
@@ -2095,11 +2100,16 @@ def inject_fake_data(self, tmpdir, config):
20952100

20962101
@datasets_utils.test_all_configs
20972102
def test_flow(self, config):
2103+
h, w = self.FLOW_H, self.FLOW_W
2104+
expected_flow = np.arange(3 * h * w).reshape(h, w, 3).transpose(2, 0, 1)
2105+
expected_flow = np.flip(expected_flow, axis=1)
2106+
expected_flow = expected_flow[:2, :, :]
2107+
20982108
with self.create_dataset(config=config) as (dataset, _):
20992109
assert dataset._flow_list and len(dataset._flow_list) == len(dataset._image_list)
21002110
for _, _, flow in dataset:
21012111
assert flow.shape == (2, self.FLOW_H, self.FLOW_W)
2102-
# We don't check the values because the reshaping and flipping makes it hard to figure out
2112+
np.testing.assert_allclose(flow, expected_flow)
21032113

21042114
def test_bad_input(self):
21052115
with pytest.raises(ValueError, match="Unknown value 'bad' for argument split"):

test/test_models.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import contextlib
12
import functools
23
import io
34
import operator
45
import os
6+
import pkgutil
7+
import sys
58
import traceback
69
import warnings
710
from collections import OrderedDict
@@ -14,7 +17,6 @@
1417
from common_utils import map_nested_tensor_object, freeze_rng_state, set_rng_seed, cpu_and_gpu, needs_cuda
1518
from torchvision import models
1619

17-
1820
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
1921

2022

@@ -23,6 +25,51 @@ def get_models_from_module(module):
2325
return [v for k, v in module.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
2426

2527

28+
@pytest.fixture
29+
def disable_weight_loading(mocker):
30+
"""When testing models, the two slowest operations are the downloading of the weights to a file and loading them
31+
into the model. Unless, you want to test against specific weights, these steps can be disabled without any
32+
drawbacks.
33+
34+
Including this fixture into the signature of your test, i.e. `test_foo(disable_weight_loading)`, will recurse
35+
through all models in `torchvision.models` and will patch all occurrences of the function
36+
`download_state_dict_from_url` as well as the method `load_state_dict` on all subclasses of `nn.Module` to be
37+
no-ops.
38+
39+
.. warning:
40+
41+
Loaded models are still executable as normal, but will always have random weights. Make sure to not use this
42+
fixture if you want to compare the model output against reference values.
43+
44+
"""
45+
starting_point = models
46+
function_name = "load_state_dict_from_url"
47+
method_name = "load_state_dict"
48+
49+
module_names = {info.name for info in pkgutil.walk_packages(starting_point.__path__, f"{starting_point.__name__}.")}
50+
targets = {f"torchvision._internally_replaced_utils.{function_name}", f"torch.nn.Module.{method_name}"}
51+
for name in module_names:
52+
module = sys.modules.get(name)
53+
if not module:
54+
continue
55+
56+
if function_name in module.__dict__:
57+
targets.add(f"{module.__name__}.{function_name}")
58+
59+
targets.update(
60+
{
61+
f"{module.__name__}.{obj.__name__}.{method_name}"
62+
for obj in module.__dict__.values()
63+
if isinstance(obj, type) and issubclass(obj, nn.Module) and method_name in obj.__dict__
64+
}
65+
)
66+
67+
for target in targets:
68+
# See https://github.com/pytorch/vision/pull/4867#discussion_r743677802 for details
69+
with contextlib.suppress(AttributeError):
70+
mocker.patch(target)
71+
72+
2673
def _get_expected_file(name=None):
2774
# Determine expected file based on environment
2875
expected_file_base = get_relative_path(os.path.realpath(__file__), "expect")
@@ -762,7 +809,7 @@ def test_quantized_classification_model(model_fn):
762809

763810

764811
@pytest.mark.parametrize("model_fn", get_models_from_module(models.detection))
765-
def test_detection_model_trainable_backbone_layers(model_fn):
812+
def test_detection_model_trainable_backbone_layers(model_fn, disable_weight_loading):
766813
model_name = model_fn.__name__
767814
max_trainable = _model_tests_values[model_name]["max_trainable"]
768815
n_trainable_params = []

torchvision/datasets/_optical_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ def _read_flo(file_name):
376376
w = int(np.fromfile(f, "<i4", count=1))
377377
h = int(np.fromfile(f, "<i4", count=1))
378378
data = np.fromfile(f, "<f4", count=2 * w * h)
379-
return data.reshape(2, h, w)
379+
return data.reshape(h, w, 2).transpose(2, 0, 1)
380380

381381

382382
def _read_16bits_png_with_flow_and_valid_mask(file_name):

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,17 @@ def _make_info(self) -> DatasetInfo:
3434
type=DatasetType.IMAGE,
3535
categories=categories,
3636
homepage="https://www.image-net.org/",
37-
valid_options=dict(split=("train", "val")),
37+
valid_options=dict(split=("train", "val", "test")),
3838
extra=dict(
3939
wnid_to_category=FrozenMapping(zip(wnids, categories)),
4040
category_to_wnid=FrozenMapping(zip(categories, wnids)),
41-
sizes=FrozenMapping([(DatasetConfig(split="train"), 1281167), (DatasetConfig(split="val"), 50000)]),
41+
sizes=FrozenMapping(
42+
[
43+
(DatasetConfig(split="train"), 1_281_167),
44+
(DatasetConfig(split="val"), 50_000),
45+
(DatasetConfig(split="test"), 100_000),
46+
]
47+
),
4248
),
4349
)
4450

@@ -53,17 +59,15 @@ def category_to_wnid(self) -> Dict[str, str]:
5359
def wnid_to_category(self) -> Dict[str, str]:
5460
return cast(Dict[str, str], self.info.extra.wnid_to_category)
5561

62+
_IMAGES_CHECKSUMS = {
63+
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
64+
"val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
65+
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
66+
}
67+
5668
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
57-
if config.split == "train":
58-
images = HttpResource(
59-
"ILSVRC2012_img_train.tar",
60-
sha256="b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
61-
)
62-
else: # config.split == "val"
63-
images = HttpResource(
64-
"ILSVRC2012_img_val.tar",
65-
sha256="c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
66-
)
69+
name = "test_v10102019" if config.split == "test" else config.split
70+
images = HttpResource(f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
6771

6872
devkit = HttpResource(
6973
"ILSVRC2012_devkit_t12.tar.gz",
@@ -81,11 +85,11 @@ def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[int, s
8185
label = self.categories.index(category)
8286
return (label, category, wnid), data
8387

84-
_VAL_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_val_(?P<id>\d{8})[.]JPEG")
88+
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
8589

86-
def _val_image_key(self, data: Tuple[str, Any]) -> int:
90+
def _val_test_image_key(self, data: Tuple[str, Any]) -> int:
8791
path = pathlib.Path(data[0])
88-
return int(self._VAL_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
92+
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
8993

9094
def _collate_val_data(
9195
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
@@ -96,9 +100,12 @@ def _collate_val_data(
96100
wnid = self.category_to_wnid[category]
97101
return (label, category, wnid), image_data
98102

103+
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[None, None, None], Tuple[str, io.IOBase]]:
104+
return (None, None, None), data
105+
99106
def _collate_and_decode_sample(
100107
self,
101-
data: Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]],
108+
data: Tuple[Tuple[Optional[int], Optional[str], Optional[str]], Tuple[str, io.IOBase]],
102109
*,
103110
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
104111
) -> Dict[str, Any]:
@@ -108,7 +115,7 @@ def _collate_and_decode_sample(
108115
return dict(
109116
path=path,
110117
image=decoder(buffer) if decoder else buffer,
111-
label=torch.tensor(label),
118+
label=label,
112119
category=category,
113120
wnid=wnid,
114121
)
@@ -129,7 +136,7 @@ def _make_datapipe(
129136
dp = TarArchiveReader(images_dp)
130137
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
131138
dp = Mapper(dp, self._collate_train_data)
132-
else:
139+
elif config.split == "val":
133140
devkit_dp = TarArchiveReader(devkit_dp)
134141
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
135142
devkit_dp = LineReader(devkit_dp, return_path=False)
@@ -141,10 +148,13 @@ def _make_datapipe(
141148
devkit_dp,
142149
images_dp,
143150
key_fn=getitem(0),
144-
ref_key_fn=self._val_image_key,
151+
ref_key_fn=self._val_test_image_key,
145152
buffer_size=INFINITE_BUFFER_SIZE,
146153
)
147154
dp = Mapper(dp, self._collate_val_data)
155+
else: # config.split == "test"
156+
dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
157+
dp = Mapper(dp, self._collate_test_data)
148158

149159
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
150160

torchvision/prototype/models/detection/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .mask_rcnn import *
44
from .retinanet import *
55
from .ssd import *
6+
from .ssdlite import *

0 commit comments

Comments
 (0)