Skip to content

Commit 8230090

Browse files
committed
keep imagenet prototype as class but factor out methods
1 parent e7e921e commit 8230090

File tree

2 files changed

+85
-74
lines changed

2 files changed

+85
-74
lines changed

test/test_prototype_builtin_datasets.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def test_transformable(self, test_home, dataset_mock, config):
103103

104104
next(iter(dataset.map(transforms.Identity())))
105105

106-
@pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237")
107106
@parametrize_dataset_mocks(DATASET_MOCKS)
108107
def test_serializable(self, test_home, dataset_mock, config):
109108
dataset_mock.prepare(test_home, config)
@@ -115,7 +114,6 @@ def test_serializable(self, test_home, dataset_mock, config):
115114
# TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also
116115
# that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680
117116
# contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now.
118-
@pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237")
119117
@parametrize_dataset_mocks(DATASET_MOCKS)
120118
@pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter))
121119
def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type):
@@ -127,7 +125,7 @@ def scan(graph):
127125
dataset_mock.prepare(test_home, config)
128126
dataset = datasets.load(dataset_mock.name, **config)
129127

130-
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))):
128+
if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset, only_datapipe=True))):
131129
raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.")
132130

133131
@parametrize_dataset_mocks(DATASET_MOCKS)

torchvision/prototype/datasets/_builtin/imagenet.py

Lines changed: 84 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,92 @@
3535

3636
NAME = "imagenet"
3737

38+
CATEGORIES, WNIDS = cast(
39+
Tuple[Tuple[str, ...], Tuple[str, ...]],
40+
zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")),
41+
)
42+
WNID_TO_CATEGORY = dict(zip(WNIDS, CATEGORIES))
43+
3844

3945
@register_info(NAME)
4046
def _info() -> Dict[str, Any]:
41-
categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories"))
42-
return dict(categories=categories, wnids=wnids)
47+
return dict(categories=CATEGORIES, wnids=WNIDS)
4348

4449

4550
class ImageNetResource(ManualDownloadResource):
4651
def __init__(self, **kwargs: Any) -> None:
4752
super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs)
4853

4954

55+
TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
56+
57+
58+
def prepare_train_data(data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
59+
path = pathlib.Path(data[0])
60+
wnid = cast(Match[str], TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"]
61+
label = Label.from_category(WNID_TO_CATEGORY[wnid], categories=CATEGORIES)
62+
return (label, wnid), data
63+
64+
65+
def prepare_test_data(data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]:
66+
return None, data
67+
68+
69+
def classifiy_devkit(data: Tuple[str, BinaryIO]) -> Optional[int]:
70+
return {
71+
"meta.mat": 0,
72+
"ILSVRC2012_validation_ground_truth.txt": 1,
73+
}.get(pathlib.Path(data[0]).name)
74+
75+
76+
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
77+
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
78+
WNID_MAP = {
79+
"n03126707": "construction crane",
80+
"n03710721": "tank suit",
81+
}
82+
83+
84+
def extract_categories_and_wnids(data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
85+
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
86+
return [
87+
(WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
88+
for _, wnid, category, _, num_children, *_ in synsets
89+
# if num_children > 0, we are looking at a superclass that has no direct instance
90+
if num_children == 0
91+
]
92+
93+
94+
def imagenet_label_to_wnid(imagenet_label: str) -> str:
95+
return WNIDS[int(imagenet_label) - 1]
96+
97+
98+
VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
99+
100+
101+
def val_test_image_key(path: pathlib.Path) -> int:
102+
return int(VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index]
103+
104+
105+
def prepare_val_data(
106+
data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]]
107+
) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
108+
label_data, image_data = data
109+
_, wnid = label_data
110+
label = Label.from_category(WNID_TO_CATEGORY[wnid], categories=CATEGORIES)
111+
return (label, wnid), image_data
112+
113+
114+
def prepare_sample(data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]]) -> Dict[str, Any]:
115+
label_data, (path, buffer) = data
116+
117+
return dict(
118+
dict(zip(("label", "wnid"), label_data if label_data else (None, None))),
119+
path=path,
120+
image=EncodedImage.from_file(buffer),
121+
)
122+
123+
50124
@register_dataset(NAME)
51125
class ImageNet(Dataset2):
52126
def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None:
@@ -83,67 +157,6 @@ def _resources(self) -> List[OnlineResource]:
83157

