diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 956abf1a297..e3049bc555b 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -103,7 +103,6 @@ def test_transformable(self, test_home, dataset_mock, config): next(iter(dataset.map(transforms.Identity()))) - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks(DATASET_MOCKS) def test_serializable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) @@ -115,7 +114,6 @@ def test_serializable(self, test_home, dataset_mock, config): # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): @@ -127,7 +125,7 @@ def scan(graph): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) - if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): + if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset, only_datapipe=True))): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") @parametrize_dataset_mocks(DATASET_MOCKS) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 220c5edf17a..83e4908cd59 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -35,11 +35,16 @@ NAME = "imagenet" +CATEGORIES, WNIDS = cast( + Tuple[Tuple[str, ...], Tuple[str, ...]], + zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")), +) +WNID_TO_CATEGORY = dict(zip(WNIDS, CATEGORIES)) + @register_info(NAME) def _info() -> Dict[str, Any]: - categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) - return dict(categories=categories, wnids=wnids) + return dict(categories=CATEGORIES, wnids=WNIDS) class ImageNetResource(ManualDownloadResource): @@ -47,6 +52,75 @@ def __init__(self, **kwargs: Any) -> None: super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) +TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") + + +def prepare_train_data(data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: + path = pathlib.Path(data[0]) + wnid = cast(Match[str], TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] + label = Label.from_category(WNID_TO_CATEGORY[wnid], categories=CATEGORIES) + return (label, wnid), data + + +def prepare_test_data(data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: + return None, data + + +def classifiy_devkit(data: Tuple[str, BinaryIO]) -> Optional[int]: + return { + "meta.mat": 0, + "ILSVRC2012_validation_ground_truth.txt": 1, + }.get(pathlib.Path(data[0]).name) + + +# Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 +# and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment +WNID_MAP = { + "n03126707": "construction crane", + "n03710721": "tank suit", +} + + +def extract_categories_and_wnids(data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: + synsets = read_mat(data[1], squeeze_me=True)["synsets"] + return [ + (WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) + for _, wnid, category, _, num_children, *_ in synsets + # if num_children > 0, we are looking at a superclass that has no direct instance + if num_children == 0 + ] + + +def imagenet_label_to_wnid(imagenet_label: str) -> str: + return WNIDS[int(imagenet_label) - 1] + + +VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") + + +def val_test_image_key(path: pathlib.Path) -> int: + return int(VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] + + +def prepare_val_data( + data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] +) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: + label_data, image_data = data + _, wnid = label_data + label = Label.from_category(WNID_TO_CATEGORY[wnid], categories=CATEGORIES) + return (label, wnid), image_data + + +def prepare_sample(data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]]) -> Dict[str, Any]: + label_data, (path, buffer) = data + + return dict( + dict(zip(("label", "wnid"), label_data if label_data else (None, None))), + path=path, + image=EncodedImage.from_file(buffer), + ) + + @register_dataset(NAME) class ImageNet(Dataset2): def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: @@ -83,67 +157,6 @@ def _resources(self) -> List[OnlineResource]: return resources - _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") - - def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: - path = pathlib.Path(data[0]) - wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) - return (label, wnid), data - - def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: - return None, data - - def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: - return { - "meta.mat": 0, - "ILSVRC2012_validation_ground_truth.txt": 1, - }.get(pathlib.Path(data[0]).name) - - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment - _WNID_MAP = { - "n03126707": "construction crane", - "n03710721": "tank suit", - } - - def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: - synsets = read_mat(data[1], squeeze_me=True)["synsets"] - return [ - (self._WNID_MAP.get(wnid, category.split(",", 1)[0]), wnid) - for _, wnid, category, _, num_children, *_ in synsets - # if num_children > 0, we are looking at a superclass that has no direct instance - if num_children == 0 - ] - - def _imagenet_label_to_wnid(self, imagenet_label: str) -> str: - return self._wnids[int(imagenet_label) - 1] - - _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") - - def _val_test_image_key(self, path: pathlib.Path) -> int: - return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] - - def _prepare_val_data( - self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] - ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: - label_data, image_data = data - _, wnid = label_data - label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) - return (label, wnid), image_data - - def _prepare_sample( - self, - data: Tuple[Optional[Tuple[Label, str]], Tuple[str, BinaryIO]], - ) -> Dict[str, Any]: - label_data, (path, buffer) = data - - return dict( - dict(zip(("label", "wnid"), label_data if label_data else (None, None))), - path=path, - image=EncodedImage.from_file(buffer), - ) - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: if self._split in {"train", "test"}: dp = resource_dps[0] @@ -154,19 +167,19 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, dp = hint_shuffling(dp) dp = hint_sharding(dp) - dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data) + dp = Mapper(dp, prepare_train_data if self._split == "train" else prepare_test_data) else: # config.split == "val": images_dp, devkit_dp = resource_dps meta_dp, label_dp = Demultiplexer( - devkit_dp, 2, self._classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE + devkit_dp, 2, classifiy_devkit, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + meta_dp = Mapper(meta_dp, extract_categories_and_wnids) _, wnids = zip(*next(iter(meta_dp))) label_dp = LineReader(label_dp, decode=True, return_path=False) - label_dp = Mapper(label_dp, self._imagenet_label_to_wnid) + label_dp = Mapper(label_dp, imagenet_label_to_wnid) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp = hint_shuffling(label_dp) label_dp = hint_sharding(label_dp) @@ -175,12 +188,12 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, label_dp, images_dp, key_fn=getitem(0), - ref_key_fn=path_accessor(self._val_test_image_key), + ref_key_fn=path_accessor(val_test_image_key), buffer_size=INFINITE_BUFFER_SIZE, ) - dp = Mapper(dp, self._prepare_val_data) + dp = Mapper(dp, prepare_val_data) - return Mapper(dp, self._prepare_sample) + return Mapper(dp, prepare_sample) def __len__(self) -> int: return { @@ -195,7 +208,7 @@ def _generate_categories(self) -> List[Tuple[str, ...]]: devkit_dp = resources[1].load(self._root) meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + meta_dp = Mapper(meta_dp, extract_categories_and_wnids) categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1])