From c293be96cd43200cf32a39ec5e8fa9a6ae0b9deb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 09:25:17 +0200 Subject: [PATCH] migrate CIFAR prototype datasets --- test/builtin_dataset_mocks.py | 12 +-- .../prototype/datasets/_builtin/cifar.py | 99 +++++++++++-------- 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index ad979b6bd84..759f356b55a 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -334,8 +334,8 @@ def generate( make_tar(root, name, folder, compression="gz") -# @register_mock -def cifar10(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test"))) +def cifar10(root, config): train_files = [f"data_batch_{idx}" for idx in range(1, 6)] test_files = ["test_batch"] @@ -349,11 +349,11 @@ def cifar10(info, root, config): labels_key="labels", ) - return len(train_files if config.split == "train" else test_files) + return len(train_files if config["split"] == "train" else test_files) -# @register_mock -def cifar100(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test"))) +def cifar100(root, config): train_files = ["train"] test_files = ["test"] @@ -367,7 +367,7 @@ def cifar100(info, root, config): labels_key="fine_labels", ) - return len(train_files if config.split == "train" else test_files) + return len(train_files if config["split"] == "train" else test_files) # @register_mock diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 3d7acefb903..9274aa543d4 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -1,9 +1,8 @@ import abc -import functools import io import pathlib import pickle -from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -11,20 +10,12 @@ Filter, Mapper, ) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) -from torchvision.prototype.datasets.utils._internal import ( - hint_shuffling, - path_comparator, - hint_sharding, -) +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_shuffling, path_comparator, hint_sharding, BUILTIN_DIR from torchvision.prototype.features import Label, Image +from .._api import register_dataset, register_info + class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: @@ -38,25 +29,29 @@ def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]: yield from iter(zip(image_arrays, category_idcs)) -class _CifarBase(Dataset): +class _CifarBase(Dataset2): _FILE_NAME: str _SHA256: str _LABELS_KEY: str _META_FILE_NAME: str _CATEGORIES_KEY: str + # _categories: List[str] + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + super().__init__(root, skip_integrity_check=skip_integrity_check) @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]: + def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]: pass - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - type(self).__name__.lower(), - homepage="https://www.cs.toronto.edu/~kriz/cifar.html", - valid_options=dict(split=("train", "test")), - ) - - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ HttpResource( f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}", @@ -72,52 +67,78 @@ def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data return dict( image=Image(image_array), - label=Label(category_idx, categories=self.categories), + label=Label(category_idx, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] - dp = Filter(dp, functools.partial(self._is_data_file, split=config.split)) + dp = Filter(dp, self._is_data_file) dp = Mapper(dp, self._unpickle) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def __len__(self) -> int: + return 50_000 if self._split == "train" else 10_000 + + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp = Mapper(dp, self._unpickle) return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) +CIFAR10_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar10.categories")) + + +@register_info("cifar10") +def _cifar10_info() -> Dict[str, Any]: + return dict(categories=CIFAR10_CATEGORIES) + + +@register_dataset("cifar10") class Cifar10(_CifarBase): + """ + - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html + """ + _FILE_NAME = "cifar-10-python.tar.gz" _SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" _LABELS_KEY = "labels" _META_FILE_NAME = "batches.meta" _CATEGORIES_KEY = "label_names" + _categories = _cifar10_info()["categories"] - def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool: + def _is_data_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) - return path.name.startswith("data" if split == "train" else "test") + return path.name.startswith("data" if self._split == "train" else "test") +CIFAR100_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar100.categories")) + + +@register_info("cifar100") +def _cifar100_info() -> Dict[str, Any]: + return dict(categories=CIFAR10_CATEGORIES) + + +@register_dataset("cifar100") class Cifar100(_CifarBase): + """ + - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html + """ + _FILE_NAME = "cifar-100-python.tar.gz" _SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" _LABELS_KEY = "fine_labels" _META_FILE_NAME = "meta" _CATEGORIES_KEY = "fine_label_names" + _categories = _cifar100_info()["categories"] - def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool: + def _is_data_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) - return path.name == split + return path.name == self._split