84158
return resources
85159

86-
_TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?P<wnid>n\d{8})_\d+[.]JPEG")
87-
88-
def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
89-
path = pathlib.Path(data[0])
90-
wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"]
91-
label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories)
92-
return (label, wnid), data
93-
94-
def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]:
95-
return None, data
96-
97-
def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]:
98-
return {
99-
"meta.mat": 0,
100-
"ILSVRC2012_validation_ground_truth.txt": 1,
101-
}.get(pathlib.Path(data[0]).name)
102-
103-
# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849
104-
# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment
105-
_WNID_MAP = {
106-
"n03126707": "construction crane",
107-
"n03710721": "tank suit",
108-
}
109-
110-
def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]:
111-
synsets = read_mat(data[1], squeeze_me=True)["synsets"]
112-
return [
113-
(self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid)
114-
for _, wnid, category, _, num_children, *_ in synsets
115-
# if num_children > 0, we are looking at a superclass that has no direct instance
116-
if num_children == 0
117-
]
118-
119-
def _imagenet_label_to_wnid(self, imagenet_label: str) -> str:
120-
return self._wnids[int(imagenet_label) - 1]
121-
122-
_VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P<id>\d{8})[.]JPEG")
123-
124-
def _val_test_image_key(self, path: pathlib.Path) -> int:
125-
return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index]
126-
127-
def _prepare_val_data(
128-
self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]]
129-
) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]:
130-
label_data, image_data = data
131-
_, wnid = label_data
132-
label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories)
133-
return (label, wnid), image_data
134-
135-
def _prepare_sample(
136-
self,
137-
data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]],
138-
) -> Dict[str, Any]:
139-
label_data, (path, buffer) = data
140-
141-
return dict(
142-
dict(zip(("label", "wnid"), label_data if label_data else (None, None))),
143-
path=path,
144-
image=EncodedImage.from_file(buffer),
145-
)
146-
147160
def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]:
148161
if self._split in {"train", "test"}:
149162
dp = resource_dps[0]
@@ -154,19 +167,19 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
154167

155168
dp = hint_shuffling(dp)
156169
dp = hint_sharding(dp)
157-
dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data)
170+
dp = Mapper(dp, prepare_train_data if self._split == "train" else prepare_test_data)
158171
else: # config.split == "val":
159172
images_dp, devkit_dp = resource_dps
160173

161174
meta_dp, label_dp = Demultiplexer(
162-
devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
175+
devkit_dp, 2, classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE
163176
)
164177

165-
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
178+
meta_dp = Mapper(meta_dp, extract_categories_and_wnids)
166179
_, wnids = zip(*next(iter(meta_dp)))
167180

168181
label_dp = LineReader(label_dp, decode=True, return_path=False)
169-
label_dp = Mapper(label_dp, self._imagenet_label_to_wnid)
182+
label_dp = Mapper(label_dp, imagenet_label_to_wnid)
170183
label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1)
171184
label_dp = hint_shuffling(label_dp)
172185
label_dp = hint_sharding(label_dp)
@@ -175,12 +188,12 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str,
175188
label_dp,
176189
images_dp,
177190
key_fn=getitem(0),
178-
ref_key_fn=path_accessor(self._val_test_image_key),
191+
ref_key_fn=path_accessor(val_test_image_key),
179192
buffer_size=INFINITE_BUFFER_SIZE,
180193
)
181-
dp = Mapper(dp, self._prepare_val_data)
194+
dp = Mapper(dp, prepare_val_data)
182195

183-
return Mapper(dp, self._prepare_sample)
196+
return Mapper(dp, prepare_sample)
184197

185198
def __len__(self) -> int:
186199
return {
@@ -195,7 +208,7 @@ def _generate_categories(self) -> List[Tuple[str, ...]]:
195208

196209
devkit_dp = resources[1].load(self._root)
197210
meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat"))
198-
meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids)
211+
meta_dp = Mapper(meta_dp, extract_categories_and_wnids)
199212

200213
categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp)))
201214
categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])

0 commit comments

Comments
 (0)