diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index ad979b6bd84..d178090c53c 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -214,58 +214,64 @@ def generate( return num_samples -# @register_mock -def mnist(info, root, config): - train = config.split == "train" - images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" - labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz" +def mnist(root, config): + prefix = "train" if config["split"] == "train" else "t10k" return MNISTMockData.generate( root, - num_categories=len(info.categories), - images_file=images_file, - labels_file=labels_file, + num_categories=10, + images_file=f"{prefix}-images-idx3-ubyte.gz", + labels_file=f"{prefix}-labels-idx1-ubyte.gz", ) -# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) +DATASET_MOCKS.update( + { + name: DatasetMock(name, mock_data_fn=mnist, configs=combinations_grid(split=("train", "test"))) + for name in ["mnist", "fashionmnist", "kmnist"] + } +) -# @register_mock -def emnist(info, root, config): - # The image sets that merge some lower case letters in their respective upper case variant, still use dense - # labels in the data files. Thus, num_categories != len(categories) there. - num_categories = defaultdict( - lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")} +@register_mock( + configs=combinations_grid( + split=("train", "test"), + image_set=("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), ) - +) +def emnist(root, config): num_samples_map = {} file_names = set() - for config_ in info._configs: - prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}" + for split, image_set in itertools.product( + ("train", "test"), + ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), + ): + prefix = f"emnist-{image_set.replace('_', '').lower()}-{split}" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" file_names.update({images_file, labels_file}) - num_samples_map[config_] = MNISTMockData.generate( + num_samples_map[(split, image_set)] = MNISTMockData.generate( root, - num_categories=num_categories[config_.image_set], + # The image sets that merge some lower case letters in their respective upper case variant, still use dense + # labels in the data files. Thus, num_categories != len(categories) there. + num_categories=47 if config["image_set"] in ("Balanced", "By_Merge") else 62, images_file=images_file, labels_file=labels_file, ) make_zip(root, "emnist-gzip.zip", *file_names) - return num_samples_map[config] + return num_samples_map[(config["split"], config["image_set"])] -# @register_mock -def qmnist(info, root, config): - num_categories = len(info.categories) - if config.split == "train": +@register_mock(configs=combinations_grid(split=("train", "test", "test10k", "test50k", "nist"))) +def qmnist(root, config): + num_categories = 10 + if config["split"] == "train": num_samples = num_samples_gen = num_categories + 2 prefix = "qmnist-train" suffix = ".gz" compressor = gzip.open - elif config.split.startswith("test"): + elif config["split"].startswith("test"): # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create # more than 10000 images for the dataset to not be empty. num_samples_gen = 10001 @@ -273,11 +279,11 @@ def qmnist(info, root, config): "test": num_samples_gen, "test10k": min(num_samples_gen, 10_000), "test50k": num_samples_gen - 10_000, - }[config.split] + }[config["split"]] prefix = "qmnist-test" suffix = ".gz" compressor = gzip.open - else: # config.split == "nist" + else: # config["split"] == "nist" num_samples = num_samples_gen = num_categories + 3 prefix = "xnist" suffix = ".xz" diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 8a929b6907c..d9ad4885b57 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -176,8 +176,7 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config): assert dp.buffer_size == INFINITE_BUFFER_SIZE -# FIXME: DATASET_MOCKS["qmnist"] -@parametrize_dataset_mocks({}) +@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: def test_extra_label(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 1e14e6dfc58..907faed49bd 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -7,12 +7,13 @@ import torch from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile -__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] +from .._api import register_dataset, register_info + prod = functools.partial(functools.reduce, operator.mul) @@ -57,18 +58,18 @@ def __iter__(self) -> Iterator[torch.Tensor]: yield read(dtype=dtype, count=count).reshape(shape) -class _MNISTBase(Dataset): +class _MNISTBase(Dataset2): _URL_BASE: Union[str, Sequence[str]] @abc.abstractmethod - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: pass - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: (images_file, images_sha256), ( labels_file, labels_sha256, - ) = self._files_and_checksums(config) + ) = self._files_and_checksums() url_bases = self._URL_BASE if isinstance(url_bases, str): @@ -82,21 +83,21 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [images, labels] - def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: + def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: return None, None - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + _categories: List[str] + + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: image, label = data return dict( image=Image(image), - label=Label(label, dtype=torch.int64, categories=self.categories), + label=Label(label, dtype=torch.int64, categories=self._categories), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, labels_dp = resource_dps - start, stop = self.start_and_stop(config) + start, stop = self.start_and_stop() images_dp = Decompressor(images_dp) images_dp = MNISTFileReader(images_dp, start=start, stop=stop) @@ -107,19 +108,31 @@ def _make_datapipe( dp = Zipper(images_dp, labels_dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) - return Mapper(dp, functools.partial(self._prepare_sample, config=config)) + return Mapper(dp, self._prepare_sample) + + +@register_info("mnist") +def _mnist_info() -> Dict[str, Any]: + return dict( + categories=[str(label) for label in range(10)], + ) +@register_dataset("mnist") class MNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "mnist", - categories=10, - homepage="http://yann.lecun.com/exdb/mnist", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: http://yann.lecun.com/exdb/mnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE: Union[str, Sequence[str]] = ( "http://yann.lecun.com/exdb/mnist", @@ -132,8 +145,8 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", } - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "train" if config.split == "train" else "t10k" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = "train" if self._split == "train" else "t10k" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" return (images_file, self._CHECKSUMS[images_file]), ( @@ -141,28 +154,35 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], self._CHECKSUMS[labels_file], ) + _categories = _mnist_info()["categories"] + + def __len__(self) -> int: + return 60_000 if self._split == "train" else 10_000 + + +@register_info("fashionmnist") +def _fashionmnist_info() -> Dict[str, Any]: + return dict( + categories=[ + "T-shirt/top", + "Trouser", + "Pullover", + "Dress", + "Coat", + "Sandal", + "Shirt", + "Sneaker", + "Bag", + "Ankle boot", + ], + ) + +@register_dataset("fashionmnist") class FashionMNIST(MNIST): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "fashionmnist", - categories=( - "T-shirt/top", - "Trouser", - "Pullover", - "Dress", - "Coat", - "Sandal", - "Shirt", - "Sneaker", - "Bag", - "Ankle boot", - ), - homepage="https://github.com/zalandoresearch/fashion-mnist", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: https://github.com/zalandoresearch/fashion-mnist + """ _URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com" _CHECKSUMS = { @@ -172,17 +192,21 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5", } + _categories = _fashionmnist_info()["categories"] + + +@register_info("kmnist") +def _kmnist_info() -> Dict[str, Any]: + return dict( + categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], + ) + +@register_dataset("kmnist") class KMNIST(MNIST): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "kmnist", - categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], - homepage="http://codh.rois.ac.jp/kmnist/index.html.en", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en + """ _URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist" _CHECKSUMS = { @@ -192,36 +216,46 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c", } + _categories = _kmnist_info()["categories"] + + +@register_info("emnist") +def _emnist_info() -> Dict[str, Any]: + return dict( + categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), + ) + +@register_dataset("emnist") class EMNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "emnist", - categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), - homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", - valid_options=dict( - split=("train", "test"), - image_set=( - "Balanced", - "By_Merge", - "By_Class", - "Letters", - "Digits", - "MNIST", - ), - ), + """ + - **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + image_set: str = "Balanced", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + self._image_set = self._verify_str_arg( + image_set, "image_set", ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST") ) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST" - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = f"emnist-{self._image_set.replace('_', '').lower()}-{self._split}" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" - # Since EMNIST provides the data files inside an archive, we don't need provide checksums for them + # Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them return (images_file, ""), (labels_file, "") - def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ HttpResource( f"{self._URL_BASE}/emnist-gzip.zip", @@ -229,9 +263,9 @@ def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResour ) ] - def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: path = pathlib.Path(data[0]) - (images_file, _), (labels_file, _) = self._files_and_checksums(config) + (images_file, _), (labels_file, _) = self._files_and_checksums() if path.name == images_file: return 0 elif path.name == labels_file: @@ -239,6 +273,8 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> else: return None + _categories = _emnist_info()["categories"] + _LABEL_OFFSETS = { 38: 1, 39: 1, @@ -251,45 +287,71 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> 46: 9, } - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, - # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, - # since there is no 'c', 'd' corresponds to + # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For + # example, since there is no 'c', 'd' corresponds to # label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing), # and at the same time corresponds to # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) - # in self.categories. Thus, we need to add 1 to the label to correct this. - if config.image_set in ("Balanced", "By_Merge"): + # in self._categories. Thus, we need to add 1 to the label to correct this. + if self._image_set in ("Balanced", "By_Merge"): image, label = data label += self._LABEL_OFFSETS.get(int(label), 0) data = (image, label) - return super()._prepare_sample(data, config=config) + return super()._prepare_sample(data) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, labels_dp = Demultiplexer( archive_dp, 2, - functools.partial(self._classify_archive, config=config), + self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - return super()._make_datapipe([images_dp, labels_dp], config=config) + return super()._datapipe([images_dp, labels_dp]) + + def __len__(self) -> int: + return { + ("train", "Balanced"): 112_800, + ("train", "By_Merge"): 697_932, + ("train", "By_Class"): 697_932, + ("train", "Letters"): 124_800, + ("train", "Digits"): 240_000, + ("train", "MNIST"): 60_000, + ("test", "Balanced"): 18_800, + ("test", "By_Merge"): 116_323, + ("test", "By_Class"): 116_323, + ("test", "Letters"): 20_800, + ("test", "Digits"): 40_000, + ("test", "MNIST"): 10_000, + }[(self._split, self._image_set)] + + +@register_info("qmnist") +def _qmnist_info() -> Dict[str, Any]: + return dict( + categories=[str(label) for label in range(10)], + ) +@register_dataset("qmnist") class QMNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "qmnist", - categories=10, - homepage="https://github.com/facebookresearch/qmnist", - valid_options=dict( - split=("train", "test", "test10k", "test50k", "nist"), - ), - ) + """ + - **homepage**: https://github.com/facebookresearch/qmnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test", "test10k", "test50k", "nist")) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master" _CHECKSUMS = { @@ -301,9 +363,9 @@ def _make_info(self) -> DatasetInfo: "xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f", } - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "xnist" if config.split == "nist" else f"qmnist-{'train' if config.split== 'train' else 'test'}" - suffix = "xz" if config.split == "nist" else "gz" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = "xnist" if self._split == "nist" else f"qmnist-{'train' if self._split == 'train' else 'test'}" + suffix = "xz" if self._split == "nist" else "gz" images_file = f"{prefix}-images-idx3-ubyte.{suffix}" labels_file = f"{prefix}-labels-idx2-int.{suffix}" return (images_file, self._CHECKSUMS[images_file]), ( @@ -311,13 +373,13 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], self._CHECKSUMS[labels_file], ) - def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: + def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: start: Optional[int] stop: Optional[int] - if config.split == "test10k": + if self._split == "test10k": start = 0 stop = 10000 - elif config.split == "test50k": + elif self._split == "test50k": start = 10000 stop = None else: @@ -325,10 +387,12 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional return start, stop - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + _categories = _emnist_info()["categories"] + + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: image, ann = data label, *extra_anns = ann - sample = super()._prepare_sample((image, label), config=config) + sample = super()._prepare_sample((image, label)) sample.update( dict( @@ -340,3 +404,12 @@ def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: Da ) sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]]))) return sample + + def __len__(self) -> int: + return { + "train": 60_000, + "test": 60_000, + "test10k": 10_000, + "test50k": 50_000, + "nist": 402_953, + }[self._split]