From 0aae4278a51702b4d65b5ae78aeda19964967cf2 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Feb 2022 17:02:57 +0100 Subject: [PATCH 01/30] refactor prototype datasets to inherit from IterDataPipe (#5448) * refactor prototype datasets to inherit from IterDataPipe * depend on new architecture * fix missing file detection * remove unrelated file * reinstante decorator for mock registering * options -> config * remove passing of info to mock data functions * refactor categories file generation --- test/builtin_dataset_mocks.py | 100 +++++++------- test/test_prototype_builtin_datasets.py | 47 +++---- torchvision/prototype/datasets/__init__.py | 3 +- torchvision/prototype/datasets/_api.py | 60 +++++---- .../prototype/datasets/_builtin/imagenet.py | 122 ++++++++---------- .../datasets/generate_category_files.py | 10 +- .../prototype/datasets/utils/__init__.py | 2 +- .../prototype/datasets/utils/_dataset.py | 40 +++++- 8 files changed, 210 insertions(+), 174 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 123d8f29d3f..1d988196190 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -9,6 +9,7 @@ import pathlib import pickle import random +import unittest.mock import xml.etree.ElementTree as ET from collections import defaultdict, Counter @@ -16,10 +17,10 @@ import PIL.Image import pytest import torch -from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file +from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor -from torchvision.prototype.datasets._api import find +from torchvision.prototype import datasets from torchvision.prototype.utils._internal import sequence_to_str make_tensor = functools.partial(_make_tensor, device="cpu") @@ -30,13 +31,11 @@ class DatasetMock: - def __init__(self, name, mock_data_fn): - self.dataset = find(name) - self.info = self.dataset.info - self.name = self.info.name - + def __init__(self, name, *, mock_data_fn, configs): + # FIXME: error handling for unknown names + self.name = name self.mock_data_fn = mock_data_fn - self.configs = self.info._configs + self.configs = configs def _parse_mock_info(self, mock_info): if mock_info is None: @@ -65,10 +64,13 @@ def prepare(self, home, config): root = home / self.name root.mkdir(exist_ok=True) - mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config)) + mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) + with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"): + required_file_names = { + resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() + } available_file_names = {path.name for path in root.glob("*")} - required_file_names = {resource.file_name for resource in self.dataset.resources(config)} missing_file_names = required_file_names - available_file_names if missing_file_names: raise pytest.UsageError( @@ -123,10 +125,16 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): DATASET_MOCKS = {} -def register_mock(fn): - name = fn.__name__.replace("_", "-") - DATASET_MOCKS[name] = DatasetMock(name, fn) - return fn +def register_mock(name=None, *, configs): + def wrapper(mock_data_fn): + nonlocal name + if name is None: + name = mock_data_fn.__name__ + DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs) + + return mock_data_fn + + return wrapper class MNISTMockData: @@ -204,7 +212,7 @@ def generate( return num_samples -@register_mock +# @register_mock def mnist(info, root, config): train = config.split == "train" images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" @@ -217,10 +225,10 @@ def mnist(info, root, config): ) -DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) +# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) -@register_mock +# @register_mock def emnist(info, root, config): # The image sets that merge some lower case letters in their respective upper case variant, still use dense # labels in the data files. Thus, num_categories != len(categories) there. @@ -247,7 +255,7 @@ def emnist(info, root, config): return num_samples_map[config] -@register_mock +# @register_mock def qmnist(info, root, config): num_categories = len(info.categories) if config.split == "train": @@ -324,7 +332,7 @@ def generate( make_tar(root, name, folder, compression="gz") -@register_mock +# @register_mock def cifar10(info, root, config): train_files = [f"data_batch_{idx}" for idx in range(1, 6)] test_files = ["test_batch"] @@ -342,7 +350,7 @@ def cifar10(info, root, config): return len(train_files if config.split == "train" else test_files) -@register_mock +# @register_mock def cifar100(info, root, config): train_files = ["train"] test_files = ["test"] @@ -360,7 +368,7 @@ def cifar100(info, root, config): return len(train_files if config.split == "train" else test_files) -@register_mock +# @register_mock def caltech101(info, root, config): def create_ann_file(root, name): import scipy.io @@ -410,7 +418,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples): return num_images_per_category * len(info.categories) -@register_mock +# @register_mock def caltech256(info, root, config): dir = root / "256_ObjectCategories" num_images_per_category = 2 @@ -430,18 +438,18 @@ def caltech256(info, root, config): return num_images_per_category * len(info.categories) -@register_mock -def imagenet(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def imagenet(root, config): from scipy.io import savemat - categories = info.categories - wnids = [info.extra.category_to_wnid[category] for category in categories] - if config.split == "train": - num_samples = len(wnids) + info = datasets.info("imagenet") + + if config["split"] == "train": + num_samples = len(info["wnids"]) archive_name = "ILSVRC2012_img_train.tar" files = [] - for wnid in wnids: + for wnid in info["wnids"]: create_image_folder( root=root, name=wnid, @@ -449,7 +457,7 @@ def imagenet(info, root, config): num_examples=1, ) files.append(make_tar(root, f"{wnid}.tar")) - elif config.split == "val": + elif config["split"] == "val": num_samples = 3 archive_name = "ILSVRC2012_img_val.tar" files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] @@ -459,20 +467,20 @@ def imagenet(info, root, config): data_root.mkdir(parents=True) with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: - for label in torch.randint(0, len(wnids), (num_samples,)).tolist(): + for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist(): file.write(f"{label}\n") num_children = 0 synsets = [ (idx, wnid, category, "", num_children, [], 0, 0) - for idx, (category, wnid) in enumerate(zip(categories, wnids), 1) + for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1) ] num_children = 1 synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) savemat(data_root / "meta.mat", dict(synsets=synsets)) make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") - else: # config.split == "test" + else: # config["split"] == "test" num_samples = 5 archive_name = "ILSVRC2012_img_test_v10102019.tar" files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] @@ -587,7 +595,7 @@ def generate( return num_samples -@register_mock +# @register_mock def coco(info, root, config): return CocoMockData.generate(root, year=config.year, num_samples=5) @@ -661,12 +669,12 @@ def generate(cls, root): return num_samples_map -@register_mock +# @register_mock def sbd(info, root, config): return SBDMockData.generate(root)[config.split] -@register_mock +# @register_mock def semeion(info, root, config): num_samples = 3 num_categories = len(info.categories) @@ -779,7 +787,7 @@ def generate(cls, root, *, year, trainval): return num_samples_map -@register_mock +# @register_mock def voc(info, root, config): trainval = config.split != "test" return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split] @@ -873,12 +881,12 @@ def generate(cls, root): return num_samples_map -@register_mock +# @register_mock def celeba(info, root, config): return CelebAMockData.generate(root)[config.split] -@register_mock +# @register_mock def dtd(info, root, config): data_folder = root / "dtd" @@ -926,7 +934,7 @@ def dtd(info, root, config): return num_samples_map[config] -@register_mock +# @register_mock def fer2013(info, root, config): num_samples = 5 if config.split == "train" else 3 @@ -951,7 +959,7 @@ def fer2013(info, root, config): return num_samples -@register_mock +# @register_mock def gtsrb(info, root, config): num_examples_per_class = 5 if config.split == "train" else 3 classes = ("00000", "00042", "00012") @@ -1021,7 +1029,7 @@ def _make_ann_file(path, num_examples, class_idx): return num_examples -@register_mock +# @register_mock def clevr(info, root, config): data_folder = root / "CLEVR_v1.0" @@ -1127,7 +1135,7 @@ def generate(self, root): return num_samples_map -@register_mock +# @register_mock def oxford_iiit_pet(info, root, config): return OxfordIIITPetMockData.generate(root)[config.split] @@ -1293,13 +1301,13 @@ def generate(cls, root): return num_samples_map -@register_mock +# @register_mock def cub200(info, root, config): num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) return num_samples_map[config.split] -@register_mock +# @register_mock def svhn(info, root, config): import scipy.io as sio @@ -1319,7 +1327,7 @@ def svhn(info, root, config): return num_samples -@register_mock +# @register_mock def pcam(info, root, config): import h5py diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index eaa92094ad7..0ba042bcda5 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -6,9 +6,8 @@ import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair -from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse -from torchdata.datapipes.iter import IterDataPipe, Shuffler +from torchdata.datapipes.iter import Shuffler, ShardingFilter from torchvision.prototype import transforms, datasets from torchvision.prototype.utils._internal import sequence_to_str @@ -35,14 +34,24 @@ def test_coverage(): class TestCommon: + @pytest.mark.parametrize("name", datasets.list_datasets()) + def test_info(self, name): + try: + info = datasets.info(name) + except ValueError: + raise AssertionError("No info available.") from None + + if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())): + raise AssertionError("Info should be a dictionary with string keys.") + @parametrize_dataset_mocks(DATASET_MOCKS) def test_smoke(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) - if not isinstance(dataset, IterDataPipe): - raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") + if not isinstance(dataset, datasets.utils.Dataset2): + raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) def test_sample(self, test_home, dataset_mock, config): @@ -67,24 +76,7 @@ def test_num_samples(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) - num_samples = 0 - for _ in dataset: - num_samples += 1 - - assert num_samples == mock_info["num_samples"] - - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_decoding(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) - - dataset = datasets.load(dataset_mock.name, **config) - - undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)} - if undecoded_features: - raise AssertionError( - f"The values of key(s) " - f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." - ) + assert len(list(dataset)) == mock_info["num_samples"] @parametrize_dataset_mocks(DATASET_MOCKS) def test_no_vanilla_tensors(self, test_home, dataset_mock, config): @@ -107,6 +99,7 @@ def test_transformable(self, test_home, dataset_mock, config): next(iter(dataset.map(transforms.Identity()))) + @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks( DATASET_MOCKS, marks={ @@ -122,6 +115,7 @@ def test_traversable(self, test_home, dataset_mock, config): traverse(dataset) + @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks( DATASET_MOCKS, marks={ @@ -138,7 +132,6 @@ def scan(graph): yield from scan(sub_graph) dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): @@ -156,7 +149,8 @@ def test_save_load(self, test_home, dataset_mock, config): assert_samples_equal(torch.load(buffer), sample) -@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) +# FIXME: DATASET_MOCKS["qmnist"] +@parametrize_dataset_mocks({}) class TestQMNIST: def test_extra_label(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) @@ -176,12 +170,13 @@ def test_extra_label(self, test_home, dataset_mock, config): assert key in sample and isinstance(sample[key], type) -@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"]) +# FIXME: DATASET_MOCKS["gtsrb"] +@parametrize_dataset_mocks({}) class TestGTSRB: def test_label_matches_path(self, test_home, dataset_mock, config): # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # This test makes sure that they're both the same - if config.split != "train": + if config["split"] != "train": return dataset_mock.prepare(test_home, config) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index bf99e175d36..44c66e422f2 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -11,5 +11,6 @@ from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import list_datasets, info, load # usort: skip +from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip from ._folder import from_data_folder, from_image_folder +from ._builtin import * diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 13ee920cea2..8f8bb53deb4 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,39 +1,50 @@ -import os -from typing import Any, Dict, List +import pathlib +from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar -from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo +from torchvision.prototype.datasets.utils import Dataset2 from torchvision.prototype.utils._internal import add_suggestion -from . import _builtin -DATASETS: Dict[str, Dataset] = {} +T = TypeVar("T") +D = TypeVar("D", bound=Type[Dataset2]) +BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register(dataset: Dataset) -> None: - DATASETS[dataset.name] = dataset +def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: + def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: + BUILTIN_INFOS[name] = fn() + return fn -for name, obj in _builtin.__dict__.items(): - if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset: - register(obj()) + return wrapper + + +BUILTIN_DATASETS = {} + + +def register_dataset(name: str) -> Callable[[D], D]: + def wrapper(dataset_cls: D) -> D: + BUILTIN_DATASETS[name] = dataset_cls + return dataset_cls + + return wrapper def list_datasets() -> List[str]: - return sorted(DATASETS.keys()) + return sorted(BUILTIN_DATASETS.keys()) -def find(name: str) -> Dataset: +def find(dct: Dict[str, T], name: str) -> T: name = name.lower() try: - return DATASETS[name] + return dct[name] except KeyError as error: raise ValueError( add_suggestion( f"Unknown dataset '{name}'.", word=name, - possibilities=DATASETS.keys(), + possibilities=dct.keys(), alternative_hint=lambda _: ( "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." ), @@ -41,19 +52,14 @@ def find(name: str) -> Dataset: ) from error -def info(name: str) -> DatasetInfo: - return find(name).info +def info(name: str) -> Dict[str, Any]: + return find(BUILTIN_INFOS, name) -def load( - name: str, - *, - skip_integrity_check: bool = False, - **options: Any, -) -> IterDataPipe[Dict[str, Any]]: - dataset = find(name) +def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2: + dataset_cls = find(BUILTIN_DATASETS, name) - config = dataset.info.make_config(**options) - root = os.path.join(home(), dataset.name) + if root is None: + root = pathlib.Path(home()) / name - return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) + return dataset_cls(root, **config) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 0d11b642c13..6f91d4c4a8d 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,16 +1,14 @@ -import functools import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer from torchdata.datapipes.iter import TarArchiveReader from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, DatasetInfo, OnlineResource, ManualDownloadResource, + Dataset2, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -21,9 +19,20 @@ read_mat, hint_sharding, hint_shuffling, + path_accessor, ) from torchvision.prototype.features import Label, EncodedImage -from torchvision.prototype.utils._internal import FrozenMapping + +from .._api import register_dataset, register_info + + +NAME = "imagenet" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + return dict(categories=categories, wnids=wnids) class ImageNetResource(ManualDownloadResource): @@ -31,32 +40,18 @@ def __init__(self, **kwargs: Any) -> None: super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) -class ImageNet(Dataset): - def _make_info(self) -> DatasetInfo: - name = "imagenet" - categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) - - return DatasetInfo( - name, - dependencies=("scipy",), - categories=categories, - homepage="https://www.image-net.org/", - valid_options=dict(split=("train", "val", "test")), - extra=dict( - wnid_to_category=FrozenMapping(zip(wnids, categories)), - category_to_wnid=FrozenMapping(zip(categories, wnids)), - sizes=FrozenMapping( - [ - (DatasetConfig(split="train"), 1_281_167), - (DatasetConfig(split="val"), 50_000), - (DatasetConfig(split="test"), 100_000), - ] - ), - ), - ) +@register_dataset(NAME) +class ImageNet(Dataset2): + def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - def supports_sharded(self) -> bool: - return True + info = _info() + categories, wnids = info["categories"], info["wnids"] + self._categories: List[str] = categories + self._wnids: List[str] = wnids + self._wnid_to_category = dict(zip(wnids, categories)) + + super().__init__(root) _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", @@ -64,15 +59,15 @@ def supports_sharded(self) -> bool: "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - name = "test_v10102019" if config.split == "test" else config.split + def _resources(self) -> List[OnlineResource]: + name = "test_v10102019" if self._split == "test" else self._split images = ImageNetResource( file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name], ) resources: List[OnlineResource] = [images] - if config.split == "val": + if self._split == "val": devkit = ImageNetResource( file_name="ILSVRC2012_devkit_t12.tar.gz", sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", @@ -81,19 +76,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return resources - def num_samples(self, config: DatasetConfig) -> int: - return { - "train": 1_281_167, - "val": 50_000, - "test": 100_000, - }[config.split] - _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] - label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) return (label, wnid), data def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: @@ -105,6 +93,13 @@ def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: "ILSVRC2012_validation_ground_truth.txt": 1, }.get(pathlib.Path(data[0]).name) + # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 + # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment + _WNID_MAP = { + "n03126707": "construction crane", + "n03710721": "tank suit", + } + def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: synsets = read_mat(data[1], squeeze_me=True)["synsets"] return [ @@ -114,21 +109,20 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl if num_children == 0 ] - def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: List[str]) -> str: - return wnids[int(imagenet_label) - 1] + def _imagenet_label_to_wnid(self, imagenet_label: str) -> str: + return self._wnids[int(imagenet_label) - 1] _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") - def _val_test_image_key(self, data: Tuple[str, Any]) -> int: - path = pathlib.Path(data[0]) - return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] + def _val_test_image_key(self, path: pathlib.Path) -> int: + return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] def _prepare_val_data( self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: label_data, image_data = data _, wnid = label_data - label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) return (label, wnid), image_data def _prepare_sample( @@ -143,19 +137,17 @@ def _prepare_sample( image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: - if config.split in {"train", "test"}: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + if self._split in {"train", "test"}: dp = resource_dps[0] # the train archive is a tar of tars - if config.split == "train": + if self._split == "train": dp = TarArchiveReader(dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data) + dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data) else: # config.split == "val": images_dp, devkit_dp = resource_dps @@ -167,7 +159,7 @@ def _make_datapipe( _, wnids = zip(*next(iter(meta_dp))) label_dp = LineReader(label_dp, decode=True, return_path=False) - label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) + label_dp = Mapper(label_dp, self._imagenet_label_to_wnid) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp = hint_sharding(label_dp) label_dp = hint_shuffling(label_dp) @@ -176,25 +168,25 @@ def _make_datapipe( label_dp, images_dp, key_fn=getitem(0), - ref_key_fn=self._val_test_image_key, + ref_key_fn=path_accessor(self._val_test_image_key), buffer_size=INFINITE_BUFFER_SIZE, ) dp = Mapper(dp, self._prepare_val_data) return Mapper(dp, self._prepare_sample) - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment - _WNID_MAP = { - "n03126707": "construction crane", - "n03710721": "tank suit", - } + def __len__(self) -> int: + return { + "train": 1_281_167, + "val": 50_000, + "test": 100_000, + }[self._split] - def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: - config = self.info.make_config(split="val") - resources = self.resources(config) + def _generate_categories(self) -> List[Tuple[str, ...]]: + self._split = "val" + resources = self._resources() - devkit_dp = resources[1].load(root) + devkit_dp = resources[1].load(self._root) meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index 3c2bf7e73cb..ac35eddb28b 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -2,25 +2,21 @@ import argparse import csv -import pathlib import sys from torchvision.prototype import datasets -from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR def main(*names, force=False): - home = pathlib.Path(datasets.home()) - for name in names: path = BUILTIN_DIR / f"{name}.categories" if path.exists() and not force: continue - dataset = find(name) + dataset = datasets.load(name) try: - categories = dataset._generate_categories(home / name) + categories = dataset._generate_categories() except NotImplementedError: continue @@ -55,7 +51,7 @@ def parse_args(argv=None): if __name__ == "__main__": - args = parse_args() + args = parse_args(["-f", "imagenet"]) try: main(*args.names, force=args.force) diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 9423b65a8ee..a16a839b594 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip -from ._dataset import DatasetConfig, DatasetInfo, Dataset +from ._dataset import DatasetConfig, DatasetInfo, Dataset, Dataset2 from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 5ee7c5ccc60..7200f00fd02 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -4,9 +4,10 @@ import itertools import os import pathlib -from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection +from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection, Iterator from torch.utils.data import IterDataPipe +from torchvision.datasets.utils import verify_str_arg from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str from .._home import use_sharded_dataset @@ -181,3 +182,40 @@ def load( def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: raise NotImplementedError + + +class Dataset2(IterDataPipe[Dict[str, Any]], abc.ABC): + @staticmethod + def _verify_str_arg( + value: str, + arg: Optional[str] = None, + valid_values: Optional[Collection[str]] = None, + *, + custom_msg: Optional[str] = None, + ) -> str: + return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg) + + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + self._root = pathlib.Path(root).expanduser().resolve() + resources = [ + resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources() + ] + self._dp = self._datapipe(resources) + + def __iter__(self) -> Iterator[Dict[str, Any]]: + yield from self._dp + + @abc.abstractmethod + def _resources(self) -> List[OnlineResource]: + pass + + @abc.abstractmethod + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + pass + + @abc.abstractmethod + def __len__(self) -> int: + pass + + def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]: + raise NotImplementedError From c5f6c11c1b2b0b2e58647cf0af0df93860e25dae Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 31 Mar 2022 16:10:45 +0200 Subject: [PATCH 02/30] fix imagenet --- .../prototype/datasets/_builtin/imagenet.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 220c5edf17a..fb507af01b0 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,3 +1,5 @@ +import enum +import functools import pathlib import re from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union @@ -21,7 +23,6 @@ from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, BUILTIN_DIR, - path_comparator, getitem, read_mat, hint_sharding, @@ -32,7 +33,6 @@ from .._api import register_dataset, register_info - NAME = "imagenet" @@ -47,6 +47,11 @@ def __init__(self, **kwargs: Any) -> None: super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) +class ImageNetDemux(enum.IntEnum): + META = 0 + LABEL = 1 + + @register_dataset(NAME) class ImageNet(Dataset2): def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: @@ -96,8 +101,8 @@ def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[st def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: return { - "meta.mat": 0, - "ILSVRC2012_validation_ground_truth.txt": 1, + "meta.mat": ImageNetDemux.META, + "ILSVRC2012_validation_ground_truth.txt": ImageNetDemux.LABEL, }.get(pathlib.Path(data[0]).name) # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 @@ -116,8 +121,8 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl if num_children == 0 ] - def _imagenet_label_to_wnid(self, imagenet_label: str) -> str: - return self._wnids[int(imagenet_label) - 1] + def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: Tuple[str, ...]) -> str: + return wnids[int(imagenet_label) - 1] _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") @@ -166,7 +171,8 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, _, wnids = zip(*next(iter(meta_dp))) label_dp = LineReader(label_dp, decode=True, return_path=False) - label_dp = Mapper(label_dp, self._imagenet_label_to_wnid) + # We cannot use self._wnids here, since we use a different order than the dataset + label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp = hint_shuffling(label_dp) label_dp = hint_sharding(label_dp) @@ -189,12 +195,15 @@ def __len__(self) -> int: "test": 100_000, }[self._split] + def _filter_meta(self, data: Tuple[str, Any]) -> bool: + return self._classifiy_devkit(data) == ImageNetDemux.META + def _generate_categories(self) -> List[Tuple[str, ...]]: self._split = "val" resources = self._resources() devkit_dp = resources[1].load(self._root) - meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) + meta_dp = Filter(devkit_dp, self._filter_meta) meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) From 9f12ef4eac796cda9630cead727b274376c0143a Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 15:11:26 +0200 Subject: [PATCH 03/30] fix prototype datasets data loading tests (#5711) * reenable serialization test * cleanup * fix dill test * trigger CI * patch DILL_AVAILABLE for pickle serialization * revert CI changes * remove dill test and traversable test * add data loader test * parametrize over only_datapipe * draw one sample rather than exhaust data loader * cleanup * trigger CI --- test/test_prototype_builtin_datasets.py | 27 ++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index f8dc3a0542b..8a929b6907c 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -7,6 +7,7 @@ import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair +from torch.utils.data import DataLoader from torch.utils.data.graph import traverse from torch.utils.data.graph_settings import get_all_graph_pipes from torchdata.datapipes.iter import Shuffler, ShardingFilter @@ -109,19 +110,39 @@ def test_transformable(self, test_home, dataset_mock, config): next(iter(dataset.map(transforms.Identity()))) - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") + @pytest.mark.parametrize("only_datapipe", [False, True]) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_serializable(self, test_home, dataset_mock, config): + def test_traversable(self, test_home, dataset_mock, config, only_datapipe): dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + traverse(dataset, only_datapipe=only_datapipe) + + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_serializable(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) dataset = datasets.load(dataset_mock.name, **config) pickle.dumps(dataset) + @pytest.mark.parametrize("num_workers", [0, 1]) + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_data_loader(self, test_home, dataset_mock, config, num_workers): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + + dl = DataLoader( + dataset, + batch_size=2, + num_workers=num_workers, + collate_fn=lambda batch: batch, + ) + + next(iter(dl)) + # TODO: we need to enforce not only that both a Shuffler and a ShardingFilter are part of the datapipe, but also # that the Shuffler comes before the ShardingFilter. Early commits in https://github.com/pytorch/vision/pull/5680 # contain a custom test for that, but we opted to wait for a potential solution / test from torchdata for now. - @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks(DATASET_MOCKS) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): From aca416435193013aaf139c462ae3997dfe30252f Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 5 Apr 2022 17:02:43 +0200 Subject: [PATCH 04/30] migrate VOC prototype dataset (#5743) * migrate VOC prototype dataset * cleanup * revert unrelated mock data changes * remove categories annotations * move properties to constructor * readd homepage --- test/builtin_dataset_mocks.py | 21 +- .../prototype/datasets/_builtin/imagenet.py | 4 +- .../prototype/datasets/_builtin/voc.py | 186 +++++++++++------- 3 files changed, 129 insertions(+), 82 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index c4f51463e34..ad979b6bd84 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -792,10 +792,23 @@ def generate(cls, root, *, year, trainval): return num_samples_map -# @register_mock -def voc(info, root, config): - trainval = config.split != "test" - return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split] +@register_mock( + configs=[ + *combinations_grid( + split=("train", "val", "trainval"), + year=("2007", "2008", "2009", "2010", "2011", "2012"), + task=("detection", "segmentation"), + ), + *combinations_grid( + split=("test",), + year=("2007",), + task=("detection", "segmentation"), + ), + ], +) +def voc(root, config): + trainval = config["split"] != "test" + return VOCMockData.generate(root, year=config["year"], trainval=trainval)[config["split"]] class CelebAMockData: diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index fb507af01b0..638878d5ec3 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -59,8 +59,8 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N info = _info() categories, wnids = info["categories"], info["wnids"] - self._categories: List[str] = categories - self._wnids: List[str] = wnids + self._categories = categories + self._wnids = wnids self._wnid_to_category = dict(zip(wnids, categories)) super().__init__(root) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 5c1d3f8c3a3..d000bdbe0e7 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -1,6 +1,7 @@ +import enum import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Callable +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, cast, Union from xml.etree import ElementTree from torchdata.datapipes.iter import ( @@ -12,13 +13,7 @@ LineReader, ) from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import DatasetInfo, OnlineResource, HttpResource, Dataset2 from torchvision.prototype.datasets.utils._internal import ( path_accessor, getitem, @@ -26,34 +21,50 @@ path_comparator, hint_sharding, hint_shuffling, + BUILTIN_DIR, ) from torchvision.prototype.features import BoundingBox, Label, EncodedImage +from .._api import register_dataset, register_info + +NAME = "voc" + +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) + -class VOCDatasetInfo(DatasetInfo): - def __init__(self, *args: Any, **kwargs: Any): - super().__init__(*args, **kwargs) - self._configs = tuple(config for config in self._configs if config.split != "test" or config.year == "2007") +@register_dataset(NAME) +class VOC(Dataset2): + """ + - **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/ + """ - def make_config(self, **options: Any) -> DatasetConfig: - config = super().make_config(**options) - if config.split == "test" and config.year != "2007": + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2012", + task: str = "detection", + **kwargs: Any, + ) -> None: + self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012")) + if split == "test" and year != "2007": raise ValueError("`split='test'` is only available for `year='2007'`") + else: + self._split = self._verify_str_arg(split, "split", ("train", "val", "trainval", "test")) + self._task = self._verify_str_arg(task, "task", ("detection", "segmentation")) - return config + self._anns_folder = "Annotations" if task == "detection" else "SegmentationClass" + self._split_folder = "Main" if task == "detection" else "Segmentation" + self._categories = _info()["categories"] -class VOC(Dataset): - def _make_info(self) -> DatasetInfo: - return VOCDatasetInfo( - "voc", - homepage="http://host.robots.ox.ac.uk/pascal/VOC/", - valid_options=dict( - split=("train", "val", "trainval", "test"), - year=("2012", "2007", "2008", "2009", "2010", "2011"), - task=("detection", "segmentation"), - ), - ) + super().__init__(root, **kwargs) _TRAIN_VAL_ARCHIVES = { "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), @@ -67,31 +78,27 @@ def _make_info(self) -> DatasetInfo: "2007": ("VOCtest_06-Nov-2007.tar", "6836888e2e01dca84577a849d339fa4f73e1e4f135d312430c4856b5609b4892") } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - file_name, sha256 = (self._TEST_ARCHIVES if config.split == "test" else self._TRAIN_VAL_ARCHIVES)[config.year] - archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{config.year}/{file_name}", sha256=sha256) + def _resources(self) -> List[OnlineResource]: + file_name, sha256 = (self._TEST_ARCHIVES if self._split == "test" else self._TRAIN_VAL_ARCHIVES)[self._year] + archive = HttpResource(f"http://host.robots.ox.ac.uk/pascal/VOC/voc{self._year}/{file_name}", sha256=sha256) return [archive] - _ANNS_FOLDER = dict( - detection="Annotations", - segmentation="SegmentationClass", - ) - _SPLIT_FOLDER = dict( - detection="Main", - segmentation="Segmentation", - ) - def _is_in_folder(self, data: Tuple[str, Any], *, name: str, depth: int = 1) -> bool: path = pathlib.Path(data[0]) return name in path.parent.parts[-depth:] - def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: + class _Demux(enum.IntEnum): + SPLIT = 0 + IMAGES = 1 + ANNS = 2 + + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: if self._is_in_folder(data, name="ImageSets", depth=2): - return 0 + return self._Demux.SPLIT elif self._is_in_folder(data, name="JPEGImages"): - return 1 - elif self._is_in_folder(data, name=self._ANNS_FOLDER[config.task]): - return 2 + return self._Demux.IMAGES + elif self._is_in_folder(data, name=self._anns_folder): + return self._Demux.ANNS else: return None @@ -111,7 +118,7 @@ def _prepare_detection_ann(self, buffer: BinaryIO) -> Dict[str, Any]: image_size=cast(Tuple[int, int], tuple(int(anns["size"][dim]) for dim in ("height", "width"))), ), labels=Label( - [self.categories.index(instance["name"]) for instance in instances], categories=self.categories + [self._categories.index(instance["name"]) for instance in instances], categories=self._categories ), ) @@ -121,8 +128,6 @@ def _prepare_segmentation_ann(self, buffer: BinaryIO) -> Dict[str, Any]: def _prepare_sample( self, data: Tuple[Tuple[Tuple[str, str], Tuple[str, BinaryIO]], Tuple[str, BinaryIO]], - *, - prepare_ann_fn: Callable[[BinaryIO], Dict[str, Any]], ) -> Dict[str, Any]: split_and_image_data, ann_data = data _, image_data = split_and_image_data @@ -130,29 +135,24 @@ def _prepare_sample( ann_path, ann_buffer = ann_data return dict( - prepare_ann_fn(ann_buffer), + (self._prepare_detection_ann if self._task == "detection" else self._prepare_segmentation_ann)(ann_buffer), image_path=image_path, image=EncodedImage.from_file(image_buffer), ann_path=ann_path, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] split_dp, images_dp, anns_dp = Demultiplexer( archive_dp, 3, - functools.partial(self._classify_archive, config=config), + self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._SPLIT_FOLDER[config.task])) - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, functools.partial(self._is_in_folder, name=self._split_folder)) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_shuffling(split_dp) split_dp = hint_sharding(split_dp) @@ -166,25 +166,59 @@ def _make_datapipe( ref_key_fn=path_accessor("stem"), buffer_size=INFINITE_BUFFER_SIZE, ) - return Mapper( - dp, - functools.partial( - self._prepare_sample, - prepare_ann_fn=self._prepare_detection_ann - if config.task == "detection" - else self._prepare_segmentation_ann, - ), - ) - - def _filter_detection_anns(self, data: Tuple[str, Any], *, config: DatasetConfig) -> bool: - return self._classify_archive(data, config=config) == 2 - - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(task="detection") - - resource = self.resources(config)[0] - dp = resource.load(pathlib.Path(root) / self.name) - dp = Filter(dp, self._filter_detection_anns, fn_kwargs=dict(config=config)) + return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + ("train", "2007", "detection"): 2_501, + ("train", "2007", "segmentation"): 209, + ("train", "2008", "detection"): 2_111, + ("train", "2008", "segmentation"): 511, + ("train", "2009", "detection"): 3_473, + ("train", "2009", "segmentation"): 749, + ("train", "2010", "detection"): 4_998, + ("train", "2010", "segmentation"): 964, + ("train", "2011", "detection"): 5_717, + ("train", "2011", "segmentation"): 1_112, + ("train", "2012", "detection"): 5_717, + ("train", "2012", "segmentation"): 1_464, + ("val", "2007", "detection"): 2_510, + ("val", "2007", "segmentation"): 213, + ("val", "2008", "detection"): 2_221, + ("val", "2008", "segmentation"): 512, + ("val", "2009", "detection"): 3_581, + ("val", "2009", "segmentation"): 750, + ("val", "2010", "detection"): 5_105, + ("val", "2010", "segmentation"): 964, + ("val", "2011", "detection"): 5_823, + ("val", "2011", "segmentation"): 1_111, + ("val", "2012", "detection"): 5_823, + ("val", "2012", "segmentation"): 1_449, + ("trainval", "2007", "detection"): 5_011, + ("trainval", "2007", "segmentation"): 422, + ("trainval", "2008", "detection"): 4_332, + ("trainval", "2008", "segmentation"): 1_023, + ("trainval", "2009", "detection"): 7_054, + ("trainval", "2009", "segmentation"): 1_499, + ("trainval", "2010", "detection"): 10_103, + ("trainval", "2010", "segmentation"): 1_928, + ("trainval", "2011", "detection"): 11_540, + ("trainval", "2011", "segmentation"): 2_223, + ("trainval", "2012", "detection"): 11_540, + ("trainval", "2012", "segmentation"): 2_913, + ("test", "2007", "detection"): 4_952, + ("test", "2007", "segmentation"): 210, + }[(self._split, self._year, self._task)] + + def _filter_anns(self, data: Tuple[str, Any]) -> bool: + return self._classify_archive(data) == self._Demux.ANNS + + def _generate_categories(self) -> List[str]: + self._task = "detection" + resources = self._resources() + + archive_dp = resources[0].load(self._root) + dp = Filter(archive_dp, self._filter_detection_anns) dp = Mapper(dp, self._parse_detection_ann, input_col=1) return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) From dead87dbe7f36714c2fd82d9fb90ffbe25555e6b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 11:01:34 +0200 Subject: [PATCH 05/30] migrate CIFAR prototype datasets (#5751) --- test/builtin_dataset_mocks.py | 12 +-- .../prototype/datasets/_builtin/cifar.py | 99 +++++++++++-------- 2 files changed, 66 insertions(+), 45 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index ad979b6bd84..759f356b55a 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -334,8 +334,8 @@ def generate( make_tar(root, name, folder, compression="gz") -# @register_mock -def cifar10(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test"))) +def cifar10(root, config): train_files = [f"data_batch_{idx}" for idx in range(1, 6)] test_files = ["test_batch"] @@ -349,11 +349,11 @@ def cifar10(info, root, config): labels_key="labels", ) - return len(train_files if config.split == "train" else test_files) + return len(train_files if config["split"] == "train" else test_files) -# @register_mock -def cifar100(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test"))) +def cifar100(root, config): train_files = ["train"] test_files = ["test"] @@ -367,7 +367,7 @@ def cifar100(info, root, config): labels_key="fine_labels", ) - return len(train_files if config.split == "train" else test_files) + return len(train_files if config["split"] == "train" else test_files) # @register_mock diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 3d7acefb903..9274aa543d4 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -1,9 +1,8 @@ import abc -import functools import io import pathlib import pickle -from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, Iterator, cast, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -11,20 +10,12 @@ Filter, Mapper, ) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) -from torchvision.prototype.datasets.utils._internal import ( - hint_shuffling, - path_comparator, - hint_sharding, -) +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import hint_shuffling, path_comparator, hint_sharding, BUILTIN_DIR from torchvision.prototype.features import Label, Image +from .._api import register_dataset, register_info + class CifarFileReader(IterDataPipe[Tuple[np.ndarray, int]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]], *, labels_key: str) -> None: @@ -38,25 +29,29 @@ def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]: yield from iter(zip(image_arrays, category_idcs)) -class _CifarBase(Dataset): +class _CifarBase(Dataset2): _FILE_NAME: str _SHA256: str _LABELS_KEY: str _META_FILE_NAME: str _CATEGORIES_KEY: str + # _categories: List[str] + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + super().__init__(root, skip_integrity_check=skip_integrity_check) @abc.abstractmethod - def _is_data_file(self, data: Tuple[str, BinaryIO], *, split: str) -> Optional[int]: + def _is_data_file(self, data: Tuple[str, BinaryIO]) -> Optional[int]: pass - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - type(self).__name__.lower(), - homepage="https://www.cs.toronto.edu/~kriz/cifar.html", - valid_options=dict(split=("train", "test")), - ) - - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ HttpResource( f"https://www.cs.toronto.edu/~kriz/{self._FILE_NAME}", @@ -72,52 +67,78 @@ def _prepare_sample(self, data: Tuple[np.ndarray, int]) -> Dict[str, Any]: image_array, category_idx = data return dict( image=Image(image_array), - label=Label(category_idx, categories=self.categories), + label=Label(category_idx, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] - dp = Filter(dp, functools.partial(self._is_data_file, split=config.split)) + dp = Filter(dp, self._is_data_file) dp = Mapper(dp, self._unpickle) dp = CifarFileReader(dp, labels_key=self._LABELS_KEY) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def __len__(self) -> int: + return 50_000 if self._split == "train" else 10_000 + + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", self._META_FILE_NAME)) dp = Mapper(dp, self._unpickle) return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) +CIFAR10_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar10.categories")) + + +@register_info("cifar10") +def _cifar10_info() -> Dict[str, Any]: + return dict(categories=CIFAR10_CATEGORIES) + + +@register_dataset("cifar10") class Cifar10(_CifarBase): + """ + - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html + """ + _FILE_NAME = "cifar-10-python.tar.gz" _SHA256 = "6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce" _LABELS_KEY = "labels" _META_FILE_NAME = "batches.meta" _CATEGORIES_KEY = "label_names" + _categories = _cifar10_info()["categories"] - def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool: + def _is_data_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) - return path.name.startswith("data" if split == "train" else "test") + return path.name.startswith("data" if self._split == "train" else "test") +CIFAR100_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar100.categories")) + + +@register_info("cifar100") +def _cifar100_info() -> Dict[str, Any]: + return dict(categories=CIFAR10_CATEGORIES) + + +@register_dataset("cifar100") class Cifar100(_CifarBase): + """ + - **homepage**: https://www.cs.toronto.edu/~kriz/cifar.html + """ + _FILE_NAME = "cifar-100-python.tar.gz" _SHA256 = "85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7" _LABELS_KEY = "fine_labels" _META_FILE_NAME = "meta" _CATEGORIES_KEY = "fine_label_names" + _categories = _cifar100_info()["categories"] - def _is_data_file(self, data: Tuple[str, Any], *, split: str) -> bool: + def _is_data_file(self, data: Tuple[str, Any]) -> bool: path = pathlib.Path(data[0]) - return path.name == split + return path.name == self._split From 6a0592fd5099e6a701f7be733ba746a3c4e98c7e Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 11:02:13 +0200 Subject: [PATCH 06/30] migrate country211 prototype dataset (#5753) --- test/builtin_dataset_mocks.py | 11 +-- .../prototype/datasets/_builtin/country211.py | 72 ++++++++++++------- 2 files changed, 50 insertions(+), 33 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 759f356b55a..cbe54cd4b75 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -904,14 +904,9 @@ def celeba(info, root, config): return CelebAMockData.generate(root)[config.split] -# @register_mock -def country211(info, root, config): - split_name_mapper = { - "train": "train", - "val": "valid", - "test": "test", - } - split_folder = pathlib.Path(root, "country211", split_name_mapper[config["split"]]) +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def country211(root, config): + split_folder = pathlib.Path(root, "country211", "valid" if config["split"] == "val" else config["split"]) split_folder.mkdir(parents=True, exist_ok=True) num_examples = { diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 0b4dc306734..ae0564b224b 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -1,21 +1,44 @@ import pathlib -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling, BUILTIN_DIR from torchvision.prototype.features import EncodedImage, Label +from .._api import register_dataset, register_info -class Country211(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "country211", - homepage="https://github.com/openai/CLIP/blob/main/data/country211.md", - valid_options=dict(split=("train", "val", "test")), - ) +NAME = "country211" + +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) + + +@register_dataset(NAME) +class Country211(Dataset2): + """ + - **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md + """ - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) + self._split_folder_name = "valid" if split == "val" else split + + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( "https://openaipublic.azureedge.net/clip/data/country211.tgz", @@ -23,17 +46,11 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) ] - _SPLIT_NAME_MAPPER = { - "train": "train", - "val": "valid", - "test": "test", - } - def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name return dict( - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) @@ -41,16 +58,21 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: def _filter_split(self, data: Tuple[str, Any], *, split: str) -> bool: return pathlib.Path(data[0]).parent.parent.name == split - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] - dp = Filter(dp, path_comparator("parent.parent.name", self._SPLIT_NAME_MAPPER[config.split])) + dp = Filter(dp, path_comparator("parent.parent.name", self._split_folder_name)) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) - dp = resources[0].load(root) + def __len__(self) -> int: + return { + "train": 31_650, + "val": 10_550, + "test": 21_100, + }[self._split] + + def _generate_categories(self) -> List[str]: + resources = self.resources() + dp = resources[0].load(self.root) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) From 2ed549dcfc4333b47dced1f3820dd6148ca846aa Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 11:02:28 +0200 Subject: [PATCH 07/30] migrate CLEVR prototype datsaet (#5752) --- test/builtin_dataset_mocks.py | 6 +- .../prototype/datasets/_builtin/clevr.py | 56 ++++++++++--------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index cbe54cd4b75..0210a4dacec 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1108,8 +1108,8 @@ def _make_ann_file(path, num_examples, class_idx): return num_examples -# @register_mock -def clevr(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def clevr(root, config): data_folder = root / "CLEVR_v1.0" num_samples_map = { @@ -1150,7 +1150,7 @@ def clevr(info, root, config): make_zip(root, f"{data_folder.name}.zip", data_folder) - return num_samples_map[config.split] + return num_samples_map[config["split"]] class OxfordIIITPetMockData: diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index dd08a257a5b..9d322de084c 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -1,14 +1,8 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, hint_sharding, @@ -19,16 +13,30 @@ ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info + +NAME = "clevr" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict() -class CLEVR(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "clevr", - homepage="https://cs.stanford.edu/people/jcjohns/clevr/", - valid_options=dict(split=("train", "val", "test")), - ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class CLEVR(Dataset2): + """ + - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/ + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://dl.fbaipublicfiles.com/clevr/CLEVR_v1.0.zip", sha256="5cd61cf1096ed20944df93c9adb31e74d189b8459a94f54ba00090e5c59936d1", @@ -61,12 +69,7 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Optional[Dict[str, A label=Label(len(scenes_data["objects"])) if scenes_data else None, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, scenes_dp = Demultiplexer( archive_dp, @@ -76,12 +79,12 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) - images_dp = Filter(images_dp, path_comparator("parent.name", config.split)) + images_dp = Filter(images_dp, path_comparator("parent.name", self._split)) images_dp = hint_shuffling(images_dp) images_dp = hint_sharding(images_dp) - if config.split != "test": - scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{config.split}_scenes.json")) + if self._split != "test": + scenes_dp = Filter(scenes_dp, path_comparator("name", f"CLEVR_{self._split}_scenes.json")) scenes_dp = JsonParser(scenes_dp) scenes_dp = Mapper(scenes_dp, getitem(1, "scenes")) scenes_dp = UnBatcher(scenes_dp) @@ -97,3 +100,6 @@ def _make_datapipe( dp = Mapper(images_dp, self._add_empty_anns) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 70_000 if self._split == "train" else 15_000 From 42bc682241c129b1596123bd737f0a07daeb9809 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 14:58:46 +0200 Subject: [PATCH 08/30] migrate coco prototype (#5473) * migrate coco prototype * revert unrelated change * add kwargs to super constructor call * remove unneeded changes * fix docstring position * make kwargs explicit * add dependencies to docstring * fix missing dependency message --- test/builtin_dataset_mocks.py | 12 +- .../prototype/datasets/_builtin/coco.py | 145 ++++++++++-------- .../prototype/datasets/_builtin/imagenet.py | 14 +- .../prototype/datasets/_builtin/voc.py | 4 +- .../datasets/generate_category_files.py | 2 +- .../prototype/datasets/utils/_dataset.py | 13 +- 6 files changed, 118 insertions(+), 72 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 0210a4dacec..eef7275f967 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -600,9 +600,15 @@ def generate( return num_samples -# @register_mock -def coco(info, root, config): - return CocoMockData.generate(root, year=config.year, num_samples=5) +@register_mock( + configs=combinations_grid( + split=("train", "val"), + year=("2017", "2014"), + annotations=("instances", "captions", None), + ) +) +def coco(root, config): + return CocoMockData.generate(root, year=config["year"], num_samples=5) class SBDMockData: diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 1005c7b3130..75896a8db08 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -1,8 +1,8 @@ -import functools import pathlib import re from collections import OrderedDict -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union import torch from torchdata.datapipes.iter import ( @@ -16,11 +16,10 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, DatasetInfo, HttpResource, OnlineResource, + Dataset2, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, @@ -32,27 +31,51 @@ hint_shuffling, ) from torchvision.prototype.features import BoundingBox, Label, _Feature, EncodedImage -from torchvision.prototype.utils._internal import FrozenMapping - - -class Coco(Dataset): - def _make_info(self) -> DatasetInfo: - name = "coco" - categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) - - return DatasetInfo( - name, - dependencies=("pycocotools",), - categories=categories, - homepage="https://cocodataset.org/", - valid_options=dict( - split=("train", "val"), - year=("2017", "2014"), - annotations=(*self._ANN_DECODERS.keys(), None), - ), - extra=dict(category_to_super_category=FrozenMapping(zip(categories, super_categories))), + +from .._api import register_dataset, register_info + + +NAME = "coco" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + return dict(categories=categories, super_categories=super_categories) + + +@register_dataset(NAME) +class Coco(Dataset2): + """ + - **homepage**: https://cocodataset.org/ + - **dependencies**: + - _ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2017", + annotations: Optional[str] = "instances", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val"}) + self._year = self._verify_str_arg(year, "year", {"2017", "2014"}) + self._annotations = ( + self._verify_str_arg(annotations, "annotations", self._ANN_DECODERS.keys()) + if annotations is not None + else None ) + info = _info() + categories, super_categories = info["categories"], info["super_categories"] + self._categories = categories + self._category_to_super_category = dict(zip(categories, super_categories)) + + super().__init__(root, dependencies=("pycocotools",), skip_integrity_check=skip_integrity_check) + _IMAGE_URL_BASE = "http://images.cocodataset.org/zips" _IMAGES_CHECKSUMS = { @@ -69,14 +92,14 @@ def _make_info(self) -> DatasetInfo: "2017": "113a836d90195ee1f884e704da6304dfaaecff1f023f49b6ca93c4aaae470268", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: images = HttpResource( - f"{self._IMAGE_URL_BASE}/{config.split}{config.year}.zip", - sha256=self._IMAGES_CHECKSUMS[(config.year, config.split)], + f"{self._IMAGE_URL_BASE}/{self._split}{self._year}.zip", + sha256=self._IMAGES_CHECKSUMS[(self._year, self._split)], ) meta = HttpResource( - f"{self._META_URL_BASE}/annotations_trainval{config.year}.zip", - sha256=self._META_CHECKSUMS[config.year], + f"{self._META_URL_BASE}/annotations_trainval{self._year}.zip", + sha256=self._META_CHECKSUMS[self._year], ) return [images, meta] @@ -110,10 +133,8 @@ def _decode_instances_anns(self, anns: List[Dict[str, Any]], image_meta: Dict[st format="xywh", image_size=image_size, ), - labels=Label(labels, categories=self.categories), - super_categories=[ - self.info.extra.category_to_super_category[self.info.categories[label]] for label in labels - ], + labels=Label(labels, categories=self._categories), + super_categories=[self._category_to_super_category[self._categories[label]] for label in labels], ann_ids=[ann["id"] for ann in anns], ) @@ -134,9 +155,14 @@ def _decode_captions_ann(self, anns: List[Dict[str, Any]], image_meta: Dict[str, fr"(?P({'|'.join(_ANN_DECODERS.keys())}))_(?P[a-zA-Z]+)(?P\d+)[.]json" ) - def _filter_meta_files(self, data: Tuple[str, Any], *, split: str, year: str, annotations: str) -> bool: + def _filter_meta_files(self, data: Tuple[str, Any]) -> bool: match = self._META_FILE_PATTERN.match(pathlib.Path(data[0]).name) - return bool(match and match["split"] == split and match["year"] == year and match["annotations"] == annotations) + return bool( + match + and match["split"] == self._split + and match["year"] == self._year + and match["annotations"] == self._annotations + ) def _classify_meta(self, data: Tuple[str, Any]) -> Optional[int]: key, _ = data @@ -157,38 +183,26 @@ def _prepare_image(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: def _prepare_sample( self, data: Tuple[Tuple[List[Dict[str, Any]], Dict[str, Any]], Tuple[str, BinaryIO]], - *, - annotations: str, ) -> Dict[str, Any]: ann_data, image_data = data anns, image_meta = ann_data sample = self._prepare_image(image_data) + # this method is only called if we have annotations + annotations = cast(str, self._annotations) sample.update(self._ANN_DECODERS[annotations](self, anns, image_meta)) return sample - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, meta_dp = resource_dps - if config.annotations is None: + if self._annotations is None: dp = hint_shuffling(images_dp) dp = hint_sharding(dp) + dp = hint_shuffling(dp) return Mapper(dp, self._prepare_image) - meta_dp = Filter( - meta_dp, - functools.partial( - self._filter_meta_files, - split=config.split, - year=config.year, - annotations=config.annotations, - ), - ) + meta_dp = Filter(meta_dp, self._filter_meta_files) meta_dp = JsonParser(meta_dp) meta_dp = Mapper(meta_dp, getitem(1)) meta_dp: IterDataPipe[Dict[str, Dict[str, Any]]] = MappingIterator(meta_dp) @@ -216,7 +230,6 @@ def _make_datapipe( ref_key_fn=getitem("id"), buffer_size=INFINITE_BUFFER_SIZE, ) - dp = IterKeyZipper( anns_dp, images_dp, @@ -224,18 +237,24 @@ def _make_datapipe( ref_key_fn=path_accessor("name"), buffer_size=INFINITE_BUFFER_SIZE, ) + return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + ("train", "2017"): defaultdict(lambda: 118_287, instances=117_266), + ("train", "2014"): defaultdict(lambda: 82_783, instances=82_081), + ("val", "2017"): defaultdict(lambda: 5_000, instances=4_952), + ("val", "2014"): defaultdict(lambda: 40_504, instances=40_137), + }[(self._split, self._year)][ + self._annotations # type: ignore[index] + ] - return Mapper(dp, functools.partial(self._prepare_sample, annotations=config.annotations)) - - def _generate_categories(self, root: pathlib.Path) -> Tuple[Tuple[str, str]]: - config = self.default_config - resources = self.resources(config) + def _generate_categories(self) -> Tuple[Tuple[str, str]]: + self._annotations = "instances" + resources = self._resources() - dp = resources[1].load(root) - dp = Filter( - dp, - functools.partial(self._filter_meta_files, split=config.split, year=config.year, annotations="instances"), - ) + dp = resources[1].load(self._root) + dp = Filter(dp, self._filter_meta_files) dp = JsonParser(dp) _, meta = next(iter(dp)) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 638878d5ec3..56accca02b4 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -54,7 +54,17 @@ class ImageNetDemux(enum.IntEnum): @register_dataset(NAME) class ImageNet(Dataset2): - def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: + """ + - **homepage**: https://www.image-net.org/ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) info = _info() @@ -63,7 +73,7 @@ def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> N self._wnids = wnids self._wnid_to_category = dict(zip(wnids, categories)) - super().__init__(root) + super().__init__(root, skip_integrity_check=skip_integrity_check) _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index d000bdbe0e7..91b82794e27 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -50,7 +50,7 @@ def __init__( split: str = "train", year: str = "2012", task: str = "detection", - **kwargs: Any, + skip_integrity_check: bool = False, ) -> None: self._year = self._verify_str_arg(year, "year", ("2007", "2008", "2009", "2010", "2011", "2012")) if split == "test" and year != "2007": @@ -64,7 +64,7 @@ def __init__( self._categories = _info()["categories"] - super().__init__(root, **kwargs) + super().__init__(root, skip_integrity_check=skip_integrity_check) _TRAIN_VAL_ARCHIVES = { "2007": ("VOCtrainval_06-Nov-2007.tar", "7d8cd951101b0957ddfd7a530bdc8a94f06121cfc1e511bb5937e973020c7508"), diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index ac35eddb28b..6d4e854fe34 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -51,7 +51,7 @@ def parse_args(argv=None): if __name__ == "__main__": - args = parse_args(["-f", "imagenet"]) + args = parse_args() try: main(*args.names, force=args.force) diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 69180040194..a6ec05c3ff4 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -196,7 +196,18 @@ def _verify_str_arg( ) -> str: return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg) - def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + def __init__( + self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False, dependencies: Collection[str] = () + ) -> None: + for dependency in dependencies: + try: + importlib.import_module(dependency) + except ModuleNotFoundError: + raise ModuleNotFoundError( + f"{type(self).__name__}() depends on the third-party package '{dependency}'. " + f"Please install it, for example with `pip install {dependency}`." + ) from None + self._root = pathlib.Path(root).expanduser().resolve() resources = [ resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources() From 27104fe98be3d1f6730f1418a6e410aec3b191b1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:00:18 +0100 Subject: [PATCH 09/30] Migrate PCAM prototype dataset (#5745) * Port PCAM * skip_integrity_check * Update torchvision/prototype/datasets/_builtin/pcam.py Co-authored-by: Philip Meier * Address comments Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 8 +-- .../prototype/datasets/_builtin/pcam.py | 54 ++++++++++++------- 2 files changed, 39 insertions(+), 23 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index eef7275f967..cc8568154ed 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1430,13 +1430,13 @@ def svhn(info, root, config): return num_samples -# @register_mock -def pcam(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def pcam(root, config): import h5py - num_images = {"train": 2, "test": 3, "val": 4}[config.split] + num_images = {"train": 2, "test": 3, "val": 4}[config["split"]] - split = "valid" if config.split == "val" else config.split + split = "valid" if config["split"] == "val" else config["split"] images_io = io.BytesIO() with h5py.File(images_io, "w") as f: diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 3d7b9547a76..1ae94da5665 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -1,13 +1,13 @@ import io +import pathlib from collections import namedtuple -from typing import Any, Dict, List, Optional, Tuple, Iterator +from typing import Any, Dict, List, Optional, Tuple, Iterator, Union +from unicodedata import category from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, OnlineResource, GDriveResource, ) @@ -17,6 +17,11 @@ ) from torchvision.prototype.features import Label +from .._api import register_dataset, register_info + + +NAME = "pcam" + class PCAMH5Reader(IterDataPipe[Tuple[str, io.IOBase]]): def __init__( @@ -40,15 +45,25 @@ def __iter__(self) -> Iterator[Tuple[str, io.IOBase]]: _Resource = namedtuple("_Resource", ("file_name", "gdrive_id", "sha256")) -class PCAM(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "pcam", - homepage="https://github.com/basveeling/pcam", - categories=2, - valid_options=dict(split=("train", "test", "val")), - dependencies=["h5py"], - ) +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=["0", "1"]) + + +@register_dataset(NAME) +class PCAM(Dataset2): + # TODO write proper docstring + """PCAM Dataset + + homepage="https://github.com/basveeling/pcam" + """ + + def __init__( + self, root: Union[str, pathlib.Path], split: str = "train", *, skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("h5py",)) _RESOURCES = { "train": ( @@ -89,10 +104,10 @@ def _make_info(self) -> DatasetInfo: ), } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ # = [images resource, targets resource] GDriveResource(file_name=file_name, id=gdrive_id, sha256=sha256, preprocess="decompress") - for file_name, gdrive_id, sha256 in self._RESOURCES[config.split] + for file_name, gdrive_id, sha256 in self._RESOURCES[self._split] ] def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: @@ -100,12 +115,10 @@ def _prepare_sample(self, data: Tuple[Any, Any]) -> Dict[str, Any]: return { "image": features.Image(image.transpose(2, 0, 1)), - "label": Label(target.item()), + "label": Label(target.item(), categories=self._categories), } - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, targets_dp = resource_dps @@ -116,3 +129,6 @@ def _make_datapipe( dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self): + return 262_144 if self._split == "train" else 32_768 From 291be31c448eed4300b79c8cd37ba32c548845fe Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:03:04 +0100 Subject: [PATCH 10/30] Migrate DTD prototype dataset (#5757) * Migrate DTD prototype dataset * Docstring * Apply suggestions from code review Co-authored-by: Philip Meier Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 8 +-- .../prototype/datasets/_builtin/dtd.py | 72 ++++++++++++------- 2 files changed, 51 insertions(+), 29 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index cc8568154ed..5c9e657c2c7 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -971,8 +971,8 @@ def food101(info, root, config): return num_samples_map[config.split] -# @register_mock -def dtd(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "val", "test"), fold=(1, 4, 10))) +def dtd(root, config): data_folder = root / "dtd" num_images_per_class = 3 @@ -1012,11 +1012,11 @@ def dtd(info, root, config): with open(meta_folder / f"{split}{fold}.txt", "w") as file: file.write("\n".join(image_ids_in_config) + "\n") - num_samples_map[info.make_config(split=split, fold=str(fold))] = len(image_ids_in_config) + num_samples_map[(split, fold)] = len(image_ids_in_config) make_tar(root, "dtd-r1.0.1.tar.gz", data_folder, compression="gz") - return num_samples_map[config] + return num_samples_map[config["split"], config["fold"]] # @register_mock diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index 682fed2d9c2..a5de1359e4e 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -1,11 +1,10 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, + Dataset2, DatasetInfo, HttpResource, OnlineResource, @@ -14,11 +13,17 @@ INFINITE_BUFFER_SIZE, hint_sharding, path_comparator, + BUILTIN_DIR, getitem, hint_shuffling, ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info + + +NAME = "dtd" + class DTDDemux(enum.IntEnum): SPLIT = 0 @@ -26,18 +31,37 @@ class DTDDemux(enum.IntEnum): IMAGES = 2 -class DTD(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "dtd", - homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", - valid_options=dict( - split=("train", "test", "val"), - fold=tuple(str(fold) for fold in range(1, 11)), - ), - ) +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = [c[0] for c in categories] + return dict(categories=categories) + + +@register_dataset(NAME) +class DTD(Dataset2): + """DTD Dataset. + homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", + """ + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + fold: int = 1, + skip_validation_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) + + if not (1 <= fold <= 10): + raise ValueError(f"The fold parameter should be an integer in [1, 10]. Got {fold}") + self._fold = fold + + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_validation_check) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz", sha256="e42855a52a4950a3b59612834602aa253914755c95b0cff9ead6d07395f8e205", @@ -71,24 +95,19 @@ def _prepare_sample(self, data: Tuple[Tuple[str, List[str]], Tuple[str, BinaryIO return dict( joint_categories={category for category in joint_categories if category}, - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] splits_dp, joint_categories_dp, images_dp = Demultiplexer( archive_dp, 3, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) - splits_dp = Filter(splits_dp, path_comparator("name", f"{config.split}{config.fold}.txt")) + splits_dp = Filter(splits_dp, path_comparator("name", f"{self._split}{self._fold}.txt")) splits_dp = LineReader(splits_dp, decode=True, return_path=False) splits_dp = hint_shuffling(splits_dp) splits_dp = hint_sharding(splits_dp) @@ -114,10 +133,13 @@ def _make_datapipe( def _filter_images(self, data: Tuple[str, Any]) -> bool: return self._classify_archive(data) == DTDDemux.IMAGES - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def _generate_categories(self) -> List[str]: + resources = self.resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, self._filter_images) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) + + def __len__(self) -> int: + return 1_880 # All splits have the same length From 217616b4dc62f684c0901a81682ac3b917d5f987 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:03:34 +0100 Subject: [PATCH 11/30] Migrate GTSRB prototype dataset (#5746) * Migrate GTSRB prototype dataset * ufmt * Address comments * Apparently mypy doesn't know that __len__ returns ints. How cute. * why is the CI not triggered?? * Update torchvision/prototype/datasets/_builtin/gtsrb.py Co-authored-by: Philip Meier Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 6 +- .../prototype/datasets/_builtin/gtsrb.py | 55 ++++++++++++------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 5c9e657c2c7..0d2d8f5f76e 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1044,9 +1044,9 @@ def fer2013(info, root, config): return num_samples -# @register_mock -def gtsrb(info, root, config): - num_examples_per_class = 5 if config.split == "train" else 3 +@register_mock(configs=combinations_grid(split=("train", "test"))) +def gtsrb(root, config): + num_examples_per_class = 5 if config["split"] == "train" else 3 classes = ("00000", "00042", "00012") num_examples = num_examples_per_class * len(classes) diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index c08d8947292..fa29f3be780 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -1,11 +1,9 @@ import pathlib -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, OnlineResource, HttpResource, ) @@ -17,15 +15,31 @@ ) from torchvision.prototype.features import Label, BoundingBox, EncodedImage +from .._api import register_dataset, register_info + +NAME = "gtsrb" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + categories=[f"{label:05d}" for label in range(43)], + ) -class GTSRB(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "gtsrb", - homepage="https://benchmark.ini.rub.de", - categories=[f"{label:05d}" for label in range(43)], - valid_options=dict(split=("train", "test")), - ) + +@register_dataset(NAME) +class GTSRB(Dataset2): + """GTSRB Dataset + + homepage="https://benchmark.ini.rub.de" + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_ROOT = "https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/" _URLS = { @@ -39,10 +53,10 @@ def _make_info(self) -> DatasetInfo: "test_ground_truth": "f94e5a7614d75845c74c04ddb26b8796b9e483f43541dd95dd5b726504e16d6d", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - rsrcs: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUMS[config.split])] + def _resources(self) -> List[OnlineResource]: + rsrcs: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUMS[self._split])] - if config.split == "test": + if self._split == "test": rsrcs.append( HttpResource( self._URLS["test_ground_truth"], @@ -74,14 +88,12 @@ def _prepare_sample(self, data: Tuple[Tuple[str, Any], Dict[str, Any]]) -> Dict[ return { "path": path, "image": EncodedImage.from_file(buffer), - "label": Label(label, categories=self.categories), + "label": Label(label, categories=self._categories), "bounding_box": bounding_box, } - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: - if config.split == "train": + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + if self._split == "train": images_dp, ann_dp = Demultiplexer( resource_dps[0], 2, self._classify_train_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) @@ -98,3 +110,6 @@ def _make_datapipe( dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 26_640 if self._split == "train" else 12_630 From 2612c4cae8c31a65b49363280c4f0275c167d804 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 15:04:11 +0200 Subject: [PATCH 12/30] migrate CelebA prototype dataset (#5750) * migrate CelebA prototype dataset * inline split_id --- test/builtin_dataset_mocks.py | 6 +- .../prototype/datasets/_builtin/celeba.py | 73 ++++++++++++------- 2 files changed, 48 insertions(+), 31 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 0d2d8f5f76e..f0ee2370aa1 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -905,9 +905,9 @@ def generate(cls, root): return num_samples_map -# @register_mock -def celeba(info, root, config): - return CelebAMockData.generate(root)[config.split] +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def celeba(root, config): + return CelebAMockData.generate(root)[config["split"]] @register_mock(configs=combinations_grid(split=("train", "val", "test"))) diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 854c705b746..17a42082f3f 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -1,6 +1,6 @@ import csv -import functools -from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO +import pathlib +from typing import Any, Dict, List, Optional, Tuple, Iterator, Sequence, BinaryIO, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -10,9 +10,7 @@ IterKeyZipper, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, GDriveResource, OnlineResource, ) @@ -25,6 +23,7 @@ ) from torchvision.prototype.features import EncodedImage, _Feature, Label, BoundingBox +from .._api import register_dataset, register_info csv.register_dialect("celeba", delimiter=" ", skipinitialspace=True) @@ -60,15 +59,32 @@ def __iter__(self) -> Iterator[Tuple[str, Dict[str, str]]]: yield line.pop("image_id"), line -class CelebA(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "celeba", - homepage="https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html", - valid_options=dict(split=("train", "val", "test")), - ) +NAME = "celeba" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict() + - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class CelebA(Dataset2): + """ + - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "test")) + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: splits = GDriveResource( "0B7EVK8r0v71pY0NSMzRuSXJEVkk", sha256="fc955bcb3ef8fbdf7d5640d9a8693a8431b5f2ee291a5c1449a1549e7e073fe7", @@ -101,14 +117,13 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: ) return [splits, images, identities, attributes, bounding_boxes, landmarks] - _SPLIT_ID_TO_NAME = { - "0": "train", - "1": "val", - "2": "test", - } - - def _filter_split(self, data: Tuple[str, Dict[str, str]], *, split: str) -> bool: - return self._SPLIT_ID_TO_NAME[data[1]["split_id"]] == split + def _filter_split(self, data: Tuple[str, Dict[str, str]]) -> bool: + split_id = { + "train": "0", + "val": "1", + "test": "2", + }[self._split] + return data[1]["split_id"] == split_id def _prepare_sample( self, @@ -145,16 +160,11 @@ def _prepare_sample( }, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: splits_dp, images_dp, identities_dp, attributes_dp, bounding_boxes_dp, landmarks_dp = resource_dps splits_dp = CelebACSVParser(splits_dp, fieldnames=("image_id", "split_id")) - splits_dp = Filter(splits_dp, functools.partial(self._filter_split, split=config.split)) + splits_dp = Filter(splits_dp, self._filter_split) splits_dp = hint_shuffling(splits_dp) splits_dp = hint_sharding(splits_dp) @@ -186,3 +196,10 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + "train": 162_770, + "val": 19_867, + "test": 19_962, + }[self._split] From 6de6ec4a8cf496cef321f0e7a865307728d7c9c4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:04:52 +0100 Subject: [PATCH 13/30] Migrate Food101 prototype dataset (#5758) * Migrate Food101 dataset * Added length * Update torchvision/prototype/datasets/_builtin/food101.py Co-authored-by: Philip Meier Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 6 +- .../prototype/datasets/_builtin/food101.py | 57 ++++++++++++------- 2 files changed, 39 insertions(+), 24 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index f0ee2370aa1..c362b53981f 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -933,8 +933,8 @@ def country211(root, config): return num_examples * len(classes) -# @register_mock -def food101(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test"))) +def food101(root, config): data_folder = root / "food-101" num_images_per_class = 3 @@ -968,7 +968,7 @@ def food101(info, root, config): make_tar(root, f"{data_folder.name}.tar.gz", compression="gz") - return num_samples_map[config.split] + return num_samples_map[config["split"]] @register_mock(configs=combinations_grid(split=("train", "val", "test"), fold=(1, 4, 10))) diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index cb720f137d9..36b2acca4d0 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Tuple, List, Dict, Optional, BinaryIO +from typing import Any, Tuple, List, Dict, Optional, BinaryIO, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -9,9 +9,10 @@ Demultiplexer, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, + BUILTIN_DIR, hint_sharding, path_comparator, getitem, @@ -19,16 +20,32 @@ ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info + + +NAME = "food101" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = [c[0] for c in categories] + return dict(categories=categories) -class Food101(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "food101", - homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", - valid_options=dict(split=("train", "test")), - ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class Food101(Dataset2): + """Food 101 dataset + homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", + """ + + def __init__(self, root: Union[str, Path], *, split: str = "train", skip_integrity_check: bool = False) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( url="http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz", @@ -49,7 +66,7 @@ def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: def _prepare_sample(self, data: Tuple[str, Tuple[str, BinaryIO]]) -> Dict[str, Any]: id, (path, buffer) = data return dict( - label=Label.from_category(id.split("/", 1)[0], categories=self.categories), + label=Label.from_category(id.split("/", 1)[0], categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) @@ -58,17 +75,12 @@ def _image_key(self, data: Tuple[str, Any]) -> str: path = Path(data[0]) return path.relative_to(path.parents[1]).with_suffix("").as_posix() - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, split_dp = Demultiplexer( archive_dp, 2, self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE ) - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True, return_path=False) split_dp = hint_sharding(split_dp) split_dp = hint_shuffling(split_dp) @@ -83,9 +95,12 @@ def _make_datapipe( return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: Path) -> List[str]: - resources = self.resources(self.default_config) - dp = resources[0].load(root) + def _generate_categories(self) -> List[str]: + resources = self.resources() + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "classes.txt")) dp = LineReader(dp, decode=True, return_path=False) return list(dp) + + def __len__(self) -> int: + return 75_750 if self._split == "train" else 25_250 From ebe9006ab3a41fcee75580c715a15c1789a6537c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:05:46 +0100 Subject: [PATCH 14/30] Migrate Fer2013 prototype dataset (#5759) * Migrate Fer2013 prototype dataset * Update torchvision/prototype/datasets/_builtin/fer2013.py Co-authored-by: Philip Meier Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 13 +++-- .../prototype/datasets/_builtin/fer2013.py | 57 +++++++++++-------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index c362b53981f..f0ff8ef17cb 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1019,13 +1019,14 @@ def dtd(root, config): return num_samples_map[config["split"], config["fold"]] -# @register_mock -def fer2013(info, root, config): - num_samples = 5 if config.split == "train" else 3 +@register_mock(configs=combinations_grid(split=("train", "test"))) +def fer2013(root, config): + split = config["split"] + num_samples = 5 if split == "train" else 3 - path = root / f"{config.split}.csv" + path = root / f"{split}.csv" with open(path, "w", newline="") as file: - field_names = ["emotion"] if config.split == "train" else [] + field_names = ["emotion"] if split == "train" else [] field_names.append("pixels") file.write(",".join(field_names) + "\n") @@ -1035,7 +1036,7 @@ def fer2013(info, root, config): rowdict = { "pixels": " ".join([str(int(pixel)) for pixel in torch.randint(256, (48 * 48,), dtype=torch.uint8)]) } - if config.split == "train": + if split == "train": rowdict["emotion"] = int(torch.randint(7, ())) writer.writerow(rowdict) diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index a5bfa681d02..ca30b78e609 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -1,11 +1,10 @@ -from typing import Any, Dict, List, cast +import pathlib +from typing import Any, Dict, List, cast, Union import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, OnlineResource, KaggleDownloadResource, ) @@ -15,26 +14,40 @@ ) from torchvision.prototype.features import Label, Image +from .._api import register_dataset, register_info + +NAME = "fer2013" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral")) -class FER2013(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "fer2013", - homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", - categories=("angry", "disgust", "fear", "happy", "sad", "surprise", "neutral"), - valid_options=dict(split=("train", "test")), - ) + +@register_dataset(NAME) +class FER2013(Dataset2): + """FER 2013 Dataset + homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "train", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) _CHECKSUMS = { "train": "a2b7c9360cc0b38d21187e5eece01c2799fce5426cdeecf746889cc96cda2d10", "test": "dec8dfe8021e30cd6704b85ec813042b4a5d99d81cb55e023291a94104f575c3", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: archive = KaggleDownloadResource( - cast(str, self.info.homepage), - file_name=f"{config.split}.csv.zip", - sha256=self._CHECKSUMS[config.split], + "https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge", + file_name=f"{self._split}.csv.zip", + sha256=self._CHECKSUMS[self._split], ) return [archive] @@ -43,17 +56,15 @@ def _prepare_sample(self, data: Dict[str, Any]) -> Dict[str, Any]: return dict( image=Image(torch.tensor([int(idx) for idx in data["pixels"].split()], dtype=torch.uint8).reshape(48, 48)), - label=Label(int(label_id), categories=self.categories) if label_id is not None else None, + label=Label(int(label_id), categories=self._categories) if label_id is not None else None, ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVDictParser(dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 28_709 if self._split == "train" else 3_589 From 8194b178cedbd570803c43a7f66cc9ea8d3286ca Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:06:02 +0100 Subject: [PATCH 15/30] Migrate EuroSAT prototype dataset (#5760) --- test/builtin_dataset_mocks.py | 4 +- .../prototype/datasets/_builtin/eurosat.py | 60 ++++++++++++------- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index f0ff8ef17cb..cc246d10691 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1393,8 +1393,8 @@ def cub200(info, root, config): return num_samples_map[config.split] -# @register_mock -def eurosat(info, root, config): +@register_mock(configs=[dict()]) +def eurosat(root, config): data_folder = root / "2750" data_folder.mkdir(parents=True) diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py index 336f35de968..00d6a04f320 100644 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ b/torchvision/prototype/datasets/_builtin/eurosat.py @@ -1,31 +1,44 @@ import pathlib -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import EncodedImage, Label +from .._api import register_dataset, register_info -class EuroSAT(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "eurosat", - homepage="https://github.com/phelber/eurosat", - categories=( - "AnnualCrop", - "Forest", - "HerbaceousVegetation", - "Highway", - "Industrial," "Pasture", - "PermanentCrop", - "Residential", - "River", - "SeaLake", - ), +NAME = "eurosat" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + categories=( + "AnnualCrop", + "Forest", + "HerbaceousVegetation", + "Highway", + "Industrial," "Pasture", + "PermanentCrop", + "Residential", + "River", + "SeaLake", ) + ) + - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class EuroSAT(Dataset2): + """EuroSAT Dataset. + homepage="https://github.com/phelber/eurosat", + """ + + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( "https://madm.dfki.de/files/sentinel/EuroSAT.zip", @@ -37,15 +50,16 @@ def _prepare_sample(self, data: Tuple[str, Any]) -> Dict[str, Any]: path, buffer = data category = pathlib.Path(path).parent.name return dict( - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), path=path, image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 27_000 From 4c9cbab25f4bbc2cf53d0ba22b0dbb34c04ef1d0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:06:34 +0100 Subject: [PATCH 16/30] Migrate Semeion prototype dataset (#5761) --- test/builtin_dataset_mocks.py | 6 +-- .../prototype/datasets/_builtin/semeion.py | 50 +++++++++++-------- 2 files changed, 33 insertions(+), 23 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index cc246d10691..b4cbadcf6bf 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -685,10 +685,10 @@ def sbd(info, root, config): return SBDMockData.generate(root)[config.split] -# @register_mock -def semeion(info, root, config): +@register_mock(configs=[dict()]) +def semeion(root, config): num_samples = 3 - num_categories = len(info.categories) + num_categories = 10 images = torch.rand(num_samples, 256) labels = one_hot(torch.randint(num_categories, size=(num_samples,)), num_classes=num_categories) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index fb64c051d6c..e3a802d3cee 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -1,31 +1,43 @@ -from typing import Any, Dict, List, Tuple +import pathlib +from typing import Any, Dict, List, Tuple, Union import torch +from pytest import skip from torchdata.datapipes.iter import ( IterDataPipe, Mapper, CSVParser, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, HttpResource, OnlineResource, ) from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, OneHotLabel +from .._api import register_dataset, register_info + +NAME = "semeion" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=[str(i) for i in range(10)]) -class SEMEION(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "semeion", - categories=10, - homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", - ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class SEMEION(Dataset2): + """Semeion dataset + homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", + """ + + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: data = HttpResource( "http://archive.ics.uci.edu/ml/machine-learning-databases/semeion/semeion.data", sha256="f43228ae3da5ea6a3c95069d53450b86166770e3b719dcc333182128fe08d4b1", @@ -36,18 +48,16 @@ def _prepare_sample(self, data: Tuple[str, ...]) -> Dict[str, Any]: image_data, label_data = data[:256], data[256:-1] return dict( - image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.uint8).reshape(16, 16)), - label=OneHotLabel([int(label) for label in label_data], categories=self.categories), + image=Image(torch.tensor([float(pixel) for pixel in image_data], dtype=torch.float).reshape(16, 16)), + label=OneHotLabel([int(label) for label in label_data], categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = CSVParser(dp, delimiter=" ") dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 1_593 From 5cd572237e75513765328ec15839c3beda8c7789 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 15:07:06 +0200 Subject: [PATCH 17/30] migrate caltech prototype datasets (#5749) * migrate caltech prototype datasets * resolve third party dependencies --- test/builtin_dataset_mocks.py | 41 +++--- .../prototype/datasets/_builtin/caltech.py | 117 ++++++++++++------ 2 files changed, 101 insertions(+), 57 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index b4cbadcf6bf..9cd5b96f2e7 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -370,8 +370,8 @@ def cifar100(root, config): return len(train_files if config["split"] == "train" else test_files) -# @register_mock -def caltech101(info, root, config): +@register_mock(configs=[dict()]) +def caltech101(root, config): def create_ann_file(root, name): import scipy.io @@ -390,15 +390,17 @@ def create_ann_folder(root, name, file_name_fn, num_examples): images_root = root / "101_ObjectCategories" anns_root = root / "Annotations" - ann_category_map = { - "Faces_2": "Faces", - "Faces_3": "Faces_easy", - "Motorbikes_16": "Motorbikes", - "Airplanes_Side_2": "airplanes", + image_category_map = { + "Faces": "Faces_2", + "Faces_easy": "Faces_3", + "Motorbikes": "Motorbikes_16", + "airplanes": "Airplanes_Side_2", } + categories = ["Faces", "Faces_easy", "Motorbikes", "airplanes", "yin_yang"] + num_images_per_category = 2 - for category in info.categories: + for category in categories: create_image_folder( root=images_root, name=category, @@ -407,7 +409,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples): ) create_ann_folder( root=anns_root, - name=ann_category_map.get(category, category), + name=image_category_map.get(category, category), file_name_fn=lambda idx: f"annotation_{idx + 1:04d}.mat", num_examples=num_images_per_category, ) @@ -417,19 +419,26 @@ def create_ann_folder(root, name, file_name_fn, num_examples): make_tar(root, f"{anns_root.name}.tar", anns_root) - return num_images_per_category * len(info.categories) + return num_images_per_category * len(categories) -# @register_mock -def caltech256(info, root, config): +@register_mock(configs=[dict()]) +def caltech256(root, config): dir = root / "256_ObjectCategories" num_images_per_category = 2 - for idx, category in enumerate(info.categories, 1): + categories = [ + (1, "ak47"), + (127, "laptop-101"), + (198, "spider"), + (257, "clutter"), + ] + + for category_idx, category in categories: files = create_image_folder( dir, - name=f"{idx:03d}.{category}", - file_name_fn=lambda image_idx: f"{idx:03d}_{image_idx + 1:04d}.jpg", + name=f"{category_idx:03d}.{category}", + file_name_fn=lambda image_idx: f"{category_idx:03d}_{image_idx + 1:04d}.jpg", num_examples=num_images_per_category, ) if category == "spider": @@ -437,7 +446,7 @@ def caltech256(info, root, config): make_tar(root, f"{dir.name}.tar", dir) - return num_images_per_category * len(info.categories) + return num_images_per_category * len(categories) @register_mock(configs=combinations_grid(split=("train", "val", "test"))) diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 4a409835b5e..3701063504f 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -1,6 +1,6 @@ import pathlib import re -from typing import Any, Dict, List, Tuple, BinaryIO +from typing import Any, Dict, List, Tuple, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -9,26 +9,49 @@ Filter, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + INFINITE_BUFFER_SIZE, + read_mat, + hint_sharding, + hint_shuffling, + BUILTIN_DIR, ) -from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from .._api import register_dataset, register_info -class Caltech101(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "caltech101", + +CALTECH101_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech101.categories")) + + +@register_info("caltech101") +def _caltech101_info() -> Dict[str, Any]: + return dict(categories=CALTECH101_CATEGORIES) + + +@register_dataset("caltech101") +class Caltech101(Dataset2): + """ + - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101 + - **dependencies**: + - _ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + skip_integrity_check: bool = False, + ) -> None: + self._categories = _caltech101_info()["categories"] + + super().__init__( + root, dependencies=("scipy",), - homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech101", + skip_integrity_check=skip_integrity_check, ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: images = HttpResource( "http://www.vision.caltech.edu/Image_Datasets/Caltech101/101_ObjectCategories.tar.gz", sha256="af6ece2f339791ca20f855943d8b55dd60892c0a25105fcd631ee3d6430f9926", @@ -88,7 +111,7 @@ def _prepare_sample( ann = read_mat(ann_buffer) return dict( - label=Label.from_category(category, categories=self.categories), + label=Label.from_category(category, categories=self._categories), image_path=image_path, image=image, ann_path=ann_path, @@ -98,12 +121,7 @@ def _prepare_sample( contour=_Feature(ann["obj_contour"].T), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps images_dp = Filter(images_dp, self._is_not_background_image) @@ -122,23 +140,42 @@ def _make_datapipe( ) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def __len__(self) -> int: + return 8677 + + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, self._is_not_background_image) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) -class Caltech256(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "caltech256", - homepage="http://www.vision.caltech.edu/Image_Datasets/Caltech256", - ) +CALTECH256_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech256.categories")) + + +@register_info("caltech256") +def _caltech256_info() -> Dict[str, Any]: + return dict(categories=CALTECH256_CATEGORIES) + + +@register_dataset("caltech256") +class Caltech256(Dataset2): + """ + - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256 + """ - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def __init__( + self, + root: Union[str, pathlib.Path], + skip_integrity_check: bool = False, + ) -> None: + self._categories = _caltech256_info()["categories"] + + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: return [ HttpResource( "http://www.vision.caltech.edu/Image_Datasets/Caltech256/256_ObjectCategories.tar", @@ -156,25 +193,23 @@ def _prepare_sample(self, data: Tuple[str, BinaryIO]) -> Dict[str, Any]: return dict( path=path, image=EncodedImage.from_file(buffer), - label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self.categories), + label=Label(int(pathlib.Path(path).parent.name.split(".", 1)[0]) - 1, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Filter(dp, self._is_not_rogue_file) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - resources = self.resources(self.default_config) + def __len__(self) -> int: + return 30607 + + def _generate_categories(self) -> List[str]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dir_names = {pathlib.Path(path).parent.name for path, _ in dp} return [name.split(".")[1] for name in sorted(dir_names)] From 70cd406087e3deef8ebe59624921179e3b087f83 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:07:40 +0100 Subject: [PATCH 18/30] Migrate Oxford Pets prototype dataset (#5764) * Migrate Oxford Pets prototype dataset * Update torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py Co-authored-by: Philip Meier Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 6 +- .../prototype/datasets/_builtin/__init__.py | 2 +- .../datasets/_builtin/oxford_iiit_pet.py | 66 +++++++++++-------- 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 9cd5b96f2e7..0f06443016d 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1230,9 +1230,9 @@ def generate(self, root): return num_samples_map -# @register_mock -def oxford_iiit_pet(info, root, config): - return OxfordIIITPetMockData.generate(root)[config.split] +@register_mock(name="oxford-iiit-pet", configs=combinations_grid(split=("trainval", "test"))) +def oxford_iiit_pet(root, config): + return OxfordIIITPetMockData.generate(root)[config["split"]] class _CUB200MockData: diff --git a/torchvision/prototype/datasets/_builtin/__init__.py b/torchvision/prototype/datasets/_builtin/__init__.py index b2beddc7f2b..4acc1d53b4d 100644 --- a/torchvision/prototype/datasets/_builtin/__init__.py +++ b/torchvision/prototype/datasets/_builtin/__init__.py @@ -12,7 +12,7 @@ from .gtsrb import GTSRB from .imagenet import ImageNet from .mnist import MNIST, FashionMNIST, KMNIST, EMNIST, QMNIST -from .oxford_iiit_pet import OxfordIITPet +from .oxford_iiit_pet import OxfordIIITPet from .pcam import PCAM from .sbd import SBD from .semeion import SEMEION diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 8d4fc00dbb0..714360c24f6 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -1,11 +1,10 @@ import enum import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, + Dataset2, DatasetInfo, HttpResource, OnlineResource, @@ -14,29 +13,45 @@ INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling, + BUILTIN_DIR, getitem, path_accessor, path_comparator, ) from torchvision.prototype.features import Label, EncodedImage +from .._api import register_dataset, register_info -class OxfordIITPetDemux(enum.IntEnum): + +NAME = "oxford-iiit-pet" + + +class OxfordIIITPetDemux(enum.IntEnum): SPLIT_AND_CLASSIFICATION = 0 SEGMENTATIONS = 1 -class OxfordIITPet(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "oxford-iiit-pet", - homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", - valid_options=dict( - split=("trainval", "test"), - ), - ) +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = [c[0] for c in categories] + return dict(categories=categories) + - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_dataset(NAME) +class OxfordIIITPet(Dataset2): + """Oxford IIIT Pet Dataset + homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", + """ + + def __init__( + self, root: Union[str, pathlib.Path], *, split: str = "trainval", skip_integrity_check: bool = False + ) -> None: + self._split = self._verify_str_arg(split, "split", {"trainval", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: images = HttpResource( "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", sha256="67195c5e1c01f1ab5f9b6a5d22b8c27a580d896ece458917e61d459337fa318d", @@ -51,8 +66,8 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: def _classify_anns(self, data: Tuple[str, Any]) -> Optional[int]: return { - "annotations": OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION, - "trimaps": OxfordIITPetDemux.SEGMENTATIONS, + "annotations": OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION, + "trimaps": OxfordIIITPetDemux.SEGMENTATIONS, }.get(pathlib.Path(data[0]).parent.name) def _filter_images(self, data: Tuple[str, Any]) -> bool: @@ -70,7 +85,7 @@ def _prepare_sample( image_path, image_buffer = image_data return dict( - label=Label(int(classification_data["label"]) - 1, categories=self.categories), + label=Label(int(classification_data["label"]) - 1, categories=self._categories), species="cat" if classification_data["species"] == "1" else "dog", segmentation_path=segmentation_path, segmentation=EncodedImage.from_file(segmentation_buffer), @@ -78,9 +93,7 @@ def _prepare_sample( image=EncodedImage.from_file(image_buffer), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, anns_dp = resource_dps images_dp = Filter(images_dp, self._filter_images) @@ -93,9 +106,7 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, ) - split_and_classification_dp = Filter( - split_and_classification_dp, path_comparator("name", f"{config.split}.txt") - ) + split_and_classification_dp = Filter(split_and_classification_dp, path_comparator("name", f"{self._split}.txt")) split_and_classification_dp = CSVDictParser( split_and_classification_dp, fieldnames=("image_id", "label", "species"), delimiter=" " ) @@ -122,13 +133,13 @@ def _make_datapipe( return Mapper(dp, self._prepare_sample) def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: - return self._classify_anns(data) == OxfordIITPetDemux.SPLIT_AND_CLASSIFICATION + return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION - def _generate_categories(self, root: pathlib.Path) -> List[str]: + def _generate_categories(self) -> List[str]: config = self.default_config resources = self.resources(config) - dp = resources[1].load(root) + dp = resources[1].load(self._root) dp = Filter(dp, self._filter_split_and_classification_anns) dp = Filter(dp, path_comparator("name", f"{config.split}.txt")) dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ") @@ -138,3 +149,6 @@ def _generate_categories(self, root: pathlib.Path) -> List[str]: *sorted(raw_categories_and_labels, key=lambda raw_category_and_label: int(raw_category_and_label[1])) ) return [" ".join(part.title() for part in raw_category.split("_")) for raw_category in raw_categories] + + def __len__(self) -> int: + return 3_680 if self._split == "trainval" else 3_669 From ccfcaa581ef66c4ac4ba80df4dc15f04e6dbe548 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 15:10:47 +0200 Subject: [PATCH 19/30] migrate mnist prototype datasets (#5480) * migrate MNIST prototype datasets * Update torchvision/prototype/datasets/_builtin/mnist.py Co-authored-by: Nicolas Hug Co-authored-by: Nicolas Hug --- test/builtin_dataset_mocks.py | 62 ++-- test/test_prototype_builtin_datasets.py | 3 +- .../prototype/datasets/_builtin/mnist.py | 281 +++++++++++------- 3 files changed, 212 insertions(+), 134 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 0f06443016d..bc117072df3 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -214,58 +214,64 @@ def generate( return num_samples -# @register_mock -def mnist(info, root, config): - train = config.split == "train" - images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" - labels_file = f"{'train' if train else 't10k'}-labels-idx1-ubyte.gz" +def mnist(root, config): + prefix = "train" if config["split"] == "train" else "t10k" return MNISTMockData.generate( root, - num_categories=len(info.categories), - images_file=images_file, - labels_file=labels_file, + num_categories=10, + images_file=f"{prefix}-images-idx3-ubyte.gz", + labels_file=f"{prefix}-labels-idx1-ubyte.gz", ) -# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) +DATASET_MOCKS.update( + { + name: DatasetMock(name, mock_data_fn=mnist, configs=combinations_grid(split=("train", "test"))) + for name in ["mnist", "fashionmnist", "kmnist"] + } +) -# @register_mock -def emnist(info, root, config): - # The image sets that merge some lower case letters in their respective upper case variant, still use dense - # labels in the data files. Thus, num_categories != len(categories) there. - num_categories = defaultdict( - lambda: len(info.categories), {image_set: 47 for image_set in ("Balanced", "By_Merge")} +@register_mock( + configs=combinations_grid( + split=("train", "test"), + image_set=("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), ) - +) +def emnist(root, config): num_samples_map = {} file_names = set() - for config_ in info._configs: - prefix = f"emnist-{config_.image_set.replace('_', '').lower()}-{config_.split}" + for split, image_set in itertools.product( + ("train", "test"), + ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST"), + ): + prefix = f"emnist-{image_set.replace('_', '').lower()}-{split}" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" file_names.update({images_file, labels_file}) - num_samples_map[config_] = MNISTMockData.generate( + num_samples_map[(split, image_set)] = MNISTMockData.generate( root, - num_categories=num_categories[config_.image_set], + # The image sets that merge some lower case letters in their respective upper case variant, still use dense + # labels in the data files. Thus, num_categories != len(categories) there. + num_categories=47 if config["image_set"] in ("Balanced", "By_Merge") else 62, images_file=images_file, labels_file=labels_file, ) make_zip(root, "emnist-gzip.zip", *file_names) - return num_samples_map[config] + return num_samples_map[(config["split"], config["image_set"])] -# @register_mock -def qmnist(info, root, config): - num_categories = len(info.categories) - if config.split == "train": +@register_mock(configs=combinations_grid(split=("train", "test", "test10k", "test50k", "nist"))) +def qmnist(root, config): + num_categories = 10 + if config["split"] == "train": num_samples = num_samples_gen = num_categories + 2 prefix = "qmnist-train" suffix = ".gz" compressor = gzip.open - elif config.split.startswith("test"): + elif config["split"].startswith("test"): # The split 'test50k' is defined as the last 50k images beginning at index 10000. Thus, we need to create # more than 10000 images for the dataset to not be empty. num_samples_gen = 10001 @@ -273,11 +279,11 @@ def qmnist(info, root, config): "test": num_samples_gen, "test10k": min(num_samples_gen, 10_000), "test50k": num_samples_gen - 10_000, - }[config.split] + }[config["split"]] prefix = "qmnist-test" suffix = ".gz" compressor = gzip.open - else: # config.split == "nist" + else: # config["split"] == "nist" num_samples = num_samples_gen = num_categories + 3 prefix = "xnist" suffix = ".xz" diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 8a929b6907c..d9ad4885b57 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -176,8 +176,7 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config): assert dp.buffer_size == INFINITE_BUFFER_SIZE -# FIXME: DATASET_MOCKS["qmnist"] -@parametrize_dataset_mocks({}) +@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: def test_extra_label(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 1e14e6dfc58..907faed49bd 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -7,12 +7,13 @@ import torch from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile -__all__ = ["MNIST", "FashionMNIST", "KMNIST", "EMNIST", "QMNIST"] +from .._api import register_dataset, register_info + prod = functools.partial(functools.reduce, operator.mul) @@ -57,18 +58,18 @@ def __iter__(self) -> Iterator[torch.Tensor]: yield read(dtype=dtype, count=count).reshape(shape) -class _MNISTBase(Dataset): +class _MNISTBase(Dataset2): _URL_BASE: Union[str, Sequence[str]] @abc.abstractmethod - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: pass - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: (images_file, images_sha256), ( labels_file, labels_sha256, - ) = self._files_and_checksums(config) + ) = self._files_and_checksums() url_bases = self._URL_BASE if isinstance(url_bases, str): @@ -82,21 +83,21 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return [images, labels] - def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: + def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: return None, None - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + _categories: List[str] + + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: image, label = data return dict( image=Image(image), - label=Label(label, dtype=torch.int64, categories=self.categories), + label=Label(label, dtype=torch.int64, categories=self._categories), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, labels_dp = resource_dps - start, stop = self.start_and_stop(config) + start, stop = self.start_and_stop() images_dp = Decompressor(images_dp) images_dp = MNISTFileReader(images_dp, start=start, stop=stop) @@ -107,19 +108,31 @@ def _make_datapipe( dp = Zipper(images_dp, labels_dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) - return Mapper(dp, functools.partial(self._prepare_sample, config=config)) + return Mapper(dp, self._prepare_sample) + + +@register_info("mnist") +def _mnist_info() -> Dict[str, Any]: + return dict( + categories=[str(label) for label in range(10)], + ) +@register_dataset("mnist") class MNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "mnist", - categories=10, - homepage="http://yann.lecun.com/exdb/mnist", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: http://yann.lecun.com/exdb/mnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE: Union[str, Sequence[str]] = ( "http://yann.lecun.com/exdb/mnist", @@ -132,8 +145,8 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6", } - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "train" if config.split == "train" else "t10k" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = "train" if self._split == "train" else "t10k" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" return (images_file, self._CHECKSUMS[images_file]), ( @@ -141,28 +154,35 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], self._CHECKSUMS[labels_file], ) + _categories = _mnist_info()["categories"] + + def __len__(self) -> int: + return 60_000 if self._split == "train" else 10_000 + + +@register_info("fashionmnist") +def _fashionmnist_info() -> Dict[str, Any]: + return dict( + categories=[ + "T-shirt/top", + "Trouser", + "Pullover", + "Dress", + "Coat", + "Sandal", + "Shirt", + "Sneaker", + "Bag", + "Ankle boot", + ], + ) + +@register_dataset("fashionmnist") class FashionMNIST(MNIST): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "fashionmnist", - categories=( - "T-shirt/top", - "Trouser", - "Pullover", - "Dress", - "Coat", - "Sandal", - "Shirt", - "Sneaker", - "Bag", - "Ankle boot", - ), - homepage="https://github.com/zalandoresearch/fashion-mnist", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: https://github.com/zalandoresearch/fashion-mnist + """ _URL_BASE = "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com" _CHECKSUMS = { @@ -172,17 +192,21 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5", } + _categories = _fashionmnist_info()["categories"] + + +@register_info("kmnist") +def _kmnist_info() -> Dict[str, Any]: + return dict( + categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], + ) + +@register_dataset("kmnist") class KMNIST(MNIST): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "kmnist", - categories=["o", "ki", "su", "tsu", "na", "ha", "ma", "ya", "re", "wo"], - homepage="http://codh.rois.ac.jp/kmnist/index.html.en", - valid_options=dict( - split=("train", "test"), - ), - ) + """ + - **homepage**: http://codh.rois.ac.jp/kmnist/index.html.en + """ _URL_BASE = "http://codh.rois.ac.jp/kmnist/dataset/kmnist" _CHECKSUMS = { @@ -192,36 +216,46 @@ def _make_info(self) -> DatasetInfo: "t10k-labels-idx1-ubyte.gz": "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c", } + _categories = _kmnist_info()["categories"] + + +@register_info("emnist") +def _emnist_info() -> Dict[str, Any]: + return dict( + categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), + ) + +@register_dataset("emnist") class EMNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "emnist", - categories=list(string.digits + string.ascii_uppercase + string.ascii_lowercase), - homepage="https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist", - valid_options=dict( - split=("train", "test"), - image_set=( - "Balanced", - "By_Merge", - "By_Class", - "Letters", - "Digits", - "MNIST", - ), - ), + """ + - **homepage**: https://www.westernsydney.edu.au/icns/reproducible_research/publication_support_materials/emnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + image_set: str = "Balanced", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + self._image_set = self._verify_str_arg( + image_set, "image_set", ("Balanced", "By_Merge", "By_Class", "Letters", "Digits", "MNIST") ) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE = "https://rds.westernsydney.edu.au/Institutes/MARCS/BENS/EMNIST" - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = f"emnist-{config.image_set.replace('_', '').lower()}-{config.split}" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = f"emnist-{self._image_set.replace('_', '').lower()}-{self._split}" images_file = f"{prefix}-images-idx3-ubyte.gz" labels_file = f"{prefix}-labels-idx1-ubyte.gz" - # Since EMNIST provides the data files inside an archive, we don't need provide checksums for them + # Since EMNIST provides the data files inside an archive, we don't need to provide checksums for them return (images_file, ""), (labels_file, "") - def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: return [ HttpResource( f"{self._URL_BASE}/emnist-gzip.zip", @@ -229,9 +263,9 @@ def resources(self, config: Optional[DatasetConfig] = None) -> List[OnlineResour ) ] - def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> Optional[int]: + def _classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: path = pathlib.Path(data[0]) - (images_file, _), (labels_file, _) = self._files_and_checksums(config) + (images_file, _), (labels_file, _) = self._files_and_checksums() if path.name == images_file: return 0 elif path.name == labels_file: @@ -239,6 +273,8 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> else: return None + _categories = _emnist_info()["categories"] + _LABEL_OFFSETS = { 38: 1, 39: 1, @@ -251,45 +287,71 @@ def _classify_archive(self, data: Tuple[str, Any], *, config: DatasetConfig) -> 46: 9, } - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: # In these two splits, some lowercase letters are merged into their uppercase ones (see Fig 2. in the paper). # That means for example that there is 'D', 'd', and 'C', but not 'c'. Since the labels are nevertheless dense, - # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create this gaps. For example, - # since there is no 'c', 'd' corresponds to + # i.e. no gaps between 0 and 46 for 47 total classes, we need to add an offset to create these gaps. For + # example, since there is no 'c', 'd' corresponds to # label 38 (10 digits + 26 uppercase letters + 3rd unmerged lower case letter - 1 for zero indexing), # and at the same time corresponds to # index 39 (10 digits + 26 uppercase letters + 4th lower case letter - 1 for zero indexing) - # in self.categories. Thus, we need to add 1 to the label to correct this. - if config.image_set in ("Balanced", "By_Merge"): + # in self._categories. Thus, we need to add 1 to the label to correct this. + if self._image_set in ("Balanced", "By_Merge"): image, label = data label += self._LABEL_OFFSETS.get(int(label), 0) data = (image, label) - return super()._prepare_sample(data, config=config) + return super()._prepare_sample(data) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp = resource_dps[0] images_dp, labels_dp = Demultiplexer( archive_dp, 2, - functools.partial(self._classify_archive, config=config), + self._classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE, ) - return super()._make_datapipe([images_dp, labels_dp], config=config) + return super()._datapipe([images_dp, labels_dp]) + + def __len__(self) -> int: + return { + ("train", "Balanced"): 112_800, + ("train", "By_Merge"): 697_932, + ("train", "By_Class"): 697_932, + ("train", "Letters"): 124_800, + ("train", "Digits"): 240_000, + ("train", "MNIST"): 60_000, + ("test", "Balanced"): 18_800, + ("test", "By_Merge"): 116_323, + ("test", "By_Class"): 116_323, + ("test", "Letters"): 20_800, + ("test", "Digits"): 40_000, + ("test", "MNIST"): 10_000, + }[(self._split, self._image_set)] + + +@register_info("qmnist") +def _qmnist_info() -> Dict[str, Any]: + return dict( + categories=[str(label) for label in range(10)], + ) +@register_dataset("qmnist") class QMNIST(_MNISTBase): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "qmnist", - categories=10, - homepage="https://github.com/facebookresearch/qmnist", - valid_options=dict( - split=("train", "test", "test10k", "test50k", "nist"), - ), - ) + """ + - **homepage**: https://github.com/facebookresearch/qmnist + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test", "test10k", "test50k", "nist")) + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL_BASE = "https://raw.githubusercontent.com/facebookresearch/qmnist/master" _CHECKSUMS = { @@ -301,9 +363,9 @@ def _make_info(self) -> DatasetInfo: "xnist-labels-idx2-int.xz": "db042968723ec2b7aed5f1beac25d2b6e983b9286d4f4bf725f1086e5ae55c4f", } - def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], Tuple[str, str]]: - prefix = "xnist" if config.split == "nist" else f"qmnist-{'train' if config.split== 'train' else 'test'}" - suffix = "xz" if config.split == "nist" else "gz" + def _files_and_checksums(self) -> Tuple[Tuple[str, str], Tuple[str, str]]: + prefix = "xnist" if self._split == "nist" else f"qmnist-{'train' if self._split == 'train' else 'test'}" + suffix = "xz" if self._split == "nist" else "gz" images_file = f"{prefix}-images-idx3-ubyte.{suffix}" labels_file = f"{prefix}-labels-idx2-int.{suffix}" return (images_file, self._CHECKSUMS[images_file]), ( @@ -311,13 +373,13 @@ def _files_and_checksums(self, config: DatasetConfig) -> Tuple[Tuple[str, str], self._CHECKSUMS[labels_file], ) - def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional[int]]: + def start_and_stop(self) -> Tuple[Optional[int], Optional[int]]: start: Optional[int] stop: Optional[int] - if config.split == "test10k": + if self._split == "test10k": start = 0 stop = 10000 - elif config.split == "test50k": + elif self._split == "test50k": start = 10000 stop = None else: @@ -325,10 +387,12 @@ def start_and_stop(self, config: DatasetConfig) -> Tuple[Optional[int], Optional return start, stop - def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: DatasetConfig) -> Dict[str, Any]: + _categories = _emnist_info()["categories"] + + def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor]) -> Dict[str, Any]: image, ann = data label, *extra_anns = ann - sample = super()._prepare_sample((image, label), config=config) + sample = super()._prepare_sample((image, label)) sample.update( dict( @@ -340,3 +404,12 @@ def _prepare_sample(self, data: Tuple[torch.Tensor, torch.Tensor], *, config: Da ) sample.update(dict(zip(("duplicate", "unused"), [bool(value) for value in extra_anns[-2:]]))) return sample + + def __len__(self) -> int: + return { + "train": 60_000, + "test": 60_000, + "test10k": 10_000, + "test50k": 50_000, + "nist": 402_953, + }[self._split] From 9ea341a90662d058f521fda2bcc059a2eb3d5f5a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:34:46 +0100 Subject: [PATCH 20/30] Migrate Stanford Cars prototype dataset (#5767) * Migrate Stanford Cars prototype dataset * Address comments --- test/builtin_dataset_mocks.py | 13 ++-- .../prototype/datasets/_builtin/dtd.py | 3 +- .../datasets/_builtin/stanford_cars.py | 78 ++++++++++++------- 3 files changed, 59 insertions(+), 35 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index bc117072df3..b33dc1450e3 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1473,18 +1473,19 @@ def pcam(root, config): return num_images -# @register_mock -def stanford_cars(info, root, config): +@register_mock(name="stanford-cars", configs=combinations_grid(split=("train", "test"))) +def stanford_cars(root, config): import scipy.io as io from numpy.core.records import fromarrays - num_samples = {"train": 5, "test": 7}[config["split"]] + split = config["split"] + num_samples = {"train": 5, "test": 7}[split] num_categories = 3 devkit = root / "devkit" devkit.mkdir(parents=True) - if config["split"] == "train": + if split == "train": images_folder_name = "cars_train" annotations_mat_path = devkit / "cars_train_annos.mat" else: @@ -1498,7 +1499,7 @@ def stanford_cars(info, root, config): num_examples=num_samples, ) - make_tar(root, f"cars_{config.split}.tgz", images_folder_name) + make_tar(root, f"cars_{split}.tgz", images_folder_name) bbox = np.random.randint(1, 200, num_samples, dtype=np.uint8) classes = np.random.randint(1, num_categories + 1, num_samples, dtype=np.uint8) fnames = [f"{i:5d}.jpg" for i in range(num_samples)] @@ -1508,7 +1509,7 @@ def stanford_cars(info, root, config): ) io.savemat(annotations_mat_path, {"annotations": rec_array}) - if config.split == "train": + if split == "train": make_tar(root, "car_devkit.tgz", devkit, compression="gz") return num_samples diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index a5de1359e4e..dcec6d0e716 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -41,8 +41,9 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) class DTD(Dataset2): """DTD Dataset. - homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", + homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", """ + def __init__( self, root: Union[str, pathlib.Path], diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 51c0b6152e6..85098eb34e5 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -1,11 +1,19 @@ import pathlib -from typing import Any, Dict, List, Tuple, Iterator, BinaryIO +from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset, DatasetConfig, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling, path_comparator, read_mat +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + hint_sharding, + hint_shuffling, + path_comparator, + read_mat, + BUILTIN_DIR, +) from torchvision.prototype.features import BoundingBox, EncodedImage, Label +from .._api import register_dataset, register_info + class StanfordCarsLabelReader(IterDataPipe[Tuple[int, int, int, int, int, str]]): def __init__(self, datapipe: IterDataPipe[Dict[str, Any]]) -> None: @@ -18,16 +26,33 @@ def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]: yield tuple(ann) # type: ignore[misc] -class StanfordCars(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - name="stanford-cars", - homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", - dependencies=("scipy",), - valid_options=dict( - split=("test", "train"), - ), - ) +NAME = "stanford-cars" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") + categories = [c[0] for c in categories] + return dict(categories=categories) + + +@register_dataset(NAME) +class StanfordCars(Dataset2): + """Stanford Cars dataset. + homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", + dependencies=scipy + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",)) _URL_ROOT = "https://ai.stanford.edu/~jkrause/" _URLS = { @@ -44,9 +69,9 @@ def _make_info(self) -> DatasetInfo: "car_devkit": "512b227b30e2f0a8aab9e09485786ab4479582073a144998da74d64b801fd288", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - resources: List[OnlineResource] = [HttpResource(self._URLS[config.split], sha256=self._CHECKSUM[config.split])] - if config.split == "train": + def _resources(self) -> List[OnlineResource]: + resources: List[OnlineResource] = [HttpResource(self._URLS[self._split], sha256=self._CHECKSUM[self._split])] + if self._split == "train": resources.append(HttpResource(url=self._URLS["car_devkit"], sha256=self._CHECKSUM["car_devkit"])) else: @@ -65,19 +90,14 @@ def _prepare_sample(self, data: Tuple[Tuple[str, BinaryIO], Tuple[int, int, int, return dict( path=path, image=image, - label=Label(target[4] - 1, categories=self.categories), + label=Label(target[4] - 1, categories=self._categories), bounding_box=BoundingBox(target[:4], format="xyxy", image_size=image.image_size), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: images_dp, targets_dp = resource_dps - if config.split == "train": + if self._split == "train": targets_dp = Filter(targets_dp, path_comparator("name", "cars_train_annos.mat")) targets_dp = StanfordCarsLabelReader(targets_dp) dp = Zipper(images_dp, targets_dp) @@ -85,12 +105,14 @@ def _make_datapipe( dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(split="train") - resources = self.resources(config) + def _generate_categories(self) -> List[str]: + resources = self._resources() - devkit_dp = resources[1].load(root) + devkit_dp = resources[1].load(self._root) meta_dp = Filter(devkit_dp, path_comparator("name", "cars_meta.mat")) _, meta_file = next(iter(meta_dp)) return list(read_mat(meta_file, squeeze_me=True)["class_names"]) + + def __len__(self) -> int: + return 8_144 if self._split == "train" else 8_041 From 3b10147d09151db28d1387fb35d2b2279d820ae5 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 15:36:40 +0200 Subject: [PATCH 21/30] fix category file generation (#5770) * fix category file generation * revert unrelated change * revert unrelated change --- torchvision/prototype/datasets/_builtin/country211.py | 4 ++-- torchvision/prototype/datasets/_builtin/dtd.py | 2 +- torchvision/prototype/datasets/_builtin/food101.py | 2 +- torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py | 5 ++--- torchvision/prototype/datasets/_builtin/sbd.py | 2 +- torchvision/prototype/datasets/_builtin/voc.py | 2 +- 6 files changed, 8 insertions(+), 9 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index ae0564b224b..461cd71568f 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -73,6 +73,6 @@ def __len__(self) -> int: }[self._split] def _generate_categories(self) -> List[str]: - resources = self.resources() - dp = resources[0].load(self.root) + resources = self._resources() + dp = resources[0].load(self._root) return sorted({pathlib.Path(path).parent.name for path, _ in dp}) diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index dcec6d0e716..d7f07dc8b30 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -135,7 +135,7 @@ def _filter_images(self, data: Tuple[str, Any]) -> bool: return self._classify_archive(data) == DTDDemux.IMAGES def _generate_categories(self) -> List[str]: - resources = self.resources() + resources = self._resources() dp = resources[0].load(self._root) dp = Filter(dp, self._filter_images) diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index 36b2acca4d0..c86b9aaea84 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -96,7 +96,7 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, return Mapper(dp, self._prepare_sample) def _generate_categories(self) -> List[str]: - resources = self.resources() + resources = self._resources() dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "classes.txt")) dp = LineReader(dp, decode=True, return_path=False) diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 714360c24f6..0ea336a1421 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -136,12 +136,11 @@ def _filter_split_and_classification_anns(self, data: Tuple[str, Any]) -> bool: return self._classify_anns(data) == OxfordIIITPetDemux.SPLIT_AND_CLASSIFICATION def _generate_categories(self) -> List[str]: - config = self.default_config - resources = self.resources(config) + resources = self._resources() dp = resources[1].load(self._root) dp = Filter(dp, self._filter_split_and_classification_anns) - dp = Filter(dp, path_comparator("name", f"{config.split}.txt")) + dp = Filter(dp, path_comparator("name", "trainval.txt")) dp = CSVDictParser(dp, fieldnames=("image_id", "label"), delimiter=" ") raw_categories_and_labels = {(data["image_id"].rsplit("_", 1)[0], data["label"]) for data in dp} diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index 7fd47b6c991..bcacaea2d24 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -121,7 +121,7 @@ def _make_datapipe( return Mapper(dp, self._prepare_sample) def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: - resources = self.resources(self.default_config) + resources = self._resources(self.default_config) dp = resources[0].load(root) dp = Filter(dp, path_comparator("name", "category_names.m")) diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 91b82794e27..1f5980bdc72 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -218,7 +218,7 @@ def _generate_categories(self) -> List[str]: resources = self._resources() archive_dp = resources[0].load(self._root) - dp = Filter(archive_dp, self._filter_detection_anns) + dp = Filter(archive_dp, self._filter_anns) dp = Mapper(dp, self._parse_detection_ann, input_col=1) return sorted({instance["name"] for _, anns in dp for instance in anns["object"]}) From 1691e7224c59c3b4d918fdd346f9e0192556cc8b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 15:36:53 +0200 Subject: [PATCH 22/30] migrate cub200 prototype dataset (#5765) * migrate cub200 prototype dataset * address comments * fix category-file-generation --- test/builtin_dataset_mocks.py | 8 +- .../prototype/datasets/_builtin/cub200.py | 95 ++++++++++++------- 2 files changed, 66 insertions(+), 37 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index b33dc1450e3..20606424319 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1402,10 +1402,10 @@ def generate(cls, root): return num_samples_map -# @register_mock -def cub200(info, root, config): - num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) - return num_samples_map[config.split] +@register_mock(configs=combinations_grid(split=("train", "test"), year=("2010", "2011"))) +def cub200(root, config): + num_samples_map = (CUB2002011MockData if config["year"] == "2011" else CUB2002010MockData).generate(root) + return num_samples_map[config["split"]] @register_mock(configs=[dict()]) diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 1b90b476aa7..073a790092c 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -1,7 +1,7 @@ import csv import functools import pathlib -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Callable, Union from torchdata.datapipes.iter import ( IterDataPipe, @@ -14,8 +14,7 @@ CSVDictParser, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, + Dataset2, DatasetInfo, HttpResource, OnlineResource, @@ -28,26 +27,53 @@ getitem, path_comparator, path_accessor, + BUILTIN_DIR, ) from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage +from .._api import register_dataset, register_info + csv.register_dialect("cub200", delimiter=" ") -class CUB200(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "cub200", - homepage="http://www.vision.caltech.edu/visipedia/CUB-200-2011.html", - dependencies=("scipy",), - valid_options=dict( - split=("train", "test"), - year=("2011", "2010"), - ), +NAME = "cub200" + +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) + + +@register_dataset(NAME) +class CUB200(Dataset2): + """ + - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + year: str = "2011", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "test")) + self._year = self._verify_str_arg(year, "year", ("2010", "2011")) + + self._categories = _info()["categories"] + + super().__init__( + root, + # TODO: this will only be available after https://github.com/pytorch/vision/pull/5473 + # dependencies=("scipy",), + skip_integrity_check=skip_integrity_check, ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - if config.year == "2011": + def _resources(self) -> List[OnlineResource]: + if self._year == "2011": archive = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz", sha256="0c685df5597a8b24909f6a7c9db6d11e008733779a671760afef78feb49bf081", @@ -59,7 +85,7 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: preprocess="decompress", ) return [archive, segmentations] - else: # config.year == "2010" + else: # self._year == "2010" split = HttpResource( "http://www.vision.caltech.edu/visipedia-data/CUB-200/lists.tgz", sha256="aeacbd5e3539ae84ea726e8a266a9a119c18f055cd80f3836d5eb4500b005428", @@ -90,12 +116,12 @@ def _2011_classify_archive(self, data: Tuple[str, Any]) -> Optional[int]: else: return None - def _2011_filter_split(self, row: List[str], *, split: str) -> bool: + def _2011_filter_split(self, row: List[str]) -> bool: _, split_id = row return { "0": "test", "1": "train", - }[split_id] == split + }[split_id] == self._split def _2011_segmentation_key(self, data: Tuple[str, Any]) -> str: path = pathlib.Path(data[0]) @@ -149,17 +175,12 @@ def _prepare_sample( return dict( prepare_ann_fn(anns_data, image.image_size), image=image, - label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self.categories), + label=Label(int(pathlib.Path(path).parent.name.rsplit(".", 1)[0]), categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: prepare_ann_fn: Callable - if config.year == "2011": + if self._year == "2011": archive_dp, segmentations_dp = resource_dps images_dp, split_dp, image_files_dp, bounding_boxes_dp = Demultiplexer( archive_dp, 4, self._2011_classify_archive, drop_none=True, buffer_size=INFINITE_BUFFER_SIZE @@ -171,7 +192,7 @@ def _make_datapipe( ) split_dp = CSVParser(split_dp, dialect="cub200") - split_dp = Filter(split_dp, functools.partial(self._2011_filter_split, split=config.split)) + split_dp = Filter(split_dp, self._2011_filter_split) split_dp = Mapper(split_dp, getitem(0)) split_dp = Mapper(split_dp, image_files_map.get) @@ -188,10 +209,10 @@ def _make_datapipe( ) prepare_ann_fn = self._2011_prepare_ann - else: # config.year == "2010" + else: # self._year == "2010" split_dp, images_dp, anns_dp = resource_dps - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True, return_path=False) split_dp = Mapper(split_dp, self._2010_split_key) @@ -217,11 +238,19 @@ def _make_datapipe( ) return Mapper(dp, functools.partial(self._prepare_sample, prepare_ann_fn=prepare_ann_fn)) - def _generate_categories(self, root: pathlib.Path) -> List[str]: - config = self.info.make_config(year="2011") - resources = self.resources(config) + def __len__(self) -> int: + return { + ("train", "2010"): 3_000, + ("test", "2010"): 3_033, + ("train", "2011"): 5_994, + ("test", "2011"): 5_794, + }[(self._split, self._year)] + + def _generate_categories(self) -> List[str]: + self._year = "2011" + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "classes.txt")) dp = CSVDictParser(dp, fieldnames=("label", "category"), dialect="cub200") From 2a212b89ebd29d7ed5e91d4f45a774c857bbc05b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 14:56:18 +0100 Subject: [PATCH 23/30] Migrate USPS prototype dataset (#5771) --- test/builtin_dataset_mocks.py | 8 +-- .../prototype/datasets/_builtin/usps.py | 57 ++++++++++++------- 2 files changed, 40 insertions(+), 25 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 20606424319..316894589c8 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1515,11 +1515,11 @@ def stanford_cars(root, config): return num_samples -# @register_mock -def usps(info, root, config): - num_samples = {"train": 15, "test": 7}[config.split] +@register_mock(configs=combinations_grid(split=("train", "test"))) +def usps(root, config): + num_samples = {"train": 15, "test": 7}[config["split"]] - with bz2.open(root / f"usps{'.t' if not config.split == 'train' else ''}.bz2", "wb") as fh: + with bz2.open(root / f"usps{'.t' if not config['split'] == 'train' else ''}.bz2", "wb") as fh: lines = [] for _ in range(num_samples): label = make_tensor(1, low=1, high=11, dtype=torch.int) diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index 155fbff5dbb..e1c9940ed86 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -1,22 +1,39 @@ -from typing import Any, Dict, List +import pathlib +from typing import Any, Dict, List, Union import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils import Dataset2, OnlineResource, HttpResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label +from .._api import register_dataset, register_info -class USPS(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "usps", - homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", - valid_options=dict( - split=("train", "test"), - ), - categories=10, - ) +NAME = "usps" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=[str(c) for c in range(10)]) + + +@register_dataset(NAME) +class USPS(Dataset2): + """USPS Dataset + homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test"}) + + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check) _URL = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass" @@ -29,8 +46,8 @@ def _make_info(self) -> DatasetInfo: ), } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - return [USPS._RESOURCES[config.split]] + def _resources(self) -> List[OnlineResource]: + return [USPS._RESOURCES[self._split]] def _prepare_sample(self, line: str) -> Dict[str, Any]: label, *values = line.strip().split(" ") @@ -38,17 +55,15 @@ def _prepare_sample(self, line: str) -> Dict[str, Any]: pixels = torch.tensor(values).add_(1).div_(2) return dict( image=Image(pixels.reshape(16, 16)), - label=Label(int(label) - 1, categories=self.categories), + label=Label(int(label) - 1, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = Decompressor(resource_dps[0]) dp = LineReader(dp, decode=True, return_path=False) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return 7_291 if self._split == "train" else 2_007 From 0b66ed6b22246a76eb971917ce3a9b4908025b4c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 15:58:58 +0200 Subject: [PATCH 24/30] migrate SBD prototype dataset (#5772) * migrate SBD prototype dataset * reuse categories --- test/builtin_dataset_mocks.py | 6 +- .../prototype/datasets/_builtin/sbd.py | 78 ++++++++++++------- 2 files changed, 51 insertions(+), 33 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 316894589c8..b1265440b0e 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -695,9 +695,9 @@ def generate(cls, root): return num_samples_map -# @register_mock -def sbd(info, root, config): - return SBDMockData.generate(root)[config.split] +@register_mock(configs=combinations_grid(split=("train", "val", "train_noval"))) +def sbd(root, config): + return SBDMockData.generate(root)[config["split"]] @register_mock(configs=[dict()]) diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index bcacaea2d24..d062d78fe0a 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -1,6 +1,6 @@ import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO +from typing import Any, Dict, List, Optional, Tuple, cast, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -11,13 +11,7 @@ IterKeyZipper, LineReader, ) -from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, - HttpResource, - OnlineResource, -) +from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -26,22 +20,44 @@ path_comparator, hint_sharding, hint_shuffling, + BUILTIN_DIR, ) from torchvision.prototype.features import _Feature, EncodedImage +from .._api import register_dataset, register_info + +NAME = "sbd" + +CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) -class SBD(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "sbd", - dependencies=("scipy",), - homepage="http://home.bharathh.info/pubs/codes/SBD/download.html", - valid_options=dict( - split=("train", "val", "train_noval"), - ), - ) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=CATEGORIES) + + +@register_dataset(NAME) +class SBD(Dataset2): + """ + - **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html + - **dependencies**: + - _ + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval")) + + self._categories = CATEGORIES + + super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) + + def _resources(self) -> List[OnlineResource]: archive = HttpResource( "https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/semantic_contours/benchmark.tgz", sha256="6a5a2918d5c73ce032fdeba876574d150d9d04113ab87540a1304cbcc715be53", @@ -85,12 +101,7 @@ def _prepare_sample(self, data: Tuple[Tuple[Any, Tuple[str, BinaryIO]], Tuple[st segmentation=_Feature(anns["Segmentation"].item()), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: archive_dp, extra_split_dp = resource_dps archive_dp = resource_dps[0] @@ -101,10 +112,10 @@ def _make_datapipe( buffer_size=INFINITE_BUFFER_SIZE, drop_none=True, ) - if config.split == "train_noval": + if self._split == "train_noval": split_dp = extra_split_dp - split_dp = Filter(split_dp, path_comparator("name", f"{config.split}.txt")) + split_dp = Filter(split_dp, path_comparator("name", f"{self._split}.txt")) split_dp = LineReader(split_dp, decode=True) split_dp = hint_shuffling(split_dp) split_dp = hint_sharding(split_dp) @@ -120,10 +131,17 @@ def _make_datapipe( ) return Mapper(dp, self._prepare_sample) - def _generate_categories(self, root: pathlib.Path) -> Tuple[str, ...]: - resources = self._resources(self.default_config) + def __len__(self) -> int: + return { + "train": 8_498, + "val": 2_857, + "train_noval": 5_623, + }[self._split] + + def _generate_categories(self) -> Tuple[str, ...]: + resources = self._resources() - dp = resources[0].load(root) + dp = resources[0].load(self._root) dp = Filter(dp, path_comparator("name", "category_names.m")) dp = LineReader(dp) dp = Mapper(dp, bytes.decode, input_col=1) From b3c8384d49ab4286f2816b3409010f5631ccc6a6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 6 Apr 2022 15:02:48 +0100 Subject: [PATCH 25/30] Migrate SVHN prototype dataset (#5769) --- test/builtin_dataset_mocks.py | 8 +-- .../prototype/datasets/_builtin/svhn.py | 64 ++++++++++++------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index b1265440b0e..3a1aac71e4f 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -1426,18 +1426,18 @@ def eurosat(root, config): return len(categories) * num_examples_per_class -# @register_mock -def svhn(info, root, config): +@register_mock(configs=combinations_grid(split=("train", "test", "extra"))) +def svhn(root, config): import scipy.io as sio num_samples = { "train": 2, "test": 3, "extra": 4, - }[config.split] + }[config["split"]] sio.savemat( - root / f"{config.split}_32x32.mat", + root / f"{config['split']}_32x32.mat", { "X": np.random.randint(256, size=(32, 32, 3, num_samples), dtype=np.uint8), "y": np.random.randint(10, size=(num_samples,), dtype=np.uint8), diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 70daece4f86..80c769f6377 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, List, Tuple, BinaryIO +import pathlib +from typing import Any, Dict, List, Tuple, BinaryIO, Union import numpy as np from torchdata.datapipes.iter import ( @@ -7,9 +8,7 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, - DatasetInfo, + Dataset2, HttpResource, OnlineResource, ) @@ -20,16 +19,33 @@ ) from torchvision.prototype.features import Label, Image +from .._api import register_dataset, register_info -class SVHN(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "svhn", - dependencies=("scipy",), - categories=10, - homepage="http://ufldl.stanford.edu/housenumbers/", - valid_options=dict(split=("train", "test", "extra")), - ) +NAME = "svhn" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict(categories=[str(c) for c in range(10)]) + + +@register_dataset(NAME) +class SVHN(Dataset2): + """SVHN Dataset. + homepage="http://ufldl.stanford.edu/housenumbers/", + dependencies = scipy + """ + + def __init__( + self, + root: Union[str, pathlib.Path], + *, + split: str = "train", + skip_integrity_check: bool = False, + ) -> None: + self._split = self._verify_str_arg(split, "split", {"train", "test", "extra"}) + self._categories = _info()["categories"] + super().__init__(root, skip_integrity_check=skip_integrity_check, dependencies=("scipy",)) _CHECKSUMS = { "train": "435e94d69a87fde4fd4d7f3dd208dfc32cb6ae8af2240d066de1df7508d083b8", @@ -37,10 +53,10 @@ def _make_info(self) -> DatasetInfo: "extra": "a133a4beb38a00fcdda90c9489e0c04f900b660ce8a316a5e854838379a71eb3", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: data = HttpResource( - f"http://ufldl.stanford.edu/housenumbers/{config.split}_32x32.mat", - sha256=self._CHECKSUMS[config.split], + f"http://ufldl.stanford.edu/housenumbers/{self._split}_32x32.mat", + sha256=self._CHECKSUMS[self._split], ) return [data] @@ -60,18 +76,20 @@ def _prepare_sample(self, data: Tuple[np.ndarray, np.ndarray]) -> Dict[str, Any] return dict( image=Image(image_array.transpose((2, 0, 1))), - label=Label(int(label_array) % 10, categories=self.categories), + label=Label(int(label_array) % 10, categories=self._categories), ) - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: dp = resource_dps[0] dp = Mapper(dp, self._read_images_and_labels) dp = UnBatcher(dp) dp = hint_shuffling(dp) dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) + + def __len__(self) -> int: + return { + "train": 73_257, + "test": 26_032, + "extra": 531_131, + }[self._split] From 11991449966b9df5eb40587a1454e92a67d7eeac Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 16:05:20 +0200 Subject: [PATCH 26/30] add test to enforce __len__ is working on prototype datasets (#5742) --- test/test_prototype_builtin_datasets.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index d9ad4885b57..1f7ebf34826 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -175,6 +175,13 @@ def test_infinite_buffer_size(self, test_home, dataset_mock, config): # resolved assert dp.buffer_size == INFINITE_BUFFER_SIZE + @parametrize_dataset_mocks(DATASET_MOCKS) + def test_has_length(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) + + assert len(dataset) > 0 + @parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) class TestQMNIST: From 8e7987a36d7ec8c7c0e3247a499fd6b3d50b0a52 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 16:09:19 +0200 Subject: [PATCH 27/30] reactivate special dataset tests --- test/test_prototype_builtin_datasets.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 1f7ebf34826..552b4f0a74b 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -203,8 +203,7 @@ def test_extra_label(self, test_home, dataset_mock, config): assert key in sample and isinstance(sample[key], type) -# FIXME: DATASET_MOCKS["gtsrb"] -@parametrize_dataset_mocks({}) +@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"]) class TestGTSRB: def test_label_matches_path(self, test_home, dataset_mock, config): # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. @@ -221,8 +220,7 @@ def test_label_matches_path(self, test_home, dataset_mock, config): assert sample["label"] == label_from_path -# FIXME: DATASET_MOCKS["usps"] -@parametrize_dataset_mocks({}) +@parametrize_dataset_mocks(DATASET_MOCKS["usps"]) class TestUSPS: def test_sample_content(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) From 5062a32e1ec5c3281f30d022c0e69a9ceb0eac4d Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 6 Apr 2022 16:12:42 +0200 Subject: [PATCH 28/30] add missing annotation --- torchvision/prototype/datasets/_builtin/pcam.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index 1ae94da5665..e375f6ab3c5 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -2,7 +2,6 @@ import pathlib from collections import namedtuple from typing import Any, Dict, List, Optional, Tuple, Iterator, Union -from unicodedata import category from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features @@ -130,5 +129,5 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, dp = hint_sharding(dp) return Mapper(dp, self._prepare_sample) - def __len__(self): + def __len__(self) -> int: return 262_144 if self._split == "train" else 32_768 From 3be12c78105260a4949e35ccabbd538eaf9aa2c0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 7 Apr 2022 09:03:34 +0100 Subject: [PATCH 29/30] Cleanup prototype dataset implementation (#5774) * Remove Dataset2 class * Move read_categories_file out of DatasetInfo * Remove FrozenBunch and FrozenMapping * Remove test_prototype_datasets_api.py and move missing dep test somewhere else * ufmt * Let read_categories_file accept names instead of paths * Mypy * flake8 * fix category file reading Co-authored-by: Philip Meier --- test/builtin_dataset_mocks.py | 2 +- test/test_prototype_builtin_datasets.py | 2 +- test/test_prototype_datasets_api.py | 231 ------------------ test/test_prototype_datasets_utils.py | 20 +- torchvision/prototype/datasets/_api.py | 6 +- .../prototype/datasets/_builtin/caltech.py | 18 +- .../prototype/datasets/_builtin/celeba.py | 4 +- .../prototype/datasets/_builtin/cifar.py | 23 +- .../prototype/datasets/_builtin/clevr.py | 4 +- .../prototype/datasets/_builtin/coco.py | 9 +- .../prototype/datasets/_builtin/country211.py | 15 +- .../prototype/datasets/_builtin/cub200.py | 11 +- .../prototype/datasets/_builtin/dtd.py | 11 +- .../prototype/datasets/_builtin/eurosat.py | 4 +- .../prototype/datasets/_builtin/fer2013.py | 6 +- .../prototype/datasets/_builtin/food101.py | 10 +- .../prototype/datasets/_builtin/gtsrb.py | 4 +- .../prototype/datasets/_builtin/imagenet.py | 9 +- .../prototype/datasets/_builtin/mnist.py | 4 +- .../datasets/_builtin/oxford_iiit_pet.py | 11 +- .../prototype/datasets/_builtin/pcam.py | 4 +- .../prototype/datasets/_builtin/sbd.py | 12 +- .../prototype/datasets/_builtin/semeion.py | 5 +- .../datasets/_builtin/stanford_cars.py | 10 +- .../prototype/datasets/_builtin/svhn.py | 4 +- .../prototype/datasets/_builtin/usps.py | 4 +- .../prototype/datasets/_builtin/voc.py | 10 +- .../prototype/datasets/utils/__init__.py | 2 +- .../prototype/datasets/utils/_dataset.py | 180 +------------- .../prototype/datasets/utils/_internal.py | 10 + torchvision/prototype/utils/_internal.py | 83 ------- 31 files changed, 121 insertions(+), 607 deletions(-) delete mode 100644 test/test_prototype_datasets_api.py diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 3a1aac71e4f..768d286e890 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -68,7 +68,7 @@ def prepare(self, home, config): mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) - with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"): + with unittest.mock.patch.object(datasets.utils.Dataset, "__init__"): required_file_names = { resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() } diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 552b4f0a74b..fc2ebd9aa38 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -59,7 +59,7 @@ def test_smoke(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) - if not isinstance(dataset, datasets.utils.Dataset2): + if not isinstance(dataset, datasets.utils.Dataset): raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) diff --git a/test/test_prototype_datasets_api.py b/test/test_prototype_datasets_api.py deleted file mode 100644 index 70a2707d050..00000000000 --- a/test/test_prototype_datasets_api.py +++ /dev/null @@ -1,231 +0,0 @@ -import unittest.mock - -import pytest -from torchvision.prototype import datasets -from torchvision.prototype.utils._internal import FrozenMapping, FrozenBunch - - -def make_minimal_dataset_info(name="name", categories=None, **kwargs): - return datasets.utils.DatasetInfo(name, categories=categories or [], **kwargs) - - -class TestFrozenMapping: - @pytest.mark.parametrize( - ("args", "kwargs"), - [ - pytest.param((dict(foo="bar", baz=1),), dict(), id="from_dict"), - pytest.param((), dict(foo="bar", baz=1), id="from_kwargs"), - pytest.param((dict(foo="bar"),), dict(baz=1), id="mixed"), - ], - ) - def test_instantiation(self, args, kwargs): - FrozenMapping(*args, **kwargs) - - def test_unhashable_items(self): - with pytest.raises(TypeError, match="unhashable type"): - FrozenMapping(foo=[]) - - def test_getitem(self): - options = dict(foo="bar", baz=1) - config = FrozenMapping(options) - - for key, value in options.items(): - assert config[key] == value - - def test_getitem_unknown(self): - with pytest.raises(KeyError): - FrozenMapping()["unknown"] - - def test_iter(self): - options = dict(foo="bar", baz=1) - assert set(iter(FrozenMapping(options))) == set(options.keys()) - - def test_len(self): - options = dict(foo="bar", baz=1) - assert len(FrozenMapping(options)) == len(options) - - def test_immutable_setitem(self): - frozen_mapping = FrozenMapping() - - with pytest.raises(RuntimeError, match="immutable"): - frozen_mapping["foo"] = "bar" - - def test_immutable_delitem( - self, - ): - frozen_mapping = FrozenMapping(foo="bar") - - with pytest.raises(RuntimeError, match="immutable"): - del frozen_mapping["foo"] - - def test_eq(self): - options = dict(foo="bar", baz=1) - assert FrozenMapping(options) == FrozenMapping(options) - - def test_ne(self): - options1 = dict(foo="bar", baz=1) - options2 = options1.copy() - options2["baz"] += 1 - - assert FrozenMapping(options1) != FrozenMapping(options2) - - def test_repr(self): - options = dict(foo="bar", baz=1) - output = repr(FrozenMapping(options)) - - assert isinstance(output, str) - for key, value in options.items(): - assert str(key) in output and str(value) in output - - -class TestFrozenBunch: - def test_getattr(self): - options = dict(foo="bar", baz=1) - config = FrozenBunch(options) - - for key, value in options.items(): - assert getattr(config, key) == value - - def test_getattr_unknown(self): - with pytest.raises(AttributeError, match="no attribute 'unknown'"): - datasets.utils.DatasetConfig().unknown - - def test_immutable_setattr(self): - frozen_bunch = FrozenBunch() - - with pytest.raises(RuntimeError, match="immutable"): - frozen_bunch.foo = "bar" - - def test_immutable_delattr( - self, - ): - frozen_bunch = FrozenBunch(foo="bar") - - with pytest.raises(RuntimeError, match="immutable"): - del frozen_bunch.foo - - def test_repr(self): - options = dict(foo="bar", baz=1) - output = repr(FrozenBunch(options)) - - assert isinstance(output, str) - assert output.startswith("FrozenBunch") - for key, value in options.items(): - assert f"{key}={value}" in output - - -class TestDatasetInfo: - @pytest.fixture - def info(self): - return make_minimal_dataset_info(valid_options=dict(split=("train", "test"), foo=("bar", "baz"))) - - def test_default_config(self, info): - valid_options = info._valid_options - default_config = datasets.utils.DatasetConfig({key: values[0] for key, values in valid_options.items()}) - - assert info.default_config == default_config - - @pytest.mark.parametrize( - ("valid_options", "options", "expected_error_msg"), - [ - (dict(), dict(any_option=None), "does not take any options"), - (dict(split="train"), dict(unknown_option=None), "Unknown option 'unknown_option'"), - (dict(split="train"), dict(split="invalid_argument"), "Invalid argument 'invalid_argument'"), - ], - ) - def test_make_config_invalid_inputs(self, info, valid_options, options, expected_error_msg): - info = make_minimal_dataset_info(valid_options=valid_options) - - with pytest.raises(ValueError, match=expected_error_msg): - info.make_config(**options) - - def test_check_dependencies(self): - dependency = "fake_dependency" - info = make_minimal_dataset_info(dependencies=(dependency,)) - with pytest.raises(ModuleNotFoundError, match=dependency): - info.check_dependencies() - - def test_repr(self, info): - output = repr(info) - - assert isinstance(output, str) - assert "DatasetInfo" in output - for key, value in info._valid_options.items(): - assert f"{key}={str(value)[1:-1]}" in output - - @pytest.mark.parametrize("optional_info", ("citation", "homepage", "license")) - def test_repr_optional_info(self, optional_info): - sentinel = "sentinel" - info = make_minimal_dataset_info(**{optional_info: sentinel}) - - assert f"{optional_info}={sentinel}" in repr(info) - - -class TestDataset: - class DatasetMock(datasets.utils.Dataset): - def __init__(self, info=None, *, resources=None): - self._info = info or make_minimal_dataset_info(valid_options=dict(split=("train", "test"))) - self.resources = unittest.mock.Mock(return_value=[]) if resources is None else lambda config: resources - self._make_datapipe = unittest.mock.Mock() - super().__init__() - - def _make_info(self): - return self._info - - def resources(self, config): - # This method is just defined to appease the ABC, but will be overwritten at instantiation - pass - - def _make_datapipe(self, resource_dps, *, config): - # This method is just defined to appease the ABC, but will be overwritten at instantiation - pass - - def test_name(self): - name = "sentinel" - dataset = self.DatasetMock(make_minimal_dataset_info(name=name)) - - assert dataset.name == name - - def test_default_config(self): - sentinel = "sentinel" - dataset = self.DatasetMock(info=make_minimal_dataset_info(valid_options=dict(split=(sentinel, "train")))) - - assert dataset.default_config == datasets.utils.DatasetConfig(split=sentinel) - - @pytest.mark.parametrize( - ("config", "kwarg"), - [ - pytest.param(*(datasets.utils.DatasetConfig(split="test"),) * 2, id="specific"), - pytest.param(DatasetMock().default_config, None, id="default"), - ], - ) - def test_load_config(self, config, kwarg): - dataset = self.DatasetMock() - - dataset.load("", config=kwarg) - - dataset.resources.assert_called_with(config) - - _, call_kwargs = dataset._make_datapipe.call_args - assert call_kwargs["config"] == config - - def test_missing_dependencies(self): - dependency = "fake_dependency" - dataset = self.DatasetMock(make_minimal_dataset_info(dependencies=(dependency,))) - with pytest.raises(ModuleNotFoundError, match=dependency): - dataset.load("root") - - def test_resources(self, mocker): - resource_mock = mocker.Mock(spec=["load"]) - sentinel = object() - resource_mock.load.return_value = sentinel - dataset = self.DatasetMock(resources=[resource_mock]) - - root = "root" - dataset.load(root) - - (call_args, _) = resource_mock.load.call_args - assert call_args[0] == root - - (call_args, _) = dataset._make_datapipe.call_args - assert call_args[0][0] is sentinel diff --git a/test/test_prototype_datasets_utils.py b/test/test_prototype_datasets_utils.py index bd857abf02f..b1c95844574 100644 --- a/test/test_prototype_datasets_utils.py +++ b/test/test_prototype_datasets_utils.py @@ -5,7 +5,7 @@ import torch from datasets_utils import make_fake_flo_file from torchvision.datasets._optical_flow import _read_flo as read_flo_ref -from torchvision.prototype.datasets.utils import HttpResource, GDriveResource +from torchvision.prototype.datasets.utils import HttpResource, GDriveResource, Dataset from torchvision.prototype.datasets.utils._internal import read_flo, fromfile @@ -101,3 +101,21 @@ def preprocess_sentinel(path): assert redirected_resource.file_name == file_name assert redirected_resource.sha256 == sha256_sentinel assert redirected_resource._preprocess is preprocess_sentinel + + +def test_missing_dependency_error(): + class DummyDataset(Dataset): + def __init__(self): + super().__init__(root="root", dependencies=("fake_dependency",)) + + def _resources(self): + pass + + def _datapipe(self, resource_dps): + pass + + def __len__(self): + pass + + with pytest.raises(ModuleNotFoundError, match="depends on the third-party package 'fake_dependency'"): + DummyDataset() diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 8f8bb53deb4..407dc23f64b 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -2,12 +2,12 @@ from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset2 +from torchvision.prototype.datasets.utils import Dataset from torchvision.prototype.utils._internal import add_suggestion T = TypeVar("T") -D = TypeVar("D", bound=Type[Dataset2]) +D = TypeVar("D", bound=Type[Dataset]) BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} @@ -56,7 +56,7 @@ def info(name: str) -> Dict[str, Any]: return find(BUILTIN_INFOS, name) -def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2: +def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset: dataset_cls = find(BUILTIN_DATASETS, name) if root is None: diff --git a/torchvision/prototype/datasets/_builtin/caltech.py b/torchvision/prototype/datasets/_builtin/caltech.py index 3701063504f..7010ab9503d 100644 --- a/torchvision/prototype/datasets/_builtin/caltech.py +++ b/torchvision/prototype/datasets/_builtin/caltech.py @@ -9,29 +9,26 @@ Filter, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, hint_sharding, hint_shuffling, - BUILTIN_DIR, + read_categories_file, ) from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage from .._api import register_dataset, register_info -CALTECH101_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech101.categories")) - - @register_info("caltech101") def _caltech101_info() -> Dict[str, Any]: - return dict(categories=CALTECH101_CATEGORIES) + return dict(categories=read_categories_file("caltech101")) @register_dataset("caltech101") -class Caltech101(Dataset2): +class Caltech101(Dataset): """ - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech101 - **dependencies**: @@ -152,16 +149,13 @@ def _generate_categories(self) -> List[str]: return sorted({pathlib.Path(path).parent.name for path, _ in dp}) -CALTECH256_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "caltech256.categories")) - - @register_info("caltech256") def _caltech256_info() -> Dict[str, Any]: - return dict(categories=CALTECH256_CATEGORIES) + return dict(categories=read_categories_file("caltech256")) @register_dataset("caltech256") -class Caltech256(Dataset2): +class Caltech256(Dataset): """ - **homepage**: http://www.vision.caltech.edu/Image_Datasets/Caltech256 """ diff --git a/torchvision/prototype/datasets/_builtin/celeba.py b/torchvision/prototype/datasets/_builtin/celeba.py index 17a42082f3f..46ccf8de6f7 100644 --- a/torchvision/prototype/datasets/_builtin/celeba.py +++ b/torchvision/prototype/datasets/_builtin/celeba.py @@ -10,7 +10,7 @@ IterKeyZipper, ) from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, GDriveResource, OnlineResource, ) @@ -68,7 +68,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class CelebA(Dataset2): +class CelebA(Dataset): """ - **homepage**: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html """ diff --git a/torchvision/prototype/datasets/_builtin/cifar.py b/torchvision/prototype/datasets/_builtin/cifar.py index 9274aa543d4..514938d6e5f 100644 --- a/torchvision/prototype/datasets/_builtin/cifar.py +++ b/torchvision/prototype/datasets/_builtin/cifar.py @@ -10,8 +10,13 @@ Filter, Mapper, ) -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import hint_shuffling, path_comparator, hint_sharding, BUILTIN_DIR +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + hint_shuffling, + path_comparator, + hint_sharding, + read_categories_file, +) from torchvision.prototype.features import Label, Image from .._api import register_dataset, register_info @@ -29,13 +34,13 @@ def __iter__(self) -> Iterator[Tuple[np.ndarray, int]]: yield from iter(zip(image_arrays, category_idcs)) -class _CifarBase(Dataset2): +class _CifarBase(Dataset): _FILE_NAME: str _SHA256: str _LABELS_KEY: str _META_FILE_NAME: str _CATEGORIES_KEY: str - # _categories: List[str] + _categories: List[str] def __init__( self, @@ -92,12 +97,9 @@ def _generate_categories(self) -> List[str]: return cast(List[str], next(iter(dp))[self._CATEGORIES_KEY]) -CIFAR10_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar10.categories")) - - @register_info("cifar10") def _cifar10_info() -> Dict[str, Any]: - return dict(categories=CIFAR10_CATEGORIES) + return dict(categories=read_categories_file("cifar10")) @register_dataset("cifar10") @@ -118,12 +120,9 @@ def _is_data_file(self, data: Tuple[str, Any]) -> bool: return path.name.startswith("data" if self._split == "train" else "test") -CIFAR100_CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / "cifar100.categories")) - - @register_info("cifar100") def _cifar100_info() -> Dict[str, Any]: - return dict(categories=CIFAR10_CATEGORIES) + return dict(categories=read_categories_file("cifar100")) @register_dataset("cifar100") diff --git a/torchvision/prototype/datasets/_builtin/clevr.py b/torchvision/prototype/datasets/_builtin/clevr.py index 9d322de084c..3a139787c6f 100644 --- a/torchvision/prototype/datasets/_builtin/clevr.py +++ b/torchvision/prototype/datasets/_builtin/clevr.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, JsonParser, UnBatcher -from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, hint_sharding, @@ -24,7 +24,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class CLEVR(Dataset2): +class CLEVR(Dataset): """ - **homepage**: https://cs.stanford.edu/people/jcjohns/clevr/ """ diff --git a/torchvision/prototype/datasets/_builtin/coco.py b/torchvision/prototype/datasets/_builtin/coco.py index 75896a8db08..ff3b5f37c96 100644 --- a/torchvision/prototype/datasets/_builtin/coco.py +++ b/torchvision/prototype/datasets/_builtin/coco.py @@ -16,16 +16,15 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - DatasetInfo, HttpResource, OnlineResource, - Dataset2, + Dataset, ) from torchvision.prototype.datasets.utils._internal import ( MappingIterator, INFINITE_BUFFER_SIZE, - BUILTIN_DIR, getitem, + read_categories_file, path_accessor, hint_sharding, hint_shuffling, @@ -40,12 +39,12 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories, super_categories = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + categories, super_categories = zip(*read_categories_file(NAME)) return dict(categories=categories, super_categories=super_categories) @register_dataset(NAME) -class Coco(Dataset2): +class Coco(Dataset): """ - **homepage**: https://cocodataset.org/ - **dependencies**: diff --git a/torchvision/prototype/datasets/_builtin/country211.py b/torchvision/prototype/datasets/_builtin/country211.py index 461cd71568f..012ecae19e2 100644 --- a/torchvision/prototype/datasets/_builtin/country211.py +++ b/torchvision/prototype/datasets/_builtin/country211.py @@ -2,24 +2,27 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource -from torchvision.prototype.datasets.utils._internal import path_comparator, hint_sharding, hint_shuffling, BUILTIN_DIR +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils._internal import ( + path_comparator, + hint_sharding, + hint_shuffling, + read_categories_file, +) from torchvision.prototype.features import EncodedImage, Label from .._api import register_dataset, register_info NAME = "country211" -CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) - @register_info(NAME) def _info() -> Dict[str, Any]: - return dict(categories=CATEGORIES) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) -class Country211(Dataset2): +class Country211(Dataset): """ - **homepage**: https://github.com/openai/CLIP/blob/main/data/country211.md """ diff --git a/torchvision/prototype/datasets/_builtin/cub200.py b/torchvision/prototype/datasets/_builtin/cub200.py index 073a790092c..1e4db7cef73 100644 --- a/torchvision/prototype/datasets/_builtin/cub200.py +++ b/torchvision/prototype/datasets/_builtin/cub200.py @@ -14,8 +14,7 @@ CSVDictParser, ) from torchvision.prototype.datasets.utils import ( - Dataset2, - DatasetInfo, + Dataset, HttpResource, OnlineResource, ) @@ -26,8 +25,8 @@ hint_shuffling, getitem, path_comparator, + read_categories_file, path_accessor, - BUILTIN_DIR, ) from torchvision.prototype.features import Label, BoundingBox, _Feature, EncodedImage @@ -38,16 +37,14 @@ NAME = "cub200" -CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) - @register_info(NAME) def _info() -> Dict[str, Any]: - return dict(categories=CATEGORIES) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) -class CUB200(Dataset2): +class CUB200(Dataset): """ - **homepage**: http://www.vision.caltech.edu/visipedia/CUB-200.html """ diff --git a/torchvision/prototype/datasets/_builtin/dtd.py b/torchvision/prototype/datasets/_builtin/dtd.py index d7f07dc8b30..b082ada19ce 100644 --- a/torchvision/prototype/datasets/_builtin/dtd.py +++ b/torchvision/prototype/datasets/_builtin/dtd.py @@ -4,8 +4,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, LineReader, CSVParser from torchvision.prototype.datasets.utils import ( - Dataset2, - DatasetInfo, + Dataset, HttpResource, OnlineResource, ) @@ -13,8 +12,8 @@ INFINITE_BUFFER_SIZE, hint_sharding, path_comparator, - BUILTIN_DIR, getitem, + read_categories_file, hint_shuffling, ) from torchvision.prototype.features import Label, EncodedImage @@ -33,13 +32,11 @@ class DTDDemux(enum.IntEnum): @register_info(NAME) def _info() -> Dict[str, Any]: - categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") - categories = [c[0] for c in categories] - return dict(categories=categories) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) -class DTD(Dataset2): +class DTD(Dataset): """DTD Dataset. homepage="https://www.robots.ox.ac.uk/~vgg/data/dtd/", """ diff --git a/torchvision/prototype/datasets/_builtin/eurosat.py b/torchvision/prototype/datasets/_builtin/eurosat.py index 00d6a04f320..ab31aaf6f42 100644 --- a/torchvision/prototype/datasets/_builtin/eurosat.py +++ b/torchvision/prototype/datasets/_builtin/eurosat.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe, Mapper -from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import EncodedImage, Label @@ -29,7 +29,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class EuroSAT(Dataset2): +class EuroSAT(Dataset): """EuroSAT Dataset. homepage="https://github.com/phelber/eurosat", """ diff --git a/torchvision/prototype/datasets/_builtin/fer2013.py b/torchvision/prototype/datasets/_builtin/fer2013.py index ca30b78e609..c1a914c6f63 100644 --- a/torchvision/prototype/datasets/_builtin/fer2013.py +++ b/torchvision/prototype/datasets/_builtin/fer2013.py @@ -1,10 +1,10 @@ import pathlib -from typing import Any, Dict, List, cast, Union +from typing import Any, Dict, List, Union import torch from torchdata.datapipes.iter import IterDataPipe, Mapper, CSVDictParser from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, OnlineResource, KaggleDownloadResource, ) @@ -25,7 +25,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class FER2013(Dataset2): +class FER2013(Dataset): """FER 2013 Dataset homepage="https://www.kaggle.com/c/challenges-in-representation-learning-facial-expression-recognition-challenge" """ diff --git a/torchvision/prototype/datasets/_builtin/food101.py b/torchvision/prototype/datasets/_builtin/food101.py index c86b9aaea84..5100e5d8c74 100644 --- a/torchvision/prototype/datasets/_builtin/food101.py +++ b/torchvision/prototype/datasets/_builtin/food101.py @@ -9,14 +9,14 @@ Demultiplexer, IterKeyZipper, ) -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_shuffling, - BUILTIN_DIR, hint_sharding, path_comparator, getitem, INFINITE_BUFFER_SIZE, + read_categories_file, ) from torchvision.prototype.features import Label, EncodedImage @@ -28,13 +28,11 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") - categories = [c[0] for c in categories] - return dict(categories=categories) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) -class Food101(Dataset2): +class Food101(Dataset): """Food 101 dataset homepage="https://data.vision.ee.ethz.ch/cvl/datasets_extra/food-101", """ diff --git a/torchvision/prototype/datasets/_builtin/gtsrb.py b/torchvision/prototype/datasets/_builtin/gtsrb.py index fa29f3be780..01f754208e2 100644 --- a/torchvision/prototype/datasets/_builtin/gtsrb.py +++ b/torchvision/prototype/datasets/_builtin/gtsrb.py @@ -3,7 +3,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, CSVDictParser, Zipper, Demultiplexer from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, OnlineResource, HttpResource, ) @@ -28,7 +28,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class GTSRB(Dataset2): +class GTSRB(Dataset): """GTSRB Dataset homepage="https://benchmark.ini.rub.de" diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 56accca02b4..1307757cef6 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -15,18 +15,17 @@ Enumerator, ) from torchvision.prototype.datasets.utils import ( - DatasetInfo, OnlineResource, ManualDownloadResource, - Dataset2, + Dataset, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, - BUILTIN_DIR, getitem, read_mat, hint_sharding, hint_shuffling, + read_categories_file, path_accessor, ) from torchvision.prototype.features import Label, EncodedImage @@ -38,7 +37,7 @@ @register_info(NAME) def _info() -> Dict[str, Any]: - categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + categories, wnids = zip(*read_categories_file(NAME)) return dict(categories=categories, wnids=wnids) @@ -53,7 +52,7 @@ class ImageNetDemux(enum.IntEnum): @register_dataset(NAME) -class ImageNet(Dataset2): +class ImageNet(Dataset): """ - **homepage**: https://www.image-net.org/ """ diff --git a/torchvision/prototype/datasets/_builtin/mnist.py b/torchvision/prototype/datasets/_builtin/mnist.py index 907faed49bd..e5537a1ef66 100644 --- a/torchvision/prototype/datasets/_builtin/mnist.py +++ b/torchvision/prototype/datasets/_builtin/mnist.py @@ -7,7 +7,7 @@ import torch from torchdata.datapipes.iter import IterDataPipe, Demultiplexer, Mapper, Zipper, Decompressor -from torchvision.prototype.datasets.utils import Dataset2, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label from torchvision.prototype.utils._internal import fromfile @@ -58,7 +58,7 @@ def __iter__(self) -> Iterator[torch.Tensor]: yield read(dtype=dtype, count=count).reshape(shape) -class _MNISTBase(Dataset2): +class _MNISTBase(Dataset): _URL_BASE: Union[str, Sequence[str]] @abc.abstractmethod diff --git a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py index 0ea336a1421..f7da02a4765 100644 --- a/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py +++ b/torchvision/prototype/datasets/_builtin/oxford_iiit_pet.py @@ -4,8 +4,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Filter, IterKeyZipper, Demultiplexer, CSVDictParser from torchvision.prototype.datasets.utils import ( - Dataset2, - DatasetInfo, + Dataset, HttpResource, OnlineResource, ) @@ -13,9 +12,9 @@ INFINITE_BUFFER_SIZE, hint_sharding, hint_shuffling, - BUILTIN_DIR, getitem, path_accessor, + read_categories_file, path_comparator, ) from torchvision.prototype.features import Label, EncodedImage @@ -33,13 +32,11 @@ class OxfordIIITPetDemux(enum.IntEnum): @register_info(NAME) def _info() -> Dict[str, Any]: - categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") - categories = [c[0] for c in categories] - return dict(categories=categories) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) -class OxfordIIITPet(Dataset2): +class OxfordIIITPet(Dataset): """Oxford IIIT Pet Dataset homepage="https://www.robots.ox.ac.uk/~vgg/data/pets/", """ diff --git a/torchvision/prototype/datasets/_builtin/pcam.py b/torchvision/prototype/datasets/_builtin/pcam.py index e375f6ab3c5..7cd31469139 100644 --- a/torchvision/prototype/datasets/_builtin/pcam.py +++ b/torchvision/prototype/datasets/_builtin/pcam.py @@ -6,7 +6,7 @@ from torchdata.datapipes.iter import IterDataPipe, Mapper, Zipper from torchvision.prototype import features from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, OnlineResource, GDriveResource, ) @@ -50,7 +50,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class PCAM(Dataset2): +class PCAM(Dataset): # TODO write proper docstring """PCAM Dataset diff --git a/torchvision/prototype/datasets/_builtin/sbd.py b/torchvision/prototype/datasets/_builtin/sbd.py index d062d78fe0a..0c806fe098c 100644 --- a/torchvision/prototype/datasets/_builtin/sbd.py +++ b/torchvision/prototype/datasets/_builtin/sbd.py @@ -11,7 +11,7 @@ IterKeyZipper, LineReader, ) -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, read_mat, @@ -20,7 +20,7 @@ path_comparator, hint_sharding, hint_shuffling, - BUILTIN_DIR, + read_categories_file, ) from torchvision.prototype.features import _Feature, EncodedImage @@ -28,16 +28,14 @@ NAME = "sbd" -CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) - @register_info(NAME) def _info() -> Dict[str, Any]: - return dict(categories=CATEGORIES) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) -class SBD(Dataset2): +class SBD(Dataset): """ - **homepage**: http://home.bharathh.info/pubs/codes/SBD/download.html - **dependencies**: @@ -53,7 +51,7 @@ def __init__( ) -> None: self._split = self._verify_str_arg(split, "split", ("train", "val", "train_noval")) - self._categories = CATEGORIES + self._categories = _info()["categories"] super().__init__(root, dependencies=("scipy",), skip_integrity_check=skip_integrity_check) diff --git a/torchvision/prototype/datasets/_builtin/semeion.py b/torchvision/prototype/datasets/_builtin/semeion.py index e3a802d3cee..5051bde4047 100644 --- a/torchvision/prototype/datasets/_builtin/semeion.py +++ b/torchvision/prototype/datasets/_builtin/semeion.py @@ -2,14 +2,13 @@ from typing import Any, Dict, List, Tuple, Union import torch -from pytest import skip from torchdata.datapipes.iter import ( IterDataPipe, Mapper, CSVParser, ) from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, HttpResource, OnlineResource, ) @@ -27,7 +26,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class SEMEION(Dataset2): +class SEMEION(Dataset): """Semeion dataset homepage="https://archive.ics.uci.edu/ml/datasets/Semeion+Handwritten+Digit", """ diff --git a/torchvision/prototype/datasets/_builtin/stanford_cars.py b/torchvision/prototype/datasets/_builtin/stanford_cars.py index 85098eb34e5..465d753c2e5 100644 --- a/torchvision/prototype/datasets/_builtin/stanford_cars.py +++ b/torchvision/prototype/datasets/_builtin/stanford_cars.py @@ -2,13 +2,13 @@ from typing import Any, Dict, List, Tuple, Iterator, BinaryIO, Union from torchdata.datapipes.iter import Filter, IterDataPipe, Mapper, Zipper -from torchvision.prototype.datasets.utils import Dataset2, DatasetInfo, HttpResource, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, HttpResource, OnlineResource from torchvision.prototype.datasets.utils._internal import ( hint_sharding, hint_shuffling, path_comparator, read_mat, - BUILTIN_DIR, + read_categories_file, ) from torchvision.prototype.features import BoundingBox, EncodedImage, Label @@ -31,13 +31,11 @@ def __iter__(self) -> Iterator[Tuple[int, int, int, int, int, str]]: @register_info(NAME) def _info() -> Dict[str, Any]: - categories = DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories") - categories = [c[0] for c in categories] - return dict(categories=categories) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) -class StanfordCars(Dataset2): +class StanfordCars(Dataset): """Stanford Cars dataset. homepage="https://ai.stanford.edu/~jkrause/cars/car_dataset.html", dependencies=scipy diff --git a/torchvision/prototype/datasets/_builtin/svhn.py b/torchvision/prototype/datasets/_builtin/svhn.py index 80c769f6377..175aa6c0a51 100644 --- a/torchvision/prototype/datasets/_builtin/svhn.py +++ b/torchvision/prototype/datasets/_builtin/svhn.py @@ -8,7 +8,7 @@ UnBatcher, ) from torchvision.prototype.datasets.utils import ( - Dataset2, + Dataset, HttpResource, OnlineResource, ) @@ -30,7 +30,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class SVHN(Dataset2): +class SVHN(Dataset): """SVHN Dataset. homepage="http://ufldl.stanford.edu/housenumbers/", dependencies = scipy diff --git a/torchvision/prototype/datasets/_builtin/usps.py b/torchvision/prototype/datasets/_builtin/usps.py index e1c9940ed86..e732f3b788a 100644 --- a/torchvision/prototype/datasets/_builtin/usps.py +++ b/torchvision/prototype/datasets/_builtin/usps.py @@ -3,7 +3,7 @@ import torch from torchdata.datapipes.iter import IterDataPipe, LineReader, Mapper, Decompressor -from torchvision.prototype.datasets.utils import Dataset2, OnlineResource, HttpResource +from torchvision.prototype.datasets.utils import Dataset, OnlineResource, HttpResource from torchvision.prototype.datasets.utils._internal import hint_sharding, hint_shuffling from torchvision.prototype.features import Image, Label @@ -18,7 +18,7 @@ def _info() -> Dict[str, Any]: @register_dataset(NAME) -class USPS(Dataset2): +class USPS(Dataset): """USPS Dataset homepage="https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#usps", """ diff --git a/torchvision/prototype/datasets/_builtin/voc.py b/torchvision/prototype/datasets/_builtin/voc.py index 1f5980bdc72..05a3c2e8622 100644 --- a/torchvision/prototype/datasets/_builtin/voc.py +++ b/torchvision/prototype/datasets/_builtin/voc.py @@ -13,7 +13,7 @@ LineReader, ) from torchvision.datasets import VOCDetection -from torchvision.prototype.datasets.utils import DatasetInfo, OnlineResource, HttpResource, Dataset2 +from torchvision.prototype.datasets.utils import OnlineResource, HttpResource, Dataset from torchvision.prototype.datasets.utils._internal import ( path_accessor, getitem, @@ -21,7 +21,7 @@ path_comparator, hint_sharding, hint_shuffling, - BUILTIN_DIR, + read_categories_file, ) from torchvision.prototype.features import BoundingBox, Label, EncodedImage @@ -29,16 +29,14 @@ NAME = "voc" -CATEGORIES, *_ = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) - @register_info(NAME) def _info() -> Dict[str, Any]: - return dict(categories=CATEGORIES) + return dict(categories=read_categories_file(NAME)) @register_dataset(NAME) -class VOC(Dataset2): +class VOC(Dataset): """ - **homepage**: http://host.robots.ox.ac.uk/pascal/VOC/ """ diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index a16a839b594..e7ef72f28a9 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip -from ._dataset import DatasetConfig, DatasetInfo, Dataset, Dataset2 +from ._dataset import Dataset from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index a6ec05c3ff4..528d0a0f25f 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -1,191 +1,15 @@ import abc -import csv import importlib -import itertools -import os import pathlib -from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection, Iterator +from typing import Any, Dict, List, Optional, Sequence, Union, Collection, Iterator from torch.utils.data import IterDataPipe -from torchvision._utils import sequence_to_str from torchvision.datasets.utils import verify_str_arg -from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion -from .._home import use_sharded_dataset -from ._internal import BUILTIN_DIR, _make_sharded_datapipe from ._resource import OnlineResource -class DatasetConfig(FrozenBunch): - # This needs to be Frozen because we often pass configs as partial(func, config=config) - # and partial() requires the parameters to be hashable. - pass - - -class DatasetInfo: - def __init__( - self, - name: str, - *, - dependencies: Collection[str] = (), - categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, - citation: Optional[str] = None, - homepage: Optional[str] = None, - license: Optional[str] = None, - valid_options: Optional[Dict[str, Sequence[Any]]] = None, - extra: Optional[Dict[str, Any]] = None, - ) -> None: - self.name = name.lower() - - self.dependecies = dependencies - - if categories is None: - path = BUILTIN_DIR / f"{self.name}.categories" - categories = path if path.exists() else [] - if isinstance(categories, int): - categories = [str(label) for label in range(categories)] - elif isinstance(categories, (str, pathlib.Path)): - path = pathlib.Path(categories).expanduser().resolve() - categories, *_ = zip(*self.read_categories_file(path)) - self.categories = tuple(categories) - - self.citation = citation - self.homepage = homepage - self.license = license - - self._valid_options = valid_options or dict() - self._configs = tuple( - DatasetConfig(**dict(zip(self._valid_options.keys(), combination))) - for combination in itertools.product(*self._valid_options.values()) - ) - - self.extra = FrozenBunch(extra or dict()) - - @property - def default_config(self) -> DatasetConfig: - return self._configs[0] - - @staticmethod - def read_categories_file(path: pathlib.Path) -> List[List[str]]: - with open(path, newline="") as file: - return [row for row in csv.reader(file)] - - def make_config(self, **options: Any) -> DatasetConfig: - if not self._valid_options and options: - raise ValueError( - f"Dataset {self.name} does not take any options, " - f"but got {sequence_to_str(list(options), separate_last=' and')}." - ) - - for name, arg in options.items(): - if name not in self._valid_options: - raise ValueError( - add_suggestion( - f"Unknown option '{name}' of dataset {self.name}.", - word=name, - possibilities=sorted(self._valid_options.keys()), - ) - ) - - valid_args = self._valid_options[name] - - if arg not in valid_args: - raise ValueError( - add_suggestion( - f"Invalid argument '{arg}' for option '{name}' of dataset {self.name}.", - word=arg, - possibilities=valid_args, - ) - ) - - return DatasetConfig(self.default_config, **options) - - def check_dependencies(self) -> None: - for dependency in self.dependecies: - try: - importlib.import_module(dependency) - except ModuleNotFoundError as error: - raise ModuleNotFoundError( - f"Dataset '{self.name}' depends on the third-party package '{dependency}'. " - f"Please install it, for example with `pip install {dependency}`." - ) from error - - def __repr__(self) -> str: - items = [("name", self.name)] - for key in ("citation", "homepage", "license"): - value = getattr(self, key) - if value is not None: - items.append((key, value)) - items.extend(sorted((key, sequence_to_str(value)) for key, value in self._valid_options.items())) - return make_repr(type(self).__name__, items) - - -class Dataset(abc.ABC): - def __init__(self) -> None: - self._info = self._make_info() - - @abc.abstractmethod - def _make_info(self) -> DatasetInfo: - pass - - @property - def info(self) -> DatasetInfo: - return self._info - - @property - def name(self) -> str: - return self.info.name - - @property - def default_config(self) -> DatasetConfig: - return self.info.default_config - - @property - def categories(self) -> Tuple[str, ...]: - return self.info.categories - - @abc.abstractmethod - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - pass - - @abc.abstractmethod - def _make_datapipe( - self, - resource_dps: List[IterDataPipe], - *, - config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: - pass - - def supports_sharded(self) -> bool: - return False - - def load( - self, - root: Union[str, pathlib.Path], - *, - config: Optional[DatasetConfig] = None, - skip_integrity_check: bool = False, - ) -> IterDataPipe[Dict[str, Any]]: - if not config: - config = self.info.default_config - - if use_sharded_dataset() and self.supports_sharded(): - root = os.path.join(root, *config.values()) - dataset_size = self.info.extra["sizes"][config] - return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return] - - self.info.check_dependencies() - resource_dps = [ - resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) - ] - return self._make_datapipe(resource_dps, config=config) - - def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: - raise NotImplementedError - - -class Dataset2(IterDataPipe[Dict[str, Any]], abc.ABC): +class Dataset(IterDataPipe[Dict[str, Any]], abc.ABC): @staticmethod def _verify_str_arg( value: str, diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index fa48218fe02..007e91eb657 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -1,3 +1,4 @@ +import csv import functools import pathlib import pickle @@ -9,6 +10,7 @@ Any, Tuple, TypeVar, + List, Iterator, Dict, IO, @@ -198,3 +200,11 @@ def hint_sharding(datapipe: IterDataPipe) -> ShardingFilter: def hint_shuffling(datapipe: IterDataPipe[D]) -> Shuffler[D]: return Shuffler(datapipe, buffer_size=INFINITE_BUFFER_SIZE).set_shuffle(False) + + +def read_categories_file(name: str) -> List[Union[str, Sequence[str]]]: + path = BUILTIN_DIR / f"{name}.categories" + with open(path, newline="") as file: + rows = list(csv.reader(file)) + rows = [row[0] if len(row) == 1 else row for row in rows] + return rows diff --git a/torchvision/prototype/utils/_internal.py b/torchvision/prototype/utils/_internal.py index fe5284394cb..233128880e3 100644 --- a/torchvision/prototype/utils/_internal.py +++ b/torchvision/prototype/utils/_internal.py @@ -2,20 +2,13 @@ import difflib import io import mmap -import os -import os.path import platform -import textwrap from typing import ( Any, BinaryIO, Callable, - cast, Collection, - Iterable, Iterator, - Mapping, - NoReturn, Sequence, Tuple, TypeVar, @@ -30,9 +23,6 @@ __all__ = [ "add_suggestion", - "FrozenMapping", - "make_repr", - "FrozenBunch", "fromfile", "ReadOnlyTensorBuffer", "apply_recursively", @@ -60,82 +50,9 @@ def add_suggestion( return f"{msg.strip()} {hint}" -K = TypeVar("K") D = TypeVar("D") -class FrozenMapping(Mapping[K, D]): - def __init__(self, *args: Any, **kwargs: Any) -> None: - data = dict(*args, **kwargs) - self.__dict__["__data__"] = data - self.__dict__["__final_hash__"] = hash(tuple(data.items())) - - def __getitem__(self, item: K) -> D: - return cast(Mapping[K, D], self.__dict__["__data__"])[item] - - def __iter__(self) -> Iterator[K]: - return iter(self.__dict__["__data__"].keys()) - - def __len__(self) -> int: - return len(self.__dict__["__data__"]) - - def __immutable__(self) -> NoReturn: - raise RuntimeError(f"'{type(self).__name__}' object is immutable") - - def __setitem__(self, key: K, value: Any) -> NoReturn: - self.__immutable__() - - def __delitem__(self, key: K) -> NoReturn: - self.__immutable__() - - def __hash__(self) -> int: - return cast(int, self.__dict__["__final_hash__"]) - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, FrozenMapping): - return NotImplemented - - return hash(self) == hash(other) - - def __repr__(self) -> str: - return repr(self.__dict__["__data__"]) - - -def make_repr(name: str, items: Iterable[Tuple[str, Any]]) -> str: - def to_str(sep: str) -> str: - return sep.join([f"{key}={value}" for key, value in items]) - - prefix = f"{name}(" - postfix = ")" - body = to_str(", ") - - line_length = int(os.environ.get("COLUMNS", 80)) - body_too_long = (len(prefix) + len(body) + len(postfix)) > line_length - multiline_body = len(str(body).splitlines()) > 1 - if not (body_too_long or multiline_body): - return prefix + body + postfix - - body = textwrap.indent(to_str(",\n"), " " * 2) - return f"{prefix}\n{body}\n{postfix}" - - -class FrozenBunch(FrozenMapping): - def __getattr__(self, name: str) -> Any: - try: - return self[name] - except KeyError as error: - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") from error - - def __setattr__(self, key: Any, value: Any) -> NoReturn: - self.__immutable__() - - def __delattr__(self, item: Any) -> NoReturn: - self.__immutable__() - - def __repr__(self) -> str: - return make_repr(type(self).__name__, self.items()) - - def _read_mutable_buffer_fallback(file: BinaryIO, count: int, item_size: int) -> bytearray: # A plain file.read() will give a read-only bytes, so we convert it to bytearray to make it mutable return bytearray(file.read(-1 if count == -1 else count * item_size)) From cd36d06a9db23909fcc0d0fae57a8597ffcd2641 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 7 Apr 2022 10:49:44 +0200 Subject: [PATCH 30/30] update prototype dataset README (#5777) * update prototype dataset README * fix header level * Apply suggestions from code review Co-authored-by: Nicolas Hug Co-authored-by: Nicolas Hug --- .../prototype/datasets/_builtin/README.md | 235 +++++++++++++----- 1 file changed, 169 insertions(+), 66 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/README.md b/torchvision/prototype/datasets/_builtin/README.md index c20c0241fac..05d61c6870e 100644 --- a/torchvision/prototype/datasets/_builtin/README.md +++ b/torchvision/prototype/datasets/_builtin/README.md @@ -12,51 +12,66 @@ Finally, `from torchvision.prototype import datasets` is implied below. Before we start with the actual implementation, you should create a module in `torchvision/prototype/datasets/_builtin` that hints at the dataset you are going to add. For example `caltech.py` for `caltech101` and `caltech256`. In that -module create a class that inherits from `datasets.utils.Dataset` and overwrites at minimum three methods that will be -discussed in detail below: +module create a class that inherits from `datasets.utils.Dataset` and overwrites four methods that will be discussed in +detail below: ```python -from typing import Any, Dict, List +import pathlib +from typing import Any, BinaryIO, Dict, List, Tuple, Union from torchdata.datapipes.iter import IterDataPipe -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, DatasetConfig, OnlineResource +from torchvision.prototype.datasets.utils import Dataset, OnlineResource +from .._api import register_dataset, register_info + +NAME = "my-dataset" + +@register_info(NAME) +def _info() -> Dict[str, Any]: + return dict( + ... + ) + +@register_dataset(NAME) class MyDataset(Dataset): - def _make_info(self) -> DatasetInfo: + def __init__(self, root: Union[str, pathlib.Path], *, ..., skip_integrity_check: bool = False) -> None: ... + super().__init__(root, skip_integrity_check=skip_integrity_check) - def resources(self, config: DatasetConfig) -> List[OnlineResource]: + def _resources(self) -> List[OnlineResource]: ... - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig, - ) -> IterDataPipe[Dict[str, Any]]: + def _datapipe(self, resource_dps: List[IterDataPipe[Tuple[str, BinaryIO]]]) -> IterDataPipe[Dict[str, Any]]: + ... + + def __len__(self) -> int: ... ``` -### `_make_info(self)` +In addition to the dataset, you also need to implement an `_info()` function that takes no arguments and returns a +dictionary of static information. The most common use case is to provide human-readable categories. +[See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories. -The `DatasetInfo` carries static information about the dataset. There are two required fields: +Finally, both the dataset class and the info function need to be registered on the API with the respective decorators. +With that they are loadable through `datasets.load("my-dataset")` and `datasets.info("my-dataset")`, respectively. -- `name`: Name of the dataset. This will be used to load the dataset with `datasets.load(name)`. Should only contain - lowercase characters. +### `__init__(self, root, *, ..., skip_integrity_check = False)` -There are more optional parameters that can be passed: +Constructor of the dataset that will be called when the dataset is instantiated. In addition to the parameters of the +base class, it can take arbitrary keyword-only parameters with defaults. The checking of these parameters as well as +setting them as instance attributes has to happen before the call of `super().__init__(...)`, because that will invoke +the other methods, which possibly depend on the parameters. All instance attributes must be private, i.e. prefixed with +an underscore. -- `dependencies`: Collection of third-party dependencies that are needed to load the dataset, e.g. `("scipy",)`. Their - availability will be automatically checked if a user tries to load the dataset. Within the implementation, import - these packages lazily to avoid missing dependencies at import time. -- `categories`: Sequence of human-readable category names for each label. The index of each category has to match the - corresponding label returned in the dataset samples. - [See below](#how-do-i-handle-a-dataset-that-defines-many-categories) how to handle cases with many categories. -- `valid_options`: Configures valid options that can be passed to the dataset. It should be `Dict[str, Sequence[Any]]`. - The options are accessible through the `config` namespace in the other two functions. First value of the sequence is - taken as default if the user passes no option to `torchvision.prototype.datasets.load()`. +If the implementation of the dataset depends on third-party packages, pass them as a collection of strings to the base +class constructor, e.g. `super().__init__(..., dependencies=("scipy",))`. Their availability will be automatically +checked if a user tries to load the dataset. Within the implementation of the dataset, import these packages lazily to +avoid missing dependencies at import time. -## `resources(self, config)` +### `_resources(self)` -Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset with a -specific `config` can be build. The download will happen automatically. +Returns `List[datasets.utils.OnlineResource]` of all the files that need to be present locally before the dataset can be +build. The download will happen automatically. Currently, the following `OnlineResource`'s are supported: @@ -81,7 +96,7 @@ def sha256sum(path, chunk_size=1024 * 1024): print(checksum.hexdigest()) ``` -### `_make_datapipe(resource_dps, *, config)` +### `_datapipe(self, resource_dps)` This method is the heart of the dataset, where we transform the raw data into a usable form. A major difference compared to the current stable datasets is that everything is performed through `IterDataPipe`'s. From the perspective of someone @@ -99,60 +114,112 @@ All of them can be imported `from torchdata.datapipes.iter`. In addition, use `f needs extra arguments. If the provided `IterDataPipe`'s are not sufficient for the use case, it is also not complicated to add one. See the MNIST or CelebA datasets for example. -`make_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return -value of `resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain -tuples comprised of the path and the handle for every file in the archive. Otherwise the datapipe will only contain one +`_datapipe()` receives `resource_dps`, which is a list of datapipes that has a 1-to-1 correspondence with the return +value of `_resources()`. In case of archives with regular suffixes (`.tar`, `.zip`, ...), the datapipe will contain +tuples comprised of the path and the handle for every file in the archive. Otherwise, the datapipe will only contain one of such tuples for the file specified by the resource. Since the datapipes are iterable in nature, some datapipes feature an in-memory buffer, e.g. `IterKeyZipper` and -`Grouper`. There are two issues with that: 1. If not used carefully, this can easily overflow the host memory, since -most datasets will not fit in completely. 2. This can lead to unnecessarily long warm-up times when data is buffered -that is only needed at runtime. +`Grouper`. There are two issues with that: + +1. If not used carefully, this can easily overflow the host memory, since most datasets will not fit in completely. +2. This can lead to unnecessarily long warm-up times when data is buffered that is only needed at runtime. Thus, all buffered datapipes should be used as early as possible, e.g. zipping two datapipes of file handles rather than trying to zip already loaded images. There are two special datapipes that are not used through their class, but through the functions `hint_shuffling` and -`hint_sharding`. As the name implies they only hint part in the datapipe graph where shuffling and sharding should take -place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` and are -required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`. +`hint_sharding`. As the name implies they only hint at a location in the datapipe graph where shuffling and sharding +should take place, but are no-ops by default. They can be imported from `torchvision.prototype.datasets.utils._internal` +and are required in each dataset. `hint_shuffling` has to be placed before `hint_sharding`. Finally, each item in the final datapipe should be a dictionary with `str` keys. There is no standardization of the names (yet!). +### `__len__` + +This returns an integer denoting the number of samples that can be drawn from the dataset. Please use +[underscores](https://peps.python.org/pep-0515/) after every three digits starting from the right to enhance the +readability. For example, `1_281_167` vs. `1281167`. + +If there are only two different numbers, a simple `if` / `else` is fine: + +```py +def __len__(self): + return 12_345 if self._split == "train" else 6_789 +``` + +If there are more options, using a dictionary usually is the most readable option: + +```py +def __len__(self): + return { + "train": 3, + "val": 2, + "test": 1, + }[self._split] +``` + +If the number of samples depends on more than one parameter, you can use tuples as dictionary keys: + +```py +def __len__(self): + return { + ("train", "bar"): 4, + ("train", "baz"): 3, + ("test", "bar"): 2, + ("test", "baz"): 1, + }[(self._split, self._foo)] +``` + +The length of the datapipe is only an annotation for subsequent processing of the datapipe and not needed during the +development process. Since it is an `@abstractmethod` you still have to implement it from the start. The canonical way +is to define a dummy method like + +```py +def __len__(self): + return 1 +``` + +and only fill it with the correct data if the implementation is otherwise finished. +[See below](#how-do-i-compute-the-number-of-samples) for a possible way to compute the number of samples. + ## Tests To test the dataset implementation, you usually don't need to add any tests, but need to provide a mock-up of the data. This mock-up should resemble the original data as close as necessary, while containing only few examples. To do this, add a new function in [`test/builtin_dataset_mocks.py`](../../../../test/builtin_dataset_mocks.py) with the -same name as you have defined in `_make_config()` (if the name includes hyphens `-`, replace them with underscores `_`) -and decorate it with `@register_mock`: +same name as you have used in `@register_info` and `@register_dataset`. This function is called "mock data function". +Decorate it with `@register_mock(configs=[dict(...), ...])`. Each dictionary denotes one configuration that the dataset +will be loaded with, e.g. `datasets.load("my-dataset", **config)`. For the most common case of a product of all options, +you can use the `combinations_grid()` helper function, e.g. +`configs=combinations_grid(split=("train", "test"), foo=("bar", "baz"))`. + +In case the name of the dataset includes hyphens `-`, replace them with underscores `_` in the function name and pass +the `name` parameter to `@register_mock` ```py # this is defined in torchvision/prototype/datasets/_builtin +@register_dataset("my-dataset") class MyDataset(Dataset): - def _make_info(self) -> DatasetInfo: - return DatasetInfo( - "my-dataset", - ... - ) - -@register_mock -def my_dataset(info, root, config): + ... + +@register_mock(name="my-dataset", configs=...) +def my_dataset(root, config): ... ``` -The function receives three arguments: +The mock data function receives two arguments: -- `info`: The return value of `_make_info()`. - `root`: A [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path) of a folder, in which the data needs to be placed. -- `config`: The configuration to generate the data for. This is the same value that `_make_datapipe()` receives. +- `config`: The configuration to generate the data for. This is one of the dictionaries defined in + `@register_mock(configs=...)` The function should generate all files that are needed for the current `config`. Each file should be complete, e.g. if -the dataset only has a single archive that contains multiple splits, you need to generate all regardless of the current -`config`. Although this seems odd at first, this is important. Consider the following original data setup: +the dataset only has a single archive that contains multiple splits, you need to generate the full archive regardless of +the current `config`. Although this seems odd at first, this is important. Consider the following original data setup: ``` root @@ -167,9 +234,8 @@ root For map-style datasets (like the one currently in `torchvision.datasets`), one explicitly selects the files they want to load. For example, something like `(root / split).iterdir()` works fine even if only the specific split folder is present. With iterable-style datasets though, we get something like `root.iterdir()` from `resource_dps` in -`_make_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data -for the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real -data. +`_datapipe()` and need to manually `Filter` it to only keep the files we want. If we would only generate the data for +the current `config`, the test would also pass if the dataset is missing the filtering, but would fail on the real data. For datasets that are ported from the old API, we already have some mock data in [`test/test_datasets.py`](../../../../test/test_datasets.py). You can find the test case corresponding test case there @@ -178,8 +244,6 @@ and have a look at the `inject_fake_data` function. There are a few differences - `tmp_dir` corresponds to `root`, but is a `str` rather than a [`pathlib.Path`](https://docs.python.org/3/library/pathlib.html#pathlib.Path). Thus, you often see something like `folder = pathlib.Path(tmp_dir)`. This is not needed. -- Although both parameters are called `config`, the value in the new tests is a namespace. Thus, please use `config.foo` - over `config["foo"]` to enhance readability. - The data generated by `inject_fake_data` was supposed to be in an extracted state. This is no longer the case for the new mock-ups. Thus, you need to use helper functions like `make_zip` or `make_tar` to actually generate the files specified in the dataset. @@ -196,9 +260,9 @@ Finally, you can run the tests with `pytest test/test_prototype_builtin_datasets ### How do I start? -Get the skeleton of your dataset class ready with all 3 methods. For `_make_datapipe()`, you can just do +Get the skeleton of your dataset class ready with all 4 methods. For `_datapipe()`, you can just do `return resources_dp[0]` to get started. Then import the dataset class in -`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset and it will be +`torchvision/prototype/datasets/_builtin/__init__.py`: this will automatically register the dataset, and it will be instantiable via `datasets.load("mydataset")`. On a separate script, try something like ```py @@ -206,7 +270,7 @@ from torchvision.prototype import datasets dataset = datasets.load("mydataset") for sample in dataset: - print(sample) # this is the content of an item in datapipe returned by _make_datapipe() + print(sample) # this is the content of an item in datapipe returned by _datapipe() break # Or you can also inspect the sample in a debugger ``` @@ -217,15 +281,24 @@ datapipes and return the appropriate dictionary format. ### How do I handle a dataset that defines many categories? -As a rule of thumb, `datasets.utils.DatasetInfo(..., categories=)` should only be set directly for ten categories or -fewer. If more categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line -specifies a category. If `$NAME` matches the name of the dataset (which it definitively should!) it will be -automatically loaded if `categories=` is not set. +As a rule of thumb, `categories` in the info dictionary should only be set manually for ten categories or fewer. If more +categories are needed, you can add a `$NAME.categories` file to the `_builtin` folder in which each line specifies a +category. To load such a file, use the `from torchvision.prototype.datasets.utils._internal import read_categories_file` +function and pass it `$NAME`. In case the categories can be generated from the dataset files, e.g. the dataset follows an image folder approach where -each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. It gets -passed the `root` path to the resources, but they have to be manually loaded, e.g. -`self.resources(config)[0].load(root)`. The method should return a sequence of strings representing the category names. +each folder denotes the name of the category, the dataset can overwrite the `_generate_categories` method. The method +should return a sequence of strings representing the category names. In the method body, you'll have to manually load +the resources, e.g. + +```py +resources = self._resources() +dp = resources[0].load(self._root) +``` + +Note that it is not necessary here to keep a datapipe until the final step. Stick with datapipes as long as it makes +sense and afterwards materialize the data with `next(iter(dp))` or `list(dp)` and proceed with that. + To generate the `$NAME.categories` file, run `python -m torchvision.prototype.datasets.generate_category_files $NAME`. ### What if a resource file forms an I/O bottleneck? @@ -235,3 +308,33 @@ the performance hit becomes significant, the archives can still be preprocessed. `preprocess` parameter that can be a `Callable[[pathlib.Path], pathlib.Path]` where the input points to the file to be preprocessed and the return value should be the result of the preprocessing to load. For convenience, `preprocess` also accepts `"decompress"` and `"extract"` to handle these common scenarios. + +### How do I compute the number of samples? + +Unless the authors of the dataset published the exact numbers (even in this case we should check), there is no other way +than to iterate over the dataset and count the number of samples: + +```py +import itertools +from torchvision.prototype import datasets + + +def combinations_grid(**kwargs): + return [dict(zip(kwargs.keys(), values)) for values in itertools.product(*kwargs.values())] + + +# If you have implemented the mock data function for the dataset tests, you can simply copy-paste from there +configs = combinations_grid(split=("train", "test"), foo=("bar", "baz")) + +for config in configs: + dataset = datasets.load("my-dataset", **config) + + num_samples = 0 + for _ in dataset: + num_samples += 1 + + print(", ".join(f"{key}={value}" for key, value in config.items()), num_samples) +``` + +To speed this up, it is useful to temporarily comment out all unnecessary I/O, such as loading of images or annotation +files.