Skip to content

Commit 8a674da

Browse files
kazhangfacebook-github-bot
authored andcommitted
[fbsync] add test split for imagenet (#4866)
Summary: * add test split for imagenet * add infinite buffer size to shuffler Reviewed By: datumbox Differential Revision: D32298970 fbshipit-source-id: 3566398a571df2469c597fcdf93af124ef99e9c9
1 parent 5bf1da1 commit 8a674da

File tree

2 files changed

+47
-29
lines changed

2 files changed

+47
-29
lines changed

test/builtin_dataset_mocks.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -452,11 +452,7 @@ def caltech256(info, root, config):
452452

453453
@dataset_mocks.register_mock_data_fn
454454
def imagenet(info, root, config):
455-
devkit_root = root / "ILSVRC2012_devkit_t12"
456-
devkit_root.mkdir()
457-
458455
wnids = tuple(info.extra.wnid_to_category.keys())
459-
460456
if config.split == "train":
461457
images_root = root / "ILSVRC2012_img_train"
462458

@@ -470,7 +466,7 @@ def imagenet(info, root, config):
470466
num_examples=1,
471467
)
472468
make_tar(images_root, f"{wnid}.tar", files[0].parent)
473-
else:
469+
elif config.split == "val":
474470
num_samples = 3
475471
files = create_image_folder(
476472
root=root,
@@ -479,14 +475,26 @@ def imagenet(info, root, config):
479475
num_examples=num_samples,
480476
)
481477
images_root = files[0].parent
478+
else: # config.split == "test"
479+
images_root = root / "ILSVRC2012_img_test_v10102019"
482480

483-
data_root = devkit_root / "data"
484-
data_root.mkdir()
485-
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
486-
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
487-
file.write(f"{label}\n")
481+
num_samples = 3
488482

483+
create_image_folder(
484+
root=images_root,
485+
name="test",
486+
file_name_fn=lambda image_idx: f"ILSVRC2012_test_{image_idx + 1:08d}.JPEG",
487+
num_examples=num_samples,
488+
)
489489
make_tar(root, f"{images_root.name}.tar", images_root)
490+
491+
devkit_root = root / "ILSVRC2012_devkit_t12"
492+
devkit_root.mkdir()
493+
data_root = devkit_root / "data"
494+
data_root.mkdir()
495+
with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file:
496+
for label in torch.randint(0, len(wnids), (num_samples,)).tolist():
497+
file.write(f"{label}\n")
490498
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
491499

492500
return num_samples

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,17 @@ def _make_info(self) -> DatasetInfo:
3434
type=DatasetType.IMAGE,
3535
categories=categories,
3636
homepage="https://www.image-net.org/",
37-
valid_options=dict(split=("train", "val")),
37+
valid_options=dict(split=("train", "val", "test")),
3838
extra=dict(
3939
wnid_to_category=FrozenMapping(zip(wnids, categories)),
4040
category_to_wnid=FrozenMapping(zip(categories, wnids)),
41-
sizes=FrozenMapping([(DatasetConfig(split="train"), 1281167), (DatasetConfig(split="val"), 50000)]),
41+
sizes=FrozenMapping(
42+
[
43+
(DatasetConfig(split="train"), 1_281_167),
44+
(DatasetConfig(split="val"), 50_000),
45+
(DatasetConfig(split="test"), 100_000),
46+
]
47+
),
4248
),
4349
)
4450

@@ -53,17 +59,15 @@ def category_to_wnid(self) -> Dict[str, str]:
5359
def wnid_to_category(self) -> Dict[str, str]:
5460
return cast(Dict[str, str], self.info.extra.wnid_to_category)
5561

62+
_IMAGES_CHECKSUMS = {
63+
"train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
64+
"val": "c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
65+
"test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4",
66+
}
67+
5668
def resources(self, config: DatasetConfig) -> List[OnlineResource]:
57-
if config.split == "train":
58-
images = HttpResource(
59-
"ILSVRC2012_img_train.tar",
60-
sha256="b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb",
61-
)
62-
else: # config.split == "val"
63-
images = HttpResource(
64-
"ILSVRC2012_img_val.tar",
65-
sha256="c7e06a6c0baccf06d8dbeb6577d71efff84673a5dbdd50633ab44f8ea0456ae0",
66-
)
69+
name = "test_v10102019" if config.split == "test" else config.split
70+
images = HttpResource(f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name])
6771

6872
devkit = HttpResource(
6973
"ILSVRC2012_devkit_t12.tar.gz",
@@ -81,11 +85,11 @@ def _collate_train_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[int, s
8185
label = self.categories.index(category)
8286
return (label, category, wnid), data
8387

84-
_VAL_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_val_(?P<id>\d{8})[.]JPEG")
88+
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
8589

86-
def _val_image_key(self, data: Tuple[str, Any]) -> int:
90+
def _val_test_image_key(self, data: Tuple[str, Any]) -> int:
8791
path = pathlib.Path(data[0])
88-
return int(self._VAL_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
92+
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr]
8993

9094
def _collate_val_data(
9195
self, data: Tuple[Tuple[int, int], Tuple[str, io.IOBase]]
@@ -96,9 +100,12 @@ def _collate_val_data(
96100
wnid = self.category_to_wnid[category]
97101
return (label, category, wnid), image_data
98102

103+
def _collate_test_data(self, data: Tuple[str, io.IOBase]) -> Tuple[Tuple[None, None, None], Tuple[str, io.IOBase]]:
104+
return (None, None, None), data
105+
99106
def _collate_and_decode_sample(
100107
self,
101-
data: Tuple[Tuple[int, str, str], Tuple[str, io.IOBase]],
108+
data: Tuple[Tuple[Optional[int], Optional[str], Optional[str]], Tuple[str, io.IOBase]],
102109
*,
103110
decoder: Optional[Callable[[io.IOBase], torch.Tensor]],
104111
) -> Dict[str, Any]:
@@ -108,7 +115,7 @@ def _collate_and_decode_sample(
108115
return dict(
109116
path=path,
110117
image=decoder(buffer) if decoder else buffer,
111-
label=torch.tensor(label),
118+
label=label,
112119
category=category,
113120
wnid=wnid,
114121
)
@@ -129,7 +136,7 @@ def _make_datapipe(
129136
dp = TarArchiveReader(images_dp)
130137
dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE)
131138
dp = Mapper(dp, self._collate_train_data)
132-
else:
139+
elif config.split == "val":
133140
devkit_dp = TarArchiveReader(devkit_dp)
134141
devkit_dp = Filter(devkit_dp, path_comparator("name", "ILSVRC2012_validation_ground_truth.txt"))
135142
devkit_dp = LineReader(devkit_dp, return_path=False)
@@ -141,10 +148,13 @@ def _make_datapipe(
141148
devkit_dp,
142149
images_dp,
143150
key_fn=getitem(0),
144-
ref_key_fn=self._val_image_key,
151+
ref_key_fn=self._val_test_image_key,
145152
buffer_size=INFINITE_BUFFER_SIZE,
146153
)
147154
dp = Mapper(dp, self._collate_val_data)
155+
else: # config.split == "test"
156+
dp = Shuffler(images_dp, buffer_size=INFINITE_BUFFER_SIZE)
157+
dp = Mapper(dp, self._collate_test_data)
148158

149159
return Mapper(dp, self._collate_and_decode_sample, fn_kwargs=dict(decoder=decoder))
150160

0 commit comments

Comments
 (0)