From f66cf9ec36ae4498885e9f2b255b2baf39354ffc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 21 Feb 2022 10:21:29 +0100 Subject: [PATCH 1/8] refactor prototype datasets to inherit from IterDataPipe --- test/test_prototype_builtin_datasets.py | 40 ++++++------ torchvision/prototype/datasets/__init__.py | 2 +- torchvision/prototype/datasets/_api.py | 63 ++++++++++++++++--- .../prototype/datasets/_builtin/imagenet.py | 50 ++++++++++++++- .../prototype/datasets/utils/__init__.py | 2 +- .../prototype/datasets/utils/_dataset.py | 26 +++++++- 6 files changed, 149 insertions(+), 34 deletions(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index eaa92094ad7..1ea9c1aabab 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -34,6 +34,15 @@ def test_coverage(): ) +# TODO: replace this with a simple call to datasets.load() as soon all datasets are migrated and thus datasets.load2() +# can be merged into datasets.load() +def _load_dataset(name, **options): + try: + return datasets.load2(name, **options) + except ValueError: + return datasets.load(name, **options) + + class TestCommon: @parametrize_dataset_mocks(DATASET_MOCKS) def test_smoke(self, test_home, dataset_mock, config): @@ -41,6 +50,7 @@ def test_smoke(self, test_home, dataset_mock, config): dataset = datasets.load(dataset_mock.name, **config) + # TODO: check for Dataset2 after all datasets are migrated if not isinstance(dataset, IterDataPipe): raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") @@ -48,7 +58,7 @@ def test_smoke(self, test_home, dataset_mock, config): def test_sample(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) + dataset = _load_dataset(dataset_mock.name, **config) try: sample = next(iter(dataset)) @@ -65,7 +75,7 @@ def test_sample(self, test_home, dataset_mock, config): def test_num_samples(self, test_home, dataset_mock, config): mock_info = dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) + dataset = _load_dataset(dataset_mock.name, **config) num_samples = 0 for _ in dataset: @@ -73,24 +83,11 @@ def test_num_samples(self, test_home, dataset_mock, config): assert num_samples == mock_info["num_samples"] - @parametrize_dataset_mocks(DATASET_MOCKS) - def test_decoding(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) - - dataset = datasets.load(dataset_mock.name, **config) - - undecoded_features = {key for key, value in next(iter(dataset)).items() if isinstance(value, io.IOBase)} - if undecoded_features: - raise AssertionError( - f"The values of key(s) " - f"{sequence_to_str(sorted(undecoded_features), separate_last='and ')} were not decoded." - ) - @parametrize_dataset_mocks(DATASET_MOCKS) def test_no_vanilla_tensors(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) + dataset = _load_dataset(dataset_mock.name, **config) vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} if vanilla_tensors: @@ -103,7 +100,7 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config): def test_transformable(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) + dataset = _load_dataset(dataset_mock.name, **config) next(iter(dataset.map(transforms.Identity()))) @@ -138,8 +135,7 @@ def scan(graph): yield from scan(sub_graph) dataset_mock.prepare(test_home, config) - - dataset = datasets.load(dataset_mock.name, **config) + dataset = _load_dataset(dataset_mock.name, **config) if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") @@ -147,7 +143,7 @@ def scan(graph): @parametrize_dataset_mocks(DATASET_MOCKS) def test_save_load(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) + dataset = _load_dataset(dataset_mock.name, **config) sample = next(iter(dataset)) with io.BytesIO() as buffer: @@ -161,7 +157,7 @@ class TestQMNIST: def test_extra_label(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) + dataset = _load_dataset(dataset_mock.name, **config) sample = next(iter(dataset)) for key, type in ( @@ -186,7 +182,7 @@ def test_label_matches_path(self, test_home, dataset_mock, config): dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **config) + dataset = _load_dataset(dataset_mock.name, **config) for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index bf99e175d36..a74517af504 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -11,5 +11,5 @@ from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import list_datasets, info, load # usort: skip +from ._api import list_datasets, info, load, register_info, register_dataset, load2 # usort: skip from ._folder import from_data_folder, from_image_folder diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 13ee920cea2..b6b072f37c5 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,12 +1,12 @@ import os -from typing import Any, Dict, List +import pathlib +from typing import Any, Dict, List, Callable, Type, Optional, Union from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo +from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, Dataset2 from torchvision.prototype.utils._internal import add_suggestion -from . import _builtin DATASETS: Dict[str, Dataset] = {} @@ -15,11 +15,6 @@ def register(dataset: Dataset) -> None: DATASETS[dataset.name] = dataset -for name, obj in _builtin.__dict__.items(): - if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset: - register(obj()) - - def list_datasets() -> List[str]: return sorted(DATASETS.keys()) @@ -57,3 +52,55 @@ def load( root = os.path.join(home(), dataset.name) return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) + + +BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} + + +def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: + def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: + BUILTIN_INFOS[name] = fn() + return fn + + return wrapper + + +def info2(name: str) -> Dict[str, Any]: + try: + return BUILTIN_INFOS[name] + except KeyError: + raise ValueError + + +BUILTIN_DATASETS = {} + + +def register_dataset(name: str) -> Callable[[Type], Type]: + def wrapper(dataset_cls: Type) -> Type: + if not issubclass(dataset_cls, Dataset2): + raise TypeError + + BUILTIN_DATASETS[name] = dataset_cls + + return dataset_cls + + return wrapper + + +def load2(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **options: Any) -> Dataset2: + try: + dataset_cls = BUILTIN_DATASETS[name] + except KeyError: + raise ValueError + + if root is None: + root = pathlib.Path(home()) / name + + return dataset_cls(root, **options) + + +from . import _builtin + +for name, obj in _builtin.__dict__.items(): + if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset: + register(obj()) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 0d11b642c13..6f17fa5c365 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,7 +1,8 @@ import functools import pathlib import re -from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast +from types import SimpleNamespace +from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer from torchdata.datapipes.iter import TarArchiveReader @@ -11,6 +12,7 @@ DatasetInfo, OnlineResource, ManualDownloadResource, + Dataset2, ) from torchvision.prototype.datasets.utils._internal import ( INFINITE_BUFFER_SIZE, @@ -25,6 +27,8 @@ from torchvision.prototype.features import Label, EncodedImage from torchvision.prototype.utils._internal import FrozenMapping +from .._api import register_dataset, register_info + class ImageNetResource(ManualDownloadResource): def __init__(self, **kwargs: Any) -> None: @@ -201,3 +205,47 @@ def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) return categories_and_wnids + + +NAME = "imagenet" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + return dict(categories=categories, wnids=wnids) + + +@register_dataset(NAME) +class ImageNet2(Dataset2): + def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: + if split not in {"train", "val", "test"}: + raise ValueError + self._split = split + + info = _info() + categories, wnids = info["categories"], info["wnids"] + + self._old_style_dataset = ImageNet() + self._old_style_config = self._old_style_dataset.info.make_config(split=self._split) + + self.categories = categories + self.info = SimpleNamespace( + wnid_to_category=zip(wnids, categories), + category_to_wnid=zip(categories, wnids), + ) + + super().__init__(root) + + def _resources(self) -> List[OnlineResource]: + return self._old_style_dataset.resources(self._old_style_config) + + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + return self._old_style_dataset._make_datapipe(resource_dps, config=self._old_style_config) + + def __len__(self) -> int: + return { + "train": 1_281_167, + "val": 50_000, + "test": 100_000, + }[self._split] diff --git a/torchvision/prototype/datasets/utils/__init__.py b/torchvision/prototype/datasets/utils/__init__.py index 9423b65a8ee..a16a839b594 100644 --- a/torchvision/prototype/datasets/utils/__init__.py +++ b/torchvision/prototype/datasets/utils/__init__.py @@ -1,4 +1,4 @@ from . import _internal # usort: skip -from ._dataset import DatasetConfig, DatasetInfo, Dataset +from ._dataset import DatasetConfig, DatasetInfo, Dataset, Dataset2 from ._query import SampleQuery from ._resource import OnlineResource, HttpResource, GDriveResource, ManualDownloadResource, KaggleDownloadResource diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 5ee7c5ccc60..c8f6ca6733e 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -4,7 +4,7 @@ import itertools import os import pathlib -from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection +from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection, Iterator from torch.utils.data import IterDataPipe from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str @@ -181,3 +181,27 @@ def load( def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: raise NotImplementedError + + +class Dataset2(IterDataPipe[Dict[str, Any]], abc.ABC): + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: + self._root = pathlib.Path(root).expanduser().resolve() + resources = [ + resource.load(self._root, skip_integrity_check=skip_integrity_check) for resource in self._resources() + ] + self._dp = self._datapipe(resources) + + def __iter__(self) -> Iterator[Dict[str, Any]]: + yield from self._dp + + @abc.abstractmethod + def _resources(self) -> List[OnlineResource]: + pass + + @abc.abstractmethod + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + pass + + @abc.abstractmethod + def __len__(self) -> int: + pass From 1c3bb2375e391504ca0e603a6d8a1c68dee6cac6 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 23 Feb 2022 09:04:43 +0100 Subject: [PATCH 2/8] depend on new architecture --- .pre-commit-config.yaml | 2 +- test/builtin_dataset_mocks.py | 87 +++++----- test/test_prototype_builtin_datasets.py | 101 ++++++----- torchvision/prototype/datasets/__init__.py | 3 +- torchvision/prototype/datasets/_api.py | 103 ++++------- .../prototype/datasets/_builtin/imagenet.py | 163 ++++++------------ .../prototype/datasets/utils/_dataset.py | 15 ++ 7 files changed, 200 insertions(+), 274 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d4e9d2e337a..217083862ad 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: # name: Upgrade code - repo: https://github.com/omnilib/ufmt - rev: v1.3.0 + rev: v1.3.2 hooks: - id: ufmt additional_dependencies: diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 123d8f29d3f..96ca78a1c9a 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -16,11 +16,10 @@ import PIL.Image import pytest import torch -from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file +from datasets_utils import make_zip, make_tar, create_image_folder, create_image_file, combinations_grid from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor -from torchvision.prototype.datasets._api import find -from torchvision.prototype.utils._internal import sequence_to_str +from torchvision.prototype import datasets make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) @@ -30,13 +29,11 @@ class DatasetMock: - def __init__(self, name, mock_data_fn): - self.dataset = find(name) - self.info = self.dataset.info - self.name = self.info.name - + def __init__(self, name, *, mock_data_fn, configs): + # FIXME: error handling for unknown names + self.name = name self.mock_data_fn = mock_data_fn - self.configs = self.info._configs + self.configs = configs def _parse_mock_info(self, mock_info): if mock_info is None: @@ -61,27 +58,30 @@ def _parse_mock_info(self, mock_info): return mock_info - def prepare(self, home, config): + def prepare(self, home, **options): root = home / self.name root.mkdir(exist_ok=True) - mock_info = self._parse_mock_info(self.mock_data_fn(self.info, root, config)) + mock_info = self._parse_mock_info(self.mock_data_fn(datasets.info(self.name), root, **options)) - available_file_names = {path.name for path in root.glob("*")} - required_file_names = {resource.file_name for resource in self.dataset.resources(config)} - missing_file_names = required_file_names - available_file_names - if missing_file_names: - raise pytest.UsageError( - f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " - f"for {config}, but they were not created by the mock data function." - ) + # # FIXME: We need to handle missing files here + # dataset = datasets.load2(self.name, **options) + # + # available_file_names = {path.name for path in root.glob("*")} + # required_file_names = {resource.file_name for resource in self.dataset.resources(config)} + # missing_file_names = required_file_names - available_file_names + # if missing_file_names: + # raise pytest.UsageError( + # f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " + # f"for {config}, but they were not created by the mock data function." + # ) return mock_info -def config_id(name, config): +def config_id(name, options): parts = [name] - for name, value in config.items(): + for name, value in options.items(): if isinstance(value, bool): part = ("" if value else "no_") + name else: @@ -95,6 +95,8 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): for mock in dataset_mocks: if isinstance(mock, DatasetMock): mocks[mock.name] = mock + elif isinstance(mock, collections.abc.Sequence): + mocks.update({mock_.name: mock_ for mock_ in mock}) elif isinstance(mock, collections.abc.Mapping): mocks.update(mock) else: @@ -111,21 +113,20 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): raise pytest.UsageError() return pytest.mark.parametrize( - ("dataset_mock", "config"), + ("dataset_mock", "options"), [ - pytest.param(dataset_mock, config, id=config_id(name, config), marks=marks.get(name, ())) + pytest.param(dataset_mock, options, id=config_id(name, options), marks=marks.get(name, ())) for name, dataset_mock in dataset_mocks.items() - for config in dataset_mock.configs + for options in dataset_mock.configs ], ) -DATASET_MOCKS = {} +DATASET_MOCKS = [] def register_mock(fn): - name = fn.__name__.replace("_", "-") - DATASET_MOCKS[name] = DatasetMock(name, fn) + # TODO: remove this decorator after all datasets have been migrated return fn @@ -217,7 +218,7 @@ def mnist(info, root, config): ) -DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) +# DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) @register_mock @@ -430,18 +431,15 @@ def caltech256(info, root, config): return num_images_per_category * len(info.categories) -@register_mock -def imagenet(info, root, config): +def imagenet_mock_data_fn(info, root, **options): from scipy.io import savemat - categories = info.categories - wnids = [info.extra.category_to_wnid[category] for category in categories] - if config.split == "train": - num_samples = len(wnids) + if options["split"] == "train": + num_samples = len(info["wnids"]) archive_name = "ILSVRC2012_img_train.tar" files = [] - for wnid in wnids: + for wnid in info["wnids"]: create_image_folder( root=root, name=wnid, @@ -449,7 +447,7 @@ def imagenet(info, root, config): num_examples=1, ) files.append(make_tar(root, f"{wnid}.tar")) - elif config.split == "val": + elif options["split"] == "val": num_samples = 3 archive_name = "ILSVRC2012_img_val.tar" files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] @@ -459,20 +457,20 @@ def imagenet(info, root, config): data_root.mkdir(parents=True) with open(data_root / "ILSVRC2012_validation_ground_truth.txt", "w") as file: - for label in torch.randint(0, len(wnids), (num_samples,)).tolist(): + for label in torch.randint(0, len(info["wnids"]), (num_samples,)).tolist(): file.write(f"{label}\n") num_children = 0 synsets = [ (idx, wnid, category, "", num_children, [], 0, 0) - for idx, (category, wnid) in enumerate(zip(categories, wnids), 1) + for idx, (category, wnid) in enumerate(zip(info["categories"], info["wnids"]), 1) ] num_children = 1 synsets.extend((0, "", "", "", num_children, [], 0, 0) for _ in range(5)) savemat(data_root / "meta.mat", dict(synsets=synsets)) make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") - else: # config.split == "test" + else: # options["split"] == "test" num_samples = 5 archive_name = "ILSVRC2012_img_test_v10102019.tar" files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] @@ -482,6 +480,17 @@ def imagenet(info, root, config): return num_samples +DATASET_MOCKS.append( + DatasetMock( + "imagenet", + mock_data_fn=imagenet_mock_data_fn, + configs=combinations_grid( + split=("train", "val", "test"), + ), + ) +) + + class CocoMockData: @classmethod def _make_images_archive(cls, root, name, *, num_samples): diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 1ea9c1aabab..fab1e5f921c 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -6,9 +6,8 @@ import torch from builtin_dataset_mocks import parametrize_dataset_mocks, DATASET_MOCKS from torch.testing._comparison import assert_equal, TensorLikePair, ObjectPair -from torch.utils.data.datapipes.iter.grouping import ShardingFilterIterDataPipe as ShardingFilter from torch.utils.data.graph import traverse -from torchdata.datapipes.iter import IterDataPipe, Shuffler +from torchdata.datapipes.iter import Shuffler, ShardingFilter from torchvision.prototype import transforms, datasets from torchvision.prototype.utils._internal import sequence_to_str @@ -25,7 +24,7 @@ def test_home(mocker, tmp_path): def test_coverage(): - untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys() + untested_datasets = set(datasets.list_datasets()) - {mock.name for mock in DATASET_MOCKS} if untested_datasets: raise AssertionError( f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} " @@ -34,31 +33,31 @@ def test_coverage(): ) -# TODO: replace this with a simple call to datasets.load() as soon all datasets are migrated and thus datasets.load2() -# can be merged into datasets.load() -def _load_dataset(name, **options): - try: - return datasets.load2(name, **options) - except ValueError: - return datasets.load(name, **options) +class TestCommon: + @pytest.mark.parametrize("name", datasets.list_datasets()) + def test_info(self, name): + try: + info = datasets.info(name) + except ValueError: + raise AssertionError("No info available.") from None + if not (isinstance(info, dict) and all(isinstance(key, str) for key in info.keys())): + raise AssertionError("Info should be a dictionary with string keys.") -class TestCommon: @parametrize_dataset_mocks(DATASET_MOCKS) - def test_smoke(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_smoke(self, test_home, dataset_mock, options): + dataset_mock.prepare(test_home, **options) - dataset = datasets.load(dataset_mock.name, **config) + dataset = datasets.load(dataset_mock.name, **options) - # TODO: check for Dataset2 after all datasets are migrated - if not isinstance(dataset, IterDataPipe): - raise AssertionError(f"Loading the dataset should return an IterDataPipe, but got {type(dataset)} instead.") + if not isinstance(dataset, datasets.utils.Dataset2): + raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_sample(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_sample(self, test_home, dataset_mock, options): + dataset_mock.prepare(test_home, **options) - dataset = _load_dataset(dataset_mock.name, **config) + dataset = datasets.load(dataset_mock.name, **options) try: sample = next(iter(dataset)) @@ -72,22 +71,18 @@ def test_sample(self, test_home, dataset_mock, config): raise AssertionError("Sample dictionary is empty.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_num_samples(self, test_home, dataset_mock, config): - mock_info = dataset_mock.prepare(test_home, config) - - dataset = _load_dataset(dataset_mock.name, **config) + def test_num_samples(self, test_home, dataset_mock, options): + mock_info = dataset_mock.prepare(test_home, **options) - num_samples = 0 - for _ in dataset: - num_samples += 1 + dataset = datasets.load(dataset_mock.name, **options) - assert num_samples == mock_info["num_samples"] + assert len(list(dataset)) == mock_info["num_samples"] @parametrize_dataset_mocks(DATASET_MOCKS) - def test_no_vanilla_tensors(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_no_vanilla_tensors(self, test_home, dataset_mock, options): + dataset_mock.prepare(test_home, **options) - dataset = _load_dataset(dataset_mock.name, **config) + dataset = datasets.load(dataset_mock.name, **options) vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} if vanilla_tensors: @@ -97,13 +92,14 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, config): ) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_transformable(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_transformable(self, test_home, dataset_mock, options): + dataset_mock.prepare(test_home, **options) - dataset = _load_dataset(dataset_mock.name, **config) + dataset = datasets.load(dataset_mock.name, **options) next(iter(dataset.map(transforms.Identity()))) + @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks( DATASET_MOCKS, marks={ @@ -112,13 +108,14 @@ def test_transformable(self, test_home, dataset_mock, config): ) }, ) - def test_traversable(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_traversable(self, test_home, dataset_mock, options): + dataset_mock.prepare(test_home, **options) - dataset = datasets.load(dataset_mock.name, **config) + dataset = datasets.load(dataset_mock.name, **options) traverse(dataset) + @pytest.mark.xfail(reason="See https://github.com/pytorch/data/issues/237") @parametrize_dataset_mocks( DATASET_MOCKS, marks={ @@ -128,22 +125,22 @@ def test_traversable(self, test_home, dataset_mock, config): }, ) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) - def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): + def test_has_annotations(self, test_home, dataset_mock, options, annotation_dp_type): def scan(graph): for node, sub_graph in graph.items(): yield node yield from scan(sub_graph) - dataset_mock.prepare(test_home, config) - dataset = _load_dataset(dataset_mock.name, **config) + dataset_mock.prepare(test_home, **options) + dataset = datasets.load(dataset_mock.name, **options) if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_save_load(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) - dataset = _load_dataset(dataset_mock.name, **config) + def test_save_load(self, test_home, dataset_mock, options): + dataset_mock.prepare(test_home, **options) + dataset = datasets.load(dataset_mock.name, **options) sample = next(iter(dataset)) with io.BytesIO() as buffer: @@ -152,12 +149,12 @@ def test_save_load(self, test_home, dataset_mock, config): assert_samples_equal(torch.load(buffer), sample) -@parametrize_dataset_mocks(DATASET_MOCKS["qmnist"]) +@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "qmnist"]) class TestQMNIST: - def test_extra_label(self, test_home, dataset_mock, config): - dataset_mock.prepare(test_home, config) + def test_extra_label(self, test_home, dataset_mock, options): + dataset_mock.prepare(test_home, **options) - dataset = _load_dataset(dataset_mock.name, **config) + dataset = datasets.load(dataset_mock.name, **options) sample = next(iter(dataset)) for key, type in ( @@ -172,17 +169,17 @@ def test_extra_label(self, test_home, dataset_mock, config): assert key in sample and isinstance(sample[key], type) -@parametrize_dataset_mocks(DATASET_MOCKS["gtsrb"]) +@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "gtsrb"]) class TestGTSRB: - def test_label_matches_path(self, test_home, dataset_mock, config): + def test_label_matches_path(self, test_home, dataset_mock, options): # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # This test makes sure that they're both the same - if config.split != "train": + if options["split"] != "train": return - dataset_mock.prepare(test_home, config) + dataset_mock.prepare(test_home, **options) - dataset = _load_dataset(dataset_mock.name, **config) + dataset = datasets.load(dataset_mock.name, **options) for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index a74517af504..44c66e422f2 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -11,5 +11,6 @@ from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import list_datasets, info, load, register_info, register_dataset, load2 # usort: skip +from ._api import list_datasets, info, load, register_info, register_dataset # usort: skip from ._folder import from_data_folder, from_image_folder +from ._builtin import * diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index b6b072f37c5..c0e3f1ad1cb 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -1,34 +1,50 @@ -import os import pathlib -from typing import Any, Dict, List, Callable, Type, Optional, Union +from typing import Any, Dict, List, Callable, Type, Optional, Union, TypeVar -from torch.utils.data import IterDataPipe from torchvision.prototype.datasets import home -from torchvision.prototype.datasets.utils import Dataset, DatasetInfo, Dataset2 +from torchvision.prototype.datasets.utils import Dataset2 from torchvision.prototype.utils._internal import add_suggestion -DATASETS: Dict[str, Dataset] = {} +T = TypeVar("T") +D = TypeVar("D", bound=Type[Dataset2]) +BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} -def register(dataset: Dataset) -> None: - DATASETS[dataset.name] = dataset + +def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: + def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: + BUILTIN_INFOS[name] = fn() + return fn + + return wrapper + + +BUILTIN_DATASETS = {} + + +def register_dataset(name: str) -> Callable[[D], D]: + def wrapper(dataset_cls: D) -> D: + BUILTIN_DATASETS[name] = dataset_cls + return dataset_cls + + return wrapper def list_datasets() -> List[str]: - return sorted(DATASETS.keys()) + return sorted(BUILTIN_DATASETS.keys()) -def find(name: str) -> Dataset: +def find(dct: Dict[str, T], name: str) -> T: name = name.lower() try: - return DATASETS[name] + return dct[name] except KeyError as error: raise ValueError( add_suggestion( f"Unknown dataset '{name}'.", word=name, - possibilities=DATASETS.keys(), + possibilities=dct.keys(), alternative_hint=lambda _: ( "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." ), @@ -36,71 +52,14 @@ def find(name: str) -> Dataset: ) from error -def info(name: str) -> DatasetInfo: - return find(name).info - - -def load( - name: str, - *, - skip_integrity_check: bool = False, - **options: Any, -) -> IterDataPipe[Dict[str, Any]]: - dataset = find(name) - - config = dataset.info.make_config(**options) - root = os.path.join(home(), dataset.name) - - return dataset.load(root, config=config, skip_integrity_check=skip_integrity_check) - - -BUILTIN_INFOS: Dict[str, Dict[str, Any]] = {} - - -def register_info(name: str) -> Callable[[Callable[[], Dict[str, Any]]], Callable[[], Dict[str, Any]]]: - def wrapper(fn: Callable[[], Dict[str, Any]]) -> Callable[[], Dict[str, Any]]: - BUILTIN_INFOS[name] = fn() - return fn - - return wrapper - - -def info2(name: str) -> Dict[str, Any]: - try: - return BUILTIN_INFOS[name] - except KeyError: - raise ValueError - - -BUILTIN_DATASETS = {} +def info(name: str) -> Dict[str, Any]: + return find(BUILTIN_INFOS, name) -def register_dataset(name: str) -> Callable[[Type], Type]: - def wrapper(dataset_cls: Type) -> Type: - if not issubclass(dataset_cls, Dataset2): - raise TypeError - - BUILTIN_DATASETS[name] = dataset_cls - - return dataset_cls - - return wrapper - - -def load2(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **options: Any) -> Dataset2: - try: - dataset_cls = BUILTIN_DATASETS[name] - except KeyError: - raise ValueError +def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **options: Any) -> Dataset2: + dataset_cls = find(BUILTIN_DATASETS, name) if root is None: root = pathlib.Path(home()) / name return dataset_cls(root, **options) - - -from . import _builtin - -for name, obj in _builtin.__dict__.items(): - if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset: - register(obj()) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 6f17fa5c365..34f672c2fc1 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -1,14 +1,10 @@ -import functools import pathlib import re -from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, BinaryIO, Match, cast, Union from torchdata.datapipes.iter import IterDataPipe, LineReader, IterKeyZipper, Mapper, Filter, Demultiplexer from torchdata.datapipes.iter import TarArchiveReader from torchvision.prototype.datasets.utils import ( - Dataset, - DatasetConfig, DatasetInfo, OnlineResource, ManualDownloadResource, @@ -23,44 +19,39 @@ read_mat, hint_sharding, hint_shuffling, + path_accessor, ) from torchvision.prototype.features import Label, EncodedImage -from torchvision.prototype.utils._internal import FrozenMapping from .._api import register_dataset, register_info +NAME = "imagenet" + + +@register_info(NAME) +def _info() -> Dict[str, Any]: + categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) + return dict(categories=categories, wnids=wnids) + + class ImageNetResource(ManualDownloadResource): def __init__(self, **kwargs: Any) -> None: super().__init__("Register on https://image-net.org/ and follow the instructions there.", **kwargs) -class ImageNet(Dataset): - def _make_info(self) -> DatasetInfo: - name = "imagenet" - categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{name}.categories")) - - return DatasetInfo( - name, - dependencies=("scipy",), - categories=categories, - homepage="https://www.image-net.org/", - valid_options=dict(split=("train", "val", "test")), - extra=dict( - wnid_to_category=FrozenMapping(zip(wnids, categories)), - category_to_wnid=FrozenMapping(zip(categories, wnids)), - sizes=FrozenMapping( - [ - (DatasetConfig(split="train"), 1_281_167), - (DatasetConfig(split="val"), 50_000), - (DatasetConfig(split="test"), 100_000), - ] - ), - ), - ) +@register_dataset(NAME) +class ImageNet(Dataset2): + def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: + self._split = self._verify_str_arg(split, "split", {"train", "val", "test"}) - def supports_sharded(self) -> bool: - return True + info = _info() + categories, wnids = info["categories"], info["wnids"] + self._categories: List[str] = categories + self._wnids: List[str] = wnids + self._wnid_to_category = dict(zip(wnids, categories)) + + super().__init__(root) _IMAGES_CHECKSUMS = { "train": "b08200a27a8e34218a0e58fde36b0fe8f73bc377f4acea2d91602057c3ca45bb", @@ -68,15 +59,15 @@ def supports_sharded(self) -> bool: "test_v10102019": "9cf7f8249639510f17d3d8a0deb47cd22a435886ba8e29e2b3223e65a4079eb4", } - def resources(self, config: DatasetConfig) -> List[OnlineResource]: - name = "test_v10102019" if config.split == "test" else config.split + def _resources(self) -> List[OnlineResource]: + name = "test_v10102019" if self._split == "test" else self._split images = ImageNetResource( file_name=f"ILSVRC2012_img_{name}.tar", sha256=self._IMAGES_CHECKSUMS[name], ) resources: List[OnlineResource] = [images] - if config.split == "val": + if self._split == "val": devkit = ImageNetResource( file_name="ILSVRC2012_devkit_t12.tar.gz", sha256="b59243268c0d266621fd587d2018f69e906fb22875aca0e295b48cafaa927953", @@ -85,19 +76,12 @@ def resources(self, config: DatasetConfig) -> List[OnlineResource]: return resources - def num_samples(self, config: DatasetConfig) -> int: - return { - "train": 1_281_167, - "val": 50_000, - "test": 100_000, - }[config.split] - _TRAIN_IMAGE_NAME_PATTERN = re.compile(r"(?Pn\d{8})_\d+[.]JPEG") def _prepare_train_data(self, data: Tuple[str, BinaryIO]) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: path = pathlib.Path(data[0]) wnid = cast(Match[str], self._TRAIN_IMAGE_NAME_PATTERN.match(path.name))["wnid"] - label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) return (label, wnid), data def _prepare_test_data(self, data: Tuple[str, BinaryIO]) -> Tuple[None, Tuple[str, BinaryIO]]: @@ -109,6 +93,13 @@ def _classifiy_devkit(self, data: Tuple[str, BinaryIO]) -> Optional[int]: "ILSVRC2012_validation_ground_truth.txt": 1, }.get(pathlib.Path(data[0]).name) + # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 + # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment + _WNID_MAP = { + "n03126707": "construction crane", + "n03710721": "tank suit", + } + def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tuple[str, str]]: synsets = read_mat(data[1], squeeze_me=True)["synsets"] return [ @@ -118,21 +109,20 @@ def _extract_categories_and_wnids(self, data: Tuple[str, BinaryIO]) -> List[Tupl if num_children == 0 ] - def _imagenet_label_to_wnid(self, imagenet_label: str, *, wnids: List[str]) -> str: - return wnids[int(imagenet_label) - 1] + def _imagenet_label_to_wnid(self, imagenet_label: str) -> str: + return self._wnids[int(imagenet_label) - 1] _VAL_TEST_IMAGE_NAME_PATTERN = re.compile(r"ILSVRC2012_(val|test)_(?P\d{8})[.]JPEG") - def _val_test_image_key(self, data: Tuple[str, Any]) -> int: - path = pathlib.Path(data[0]) - return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name).group("id")) # type: ignore[union-attr] + def _val_test_image_key(self, path: pathlib.Path) -> int: + return int(self._VAL_TEST_IMAGE_NAME_PATTERN.match(path.name)["id"]) # type: ignore[index] def _prepare_val_data( self, data: Tuple[Tuple[int, str], Tuple[str, BinaryIO]] ) -> Tuple[Tuple[Label, str], Tuple[str, BinaryIO]]: label_data, image_data = data _, wnid = label_data - label = Label.from_category(self.info.extra.wnid_to_category[wnid], categories=self.categories) + label = Label.from_category(self._wnid_to_category[wnid], categories=self._categories) return (label, wnid), image_data def _prepare_sample( @@ -147,19 +137,17 @@ def _prepare_sample( image=EncodedImage.from_file(buffer), ) - def _make_datapipe( - self, resource_dps: List[IterDataPipe], *, config: DatasetConfig - ) -> IterDataPipe[Dict[str, Any]]: - if config.split in {"train", "test"}: + def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: + if self._split in {"train", "test"}: dp = resource_dps[0] # the train archive is a tar of tars - if config.split == "train": + if self._split == "train": dp = TarArchiveReader(dp) dp = hint_sharding(dp) dp = hint_shuffling(dp) - dp = Mapper(dp, self._prepare_train_data if config.split == "train" else self._prepare_test_data) + dp = Mapper(dp, self._prepare_train_data if self._split == "train" else self._prepare_test_data) else: # config.split == "val": images_dp, devkit_dp = resource_dps @@ -171,7 +159,7 @@ def _make_datapipe( _, wnids = zip(*next(iter(meta_dp))) label_dp = LineReader(label_dp, decode=True, return_path=False) - label_dp = Mapper(label_dp, functools.partial(self._imagenet_label_to_wnid, wnids=wnids)) + label_dp = Mapper(label_dp, self._imagenet_label_to_wnid) label_dp: IterDataPipe[Tuple[int, str]] = Enumerator(label_dp, 1) label_dp = hint_sharding(label_dp) label_dp = hint_shuffling(label_dp) @@ -180,72 +168,29 @@ def _make_datapipe( label_dp, images_dp, key_fn=getitem(0), - ref_key_fn=self._val_test_image_key, + ref_key_fn=path_accessor(self._val_test_image_key), buffer_size=INFINITE_BUFFER_SIZE, ) dp = Mapper(dp, self._prepare_val_data) return Mapper(dp, self._prepare_sample) - # Although the WordNet IDs (wnids) are unique, the corresponding categories are not. For example, both n02012849 - # and n03126707 are labeled 'crane' while the first means the bird and the latter means the construction equipment - _WNID_MAP = { - "n03126707": "construction crane", - "n03710721": "tank suit", - } + def __len__(self) -> int: + return { + "train": 1_281_167, + "val": 50_000, + "test": 100_000, + }[self._split] - def _generate_categories(self, root: pathlib.Path) -> List[Tuple[str, ...]]: - config = self.info.make_config(split="val") - resources = self.resources(config) + @classmethod + def _generate_categories(cls, root: pathlib.Path) -> List[Tuple[str, ...]]: + dataset = cls(root, split="val") + resources = dataset._resources() devkit_dp = resources[1].load(root) meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) - meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) + meta_dp = Mapper(meta_dp, dataset._extract_categories_and_wnids) categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) return categories_and_wnids - - -NAME = "imagenet" - - -@register_info(NAME) -def _info() -> Dict[str, Any]: - categories, wnids = zip(*DatasetInfo.read_categories_file(BUILTIN_DIR / f"{NAME}.categories")) - return dict(categories=categories, wnids=wnids) - - -@register_dataset(NAME) -class ImageNet2(Dataset2): - def __init__(self, root: Union[str, pathlib.Path], *, split: str = "train") -> None: - if split not in {"train", "val", "test"}: - raise ValueError - self._split = split - - info = _info() - categories, wnids = info["categories"], info["wnids"] - - self._old_style_dataset = ImageNet() - self._old_style_config = self._old_style_dataset.info.make_config(split=self._split) - - self.categories = categories - self.info = SimpleNamespace( - wnid_to_category=zip(wnids, categories), - category_to_wnid=zip(categories, wnids), - ) - - super().__init__(root) - - def _resources(self) -> List[OnlineResource]: - return self._old_style_dataset.resources(self._old_style_config) - - def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, Any]]: - return self._old_style_dataset._make_datapipe(resource_dps, config=self._old_style_config) - - def __len__(self) -> int: - return { - "train": 1_281_167, - "val": 50_000, - "test": 100_000, - }[self._split] diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index c8f6ca6733e..ef5c1c29e0b 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -7,6 +7,7 @@ from typing import Any, Dict, List, Optional, Sequence, Union, Tuple, Collection, Iterator from torch.utils.data import IterDataPipe +from torchvision.datasets.utils import verify_str_arg from torchvision.prototype.utils._internal import FrozenBunch, make_repr, add_suggestion, sequence_to_str from .._home import use_sharded_dataset @@ -184,6 +185,16 @@ def _generate_categories(self, root: pathlib.Path) -> Sequence[Union[str, Sequen class Dataset2(IterDataPipe[Dict[str, Any]], abc.ABC): + @staticmethod + def _verify_str_arg( + value: str, + arg: Optional[str] = None, + valid_values: Optional[Collection[str]] = None, + *, + custom_msg: Optional[str] = None, + ) -> str: + return verify_str_arg(value, arg, valid_values, custom_msg=custom_msg) + def __init__(self, root: Union[str, pathlib.Path], *, skip_integrity_check: bool = False) -> None: self._root = pathlib.Path(root).expanduser().resolve() resources = [ @@ -205,3 +216,7 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, @abc.abstractmethod def __len__(self) -> int: pass + + @classmethod + def _generate_categories(cls, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: + raise NotImplementedError From bf79e2f8f782d6ab784cb6fc1c22cc9ef5d971a9 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 23 Feb 2022 09:50:33 +0100 Subject: [PATCH 3/8] fix missing file detection --- main.py | 17 +++++++++++++++++ test/builtin_dataset_mocks.py | 24 +++++++++++++----------- 2 files changed, 30 insertions(+), 11 deletions(-) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 00000000000..9db391a3d31 --- /dev/null +++ b/main.py @@ -0,0 +1,17 @@ +# from torchvision import transforms +# from torchvision.transforms import functional as F +# import PIL.Image +# +# image = F.pil_to_tensor(PIL.Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")) +# +# print(image.shape) +# +# transform = transforms.AugMix() +# timage = transform(image) +# +# print(timage.shape) + +# from torchvision.prototype import datasets +# +# for sample in datasets.load("cityscapes"): +# break diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 96ca78a1c9a..2bce8578fa7 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -9,6 +9,7 @@ import pathlib import pickle import random +import unittest.mock import xml.etree.ElementTree as ET from collections import defaultdict, Counter @@ -20,6 +21,7 @@ from torch.nn.functional import one_hot from torch.testing import make_tensor as _make_tensor from torchvision.prototype import datasets +from torchvision.prototype.utils._internal import sequence_to_str make_tensor = functools.partial(_make_tensor, device="cpu") make_scalar = functools.partial(make_tensor, ()) @@ -64,17 +66,17 @@ def prepare(self, home, **options): mock_info = self._parse_mock_info(self.mock_data_fn(datasets.info(self.name), root, **options)) - # # FIXME: We need to handle missing files here - # dataset = datasets.load2(self.name, **options) - # - # available_file_names = {path.name for path in root.glob("*")} - # required_file_names = {resource.file_name for resource in self.dataset.resources(config)} - # missing_file_names = required_file_names - available_file_names - # if missing_file_names: - # raise pytest.UsageError( - # f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " - # f"for {config}, but they were not created by the mock data function." - # ) + with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"): + required_file_names = { + resource.file_name for resource in datasets.load(self.name, root=root, **options)._resources() + } + available_file_names = {path.name for path in root.glob("*")} + missing_file_names = required_file_names - available_file_names + if missing_file_names: + raise pytest.UsageError( + f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " + f"for {options}, but they were not created by the mock data function." + ) return mock_info From 3744e943cf70bbbaa93369352c1c3eecff44aaeb Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 23 Feb 2022 12:08:01 +0100 Subject: [PATCH 4/8] remove unrelated file --- main.py | 17 ----------------- 1 file changed, 17 deletions(-) delete mode 100644 main.py diff --git a/main.py b/main.py deleted file mode 100644 index 9db391a3d31..00000000000 --- a/main.py +++ /dev/null @@ -1,17 +0,0 @@ -# from torchvision import transforms -# from torchvision.transforms import functional as F -# import PIL.Image -# -# image = F.pil_to_tensor(PIL.Image.open("test/assets/encode_jpeg/grace_hopper_517x606.jpg")) -# -# print(image.shape) -# -# transform = transforms.AugMix() -# timage = transform(image) -# -# print(timage.shape) - -# from torchvision.prototype import datasets -# -# for sample in datasets.load("cityscapes"): -# break From fb365bdf811970840a34b98762249b7d99ed4ace Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Feb 2022 16:22:16 +0100 Subject: [PATCH 5/8] reinstante decorator for mock registering --- test/builtin_dataset_mocks.py | 71 ++++++++++++------------- test/test_prototype_builtin_datasets.py | 8 +-- 2 files changed, 38 insertions(+), 41 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 2bce8578fa7..da097a2e15e 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -97,8 +97,6 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): for mock in dataset_mocks: if isinstance(mock, DatasetMock): mocks[mock.name] = mock - elif isinstance(mock, collections.abc.Sequence): - mocks.update({mock_.name: mock_ for mock_ in mock}) elif isinstance(mock, collections.abc.Mapping): mocks.update(mock) else: @@ -124,12 +122,19 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): ) -DATASET_MOCKS = [] +DATASET_MOCKS = {} -def register_mock(fn): - # TODO: remove this decorator after all datasets have been migrated - return fn +def register_mock(name=None, *, configs): + def wrapper(mock_data_fn): + nonlocal name + if name is None: + name = mock_data_fn.__name__ + DATASET_MOCKS[name] = DatasetMock(name, mock_data_fn=mock_data_fn, configs=configs) + + return mock_data_fn + + return wrapper class MNISTMockData: @@ -207,7 +212,7 @@ def generate( return num_samples -@register_mock +# # @register_mock def mnist(info, root, config): train = config.split == "train" images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" @@ -223,7 +228,7 @@ def mnist(info, root, config): # DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) -@register_mock +# # @register_mock def emnist(info, root, config): # The image sets that merge some lower case letters in their respective upper case variant, still use dense # labels in the data files. Thus, num_categories != len(categories) there. @@ -250,7 +255,7 @@ def emnist(info, root, config): return num_samples_map[config] -@register_mock +# # @register_mock def qmnist(info, root, config): num_categories = len(info.categories) if config.split == "train": @@ -327,7 +332,7 @@ def generate( make_tar(root, name, folder, compression="gz") -@register_mock +# @register_mock def cifar10(info, root, config): train_files = [f"data_batch_{idx}" for idx in range(1, 6)] test_files = ["test_batch"] @@ -345,7 +350,7 @@ def cifar10(info, root, config): return len(train_files if config.split == "train" else test_files) -@register_mock +# @register_mock def cifar100(info, root, config): train_files = ["train"] test_files = ["test"] @@ -363,7 +368,7 @@ def cifar100(info, root, config): return len(train_files if config.split == "train" else test_files) -@register_mock +# @register_mock def caltech101(info, root, config): def create_ann_file(root, name): import scipy.io @@ -413,7 +418,7 @@ def create_ann_folder(root, name, file_name_fn, num_examples): return num_images_per_category * len(info.categories) -@register_mock +# @register_mock def caltech256(info, root, config): dir = root / "256_ObjectCategories" num_images_per_category = 2 @@ -433,7 +438,8 @@ def caltech256(info, root, config): return num_images_per_category * len(info.categories) -def imagenet_mock_data_fn(info, root, **options): +@register_mock(configs=combinations_grid(split=("train", "val", "test"))) +def imagenet(info, root, **options): from scipy.io import savemat if options["split"] == "train": @@ -482,17 +488,6 @@ def imagenet_mock_data_fn(info, root, **options): return num_samples -DATASET_MOCKS.append( - DatasetMock( - "imagenet", - mock_data_fn=imagenet_mock_data_fn, - configs=combinations_grid( - split=("train", "val", "test"), - ), - ) -) - - class CocoMockData: @classmethod def _make_images_archive(cls, root, name, *, num_samples): @@ -598,7 +593,7 @@ def generate( return num_samples -@register_mock +# @register_mock def coco(info, root, config): return CocoMockData.generate(root, year=config.year, num_samples=5) @@ -672,12 +667,12 @@ def generate(cls, root): return num_samples_map -@register_mock +# @register_mock def sbd(info, root, config): return SBDMockData.generate(root)[config.split] -@register_mock +# @register_mock def semeion(info, root, config): num_samples = 3 num_categories = len(info.categories) @@ -790,7 +785,7 @@ def generate(cls, root, *, year, trainval): return num_samples_map -@register_mock +# @register_mock def voc(info, root, config): trainval = config.split != "test" return VOCMockData.generate(root, year=config.year, trainval=trainval)[config.split] @@ -884,12 +879,12 @@ def generate(cls, root): return num_samples_map -@register_mock +# @register_mock def celeba(info, root, config): return CelebAMockData.generate(root)[config.split] -@register_mock +# @register_mock def dtd(info, root, config): data_folder = root / "dtd" @@ -937,7 +932,7 @@ def dtd(info, root, config): return num_samples_map[config] -@register_mock +# @register_mock def fer2013(info, root, config): num_samples = 5 if config.split == "train" else 3 @@ -962,7 +957,7 @@ def fer2013(info, root, config): return num_samples -@register_mock +# @register_mock def gtsrb(info, root, config): num_examples_per_class = 5 if config.split == "train" else 3 classes = ("00000", "00042", "00012") @@ -1032,7 +1027,7 @@ def _make_ann_file(path, num_examples, class_idx): return num_examples -@register_mock +# @register_mock def clevr(info, root, config): data_folder = root / "CLEVR_v1.0" @@ -1138,7 +1133,7 @@ def generate(self, root): return num_samples_map -@register_mock +# @register_mock def oxford_iiit_pet(info, root, config): return OxfordIIITPetMockData.generate(root)[config.split] @@ -1304,13 +1299,13 @@ def generate(cls, root): return num_samples_map -@register_mock +# @register_mock def cub200(info, root, config): num_samples_map = (CUB2002011MockData if config.year == "2011" else CUB2002010MockData).generate(root) return num_samples_map[config.split] -@register_mock +# @register_mock def svhn(info, root, config): import scipy.io as sio @@ -1330,7 +1325,7 @@ def svhn(info, root, config): return num_samples -@register_mock +# @register_mock def pcam(info, root, config): import h5py diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index fab1e5f921c..7a90ba881b5 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -24,7 +24,7 @@ def test_home(mocker, tmp_path): def test_coverage(): - untested_datasets = set(datasets.list_datasets()) - {mock.name for mock in DATASET_MOCKS} + untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys() if untested_datasets: raise AssertionError( f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} " @@ -149,7 +149,8 @@ def test_save_load(self, test_home, dataset_mock, options): assert_samples_equal(torch.load(buffer), sample) -@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "qmnist"]) +# FIXME: DATASET_MOCKS["qmnist"] +@parametrize_dataset_mocks({}) class TestQMNIST: def test_extra_label(self, test_home, dataset_mock, options): dataset_mock.prepare(test_home, **options) @@ -169,7 +170,8 @@ def test_extra_label(self, test_home, dataset_mock, options): assert key in sample and isinstance(sample[key], type) -@parametrize_dataset_mocks([mock for mock in DATASET_MOCKS if mock.name == "gtsrb"]) +# FIXME: DATASET_MOCKS["gtsrb"] +@parametrize_dataset_mocks({}) class TestGTSRB: def test_label_matches_path(self, test_home, dataset_mock, options): # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. From 69dd4185e84d51c8e6e41b07b5beb4e435309ed3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Feb 2022 16:27:07 +0100 Subject: [PATCH 6/8] options -> config --- test/builtin_dataset_mocks.py | 32 ++++++------- test/test_prototype_builtin_datasets.py | 62 ++++++++++++------------- torchvision/prototype/datasets/_api.py | 4 +- 3 files changed, 49 insertions(+), 49 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index da097a2e15e..7a9ec242337 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -60,30 +60,30 @@ def _parse_mock_info(self, mock_info): return mock_info - def prepare(self, home, **options): + def prepare(self, home, config): root = home / self.name root.mkdir(exist_ok=True) - mock_info = self._parse_mock_info(self.mock_data_fn(datasets.info(self.name), root, **options)) + mock_info = self._parse_mock_info(self.mock_data_fn(datasets.info(self.name), root, config)) with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"): required_file_names = { - resource.file_name for resource in datasets.load(self.name, root=root, **options)._resources() + resource.file_name for resource in datasets.load(self.name, root=root, **config)._resources() } available_file_names = {path.name for path in root.glob("*")} missing_file_names = required_file_names - available_file_names if missing_file_names: raise pytest.UsageError( f"Dataset '{self.name}' requires the files {sequence_to_str(sorted(missing_file_names))} " - f"for {options}, but they were not created by the mock data function." + f"for {config}, but they were not created by the mock data function." ) return mock_info -def config_id(name, options): +def config_id(name, config): parts = [name] - for name, value in options.items(): + for name, value in config.items(): if isinstance(value, bool): part = ("" if value else "no_") + name else: @@ -113,11 +113,11 @@ def parametrize_dataset_mocks(*dataset_mocks, marks=None): raise pytest.UsageError() return pytest.mark.parametrize( - ("dataset_mock", "options"), + ("dataset_mock", "config"), [ - pytest.param(dataset_mock, options, id=config_id(name, options), marks=marks.get(name, ())) + pytest.param(dataset_mock, config, id=config_id(name, config), marks=marks.get(name, ())) for name, dataset_mock in dataset_mocks.items() - for options in dataset_mock.configs + for config in dataset_mock.configs ], ) @@ -212,7 +212,7 @@ def generate( return num_samples -# # @register_mock +# @register_mock def mnist(info, root, config): train = config.split == "train" images_file = f"{'train' if train else 't10k'}-images-idx3-ubyte.gz" @@ -228,7 +228,7 @@ def mnist(info, root, config): # DATASET_MOCKS.update({name: DatasetMock(name, mnist) for name in ["fashionmnist", "kmnist"]}) -# # @register_mock +# @register_mock def emnist(info, root, config): # The image sets that merge some lower case letters in their respective upper case variant, still use dense # labels in the data files. Thus, num_categories != len(categories) there. @@ -255,7 +255,7 @@ def emnist(info, root, config): return num_samples_map[config] -# # @register_mock +# @register_mock def qmnist(info, root, config): num_categories = len(info.categories) if config.split == "train": @@ -439,10 +439,10 @@ def caltech256(info, root, config): @register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def imagenet(info, root, **options): +def imagenet(info, root, config): from scipy.io import savemat - if options["split"] == "train": + if config["split"] == "train": num_samples = len(info["wnids"]) archive_name = "ILSVRC2012_img_train.tar" @@ -455,7 +455,7 @@ def imagenet(info, root, **options): num_examples=1, ) files.append(make_tar(root, f"{wnid}.tar")) - elif options["split"] == "val": + elif config["split"] == "val": num_samples = 3 archive_name = "ILSVRC2012_img_val.tar" files = [create_image_file(root, f"ILSVRC2012_val_{idx + 1:08d}.JPEG") for idx in range(num_samples)] @@ -478,7 +478,7 @@ def imagenet(info, root, **options): savemat(data_root / "meta.mat", dict(synsets=synsets)) make_tar(root, devkit_root.with_suffix(".tar.gz").name, compression="gz") - else: # options["split"] == "test" + else: # config["split"] == "test" num_samples = 5 archive_name = "ILSVRC2012_img_test_v10102019.tar" files = [create_image_file(root, f"ILSVRC2012_test_{idx + 1:08d}.JPEG") for idx in range(num_samples)] diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 7a90ba881b5..0ba042bcda5 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -45,19 +45,19 @@ def test_info(self, name): raise AssertionError("Info should be a dictionary with string keys.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_smoke(self, test_home, dataset_mock, options): - dataset_mock.prepare(test_home, **options) + def test_smoke(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **options) + dataset = datasets.load(dataset_mock.name, **config) if not isinstance(dataset, datasets.utils.Dataset2): raise AssertionError(f"Loading the dataset should return an Dataset, but got {type(dataset)} instead.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_sample(self, test_home, dataset_mock, options): - dataset_mock.prepare(test_home, **options) + def test_sample(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **options) + dataset = datasets.load(dataset_mock.name, **config) try: sample = next(iter(dataset)) @@ -71,18 +71,18 @@ def test_sample(self, test_home, dataset_mock, options): raise AssertionError("Sample dictionary is empty.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_num_samples(self, test_home, dataset_mock, options): - mock_info = dataset_mock.prepare(test_home, **options) + def test_num_samples(self, test_home, dataset_mock, config): + mock_info = dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **options) + dataset = datasets.load(dataset_mock.name, **config) assert len(list(dataset)) == mock_info["num_samples"] @parametrize_dataset_mocks(DATASET_MOCKS) - def test_no_vanilla_tensors(self, test_home, dataset_mock, options): - dataset_mock.prepare(test_home, **options) + def test_no_vanilla_tensors(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **options) + dataset = datasets.load(dataset_mock.name, **config) vanilla_tensors = {key for key, value in next(iter(dataset)).items() if type(value) is torch.Tensor} if vanilla_tensors: @@ -92,10 +92,10 @@ def test_no_vanilla_tensors(self, test_home, dataset_mock, options): ) @parametrize_dataset_mocks(DATASET_MOCKS) - def test_transformable(self, test_home, dataset_mock, options): - dataset_mock.prepare(test_home, **options) + def test_transformable(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **options) + dataset = datasets.load(dataset_mock.name, **config) next(iter(dataset.map(transforms.Identity()))) @@ -108,10 +108,10 @@ def test_transformable(self, test_home, dataset_mock, options): ) }, ) - def test_traversable(self, test_home, dataset_mock, options): - dataset_mock.prepare(test_home, **options) + def test_traversable(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **options) + dataset = datasets.load(dataset_mock.name, **config) traverse(dataset) @@ -125,22 +125,22 @@ def test_traversable(self, test_home, dataset_mock, options): }, ) @pytest.mark.parametrize("annotation_dp_type", (Shuffler, ShardingFilter)) - def test_has_annotations(self, test_home, dataset_mock, options, annotation_dp_type): + def test_has_annotations(self, test_home, dataset_mock, config, annotation_dp_type): def scan(graph): for node, sub_graph in graph.items(): yield node yield from scan(sub_graph) - dataset_mock.prepare(test_home, **options) - dataset = datasets.load(dataset_mock.name, **options) + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) if not any(type(dp) is annotation_dp_type for dp in scan(traverse(dataset))): raise AssertionError(f"The dataset doesn't contain a {annotation_dp_type.__name__}() datapipe.") @parametrize_dataset_mocks(DATASET_MOCKS) - def test_save_load(self, test_home, dataset_mock, options): - dataset_mock.prepare(test_home, **options) - dataset = datasets.load(dataset_mock.name, **options) + def test_save_load(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) + dataset = datasets.load(dataset_mock.name, **config) sample = next(iter(dataset)) with io.BytesIO() as buffer: @@ -152,10 +152,10 @@ def test_save_load(self, test_home, dataset_mock, options): # FIXME: DATASET_MOCKS["qmnist"] @parametrize_dataset_mocks({}) class TestQMNIST: - def test_extra_label(self, test_home, dataset_mock, options): - dataset_mock.prepare(test_home, **options) + def test_extra_label(self, test_home, dataset_mock, config): + dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **options) + dataset = datasets.load(dataset_mock.name, **config) sample = next(iter(dataset)) for key, type in ( @@ -173,15 +173,15 @@ def test_extra_label(self, test_home, dataset_mock, options): # FIXME: DATASET_MOCKS["gtsrb"] @parametrize_dataset_mocks({}) class TestGTSRB: - def test_label_matches_path(self, test_home, dataset_mock, options): + def test_label_matches_path(self, test_home, dataset_mock, config): # We read the labels from the csv files instead. But for the trainset, the labels are also part of the path. # This test makes sure that they're both the same - if options["split"] != "train": + if config["split"] != "train": return - dataset_mock.prepare(test_home, **options) + dataset_mock.prepare(test_home, config) - dataset = datasets.load(dataset_mock.name, **options) + dataset = datasets.load(dataset_mock.name, **config) for sample in dataset: label_from_path = int(Path(sample["path"]).parent.name) diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index c0e3f1ad1cb..8f8bb53deb4 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -56,10 +56,10 @@ def info(name: str) -> Dict[str, Any]: return find(BUILTIN_INFOS, name) -def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **options: Any) -> Dataset2: +def load(name: str, *, root: Optional[Union[str, pathlib.Path]] = None, **config: Any) -> Dataset2: dataset_cls = find(BUILTIN_DATASETS, name) if root is None: root = pathlib.Path(home()) / name - return dataset_cls(root, **options) + return dataset_cls(root, **config) From 240bf53b163520948503a0705a64caf2585ef49c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Feb 2022 16:39:34 +0100 Subject: [PATCH 7/8] remove passing of info to mock data functions --- test/builtin_dataset_mocks.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/builtin_dataset_mocks.py b/test/builtin_dataset_mocks.py index 7a9ec242337..1d988196190 100644 --- a/test/builtin_dataset_mocks.py +++ b/test/builtin_dataset_mocks.py @@ -64,7 +64,7 @@ def prepare(self, home, config): root = home / self.name root.mkdir(exist_ok=True) - mock_info = self._parse_mock_info(self.mock_data_fn(datasets.info(self.name), root, config)) + mock_info = self._parse_mock_info(self.mock_data_fn(root, config)) with unittest.mock.patch.object(datasets.utils.Dataset2, "__init__"): required_file_names = { @@ -439,9 +439,11 @@ def caltech256(info, root, config): @register_mock(configs=combinations_grid(split=("train", "val", "test"))) -def imagenet(info, root, config): +def imagenet(root, config): from scipy.io import savemat + info = datasets.info("imagenet") + if config["split"] == "train": num_samples = len(info["wnids"]) archive_name = "ILSVRC2012_img_train.tar" From 7d1e72bf134eff07cf72e03b5ab57066f73491dd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Thu, 24 Feb 2022 16:45:24 +0100 Subject: [PATCH 8/8] refactor categories file generation --- torchvision/prototype/datasets/_builtin/imagenet.py | 11 +++++------ .../prototype/datasets/generate_category_files.py | 10 +++------- torchvision/prototype/datasets/utils/_dataset.py | 3 +-- 3 files changed, 9 insertions(+), 15 deletions(-) diff --git a/torchvision/prototype/datasets/_builtin/imagenet.py b/torchvision/prototype/datasets/_builtin/imagenet.py index 34f672c2fc1..6f91d4c4a8d 100644 --- a/torchvision/prototype/datasets/_builtin/imagenet.py +++ b/torchvision/prototype/datasets/_builtin/imagenet.py @@ -182,14 +182,13 @@ def __len__(self) -> int: "test": 100_000, }[self._split] - @classmethod - def _generate_categories(cls, root: pathlib.Path) -> List[Tuple[str, ...]]: - dataset = cls(root, split="val") - resources = dataset._resources() + def _generate_categories(self) -> List[Tuple[str, ...]]: + self._split = "val" + resources = self._resources() - devkit_dp = resources[1].load(root) + devkit_dp = resources[1].load(self._root) meta_dp = Filter(devkit_dp, path_comparator("name", "meta.mat")) - meta_dp = Mapper(meta_dp, dataset._extract_categories_and_wnids) + meta_dp = Mapper(meta_dp, self._extract_categories_and_wnids) categories_and_wnids = cast(List[Tuple[str, ...]], next(iter(meta_dp))) categories_and_wnids.sort(key=lambda category_and_wnid: category_and_wnid[1]) diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index 3c2bf7e73cb..ac35eddb28b 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -2,25 +2,21 @@ import argparse import csv -import pathlib import sys from torchvision.prototype import datasets -from torchvision.prototype.datasets._api import find from torchvision.prototype.datasets.utils._internal import BUILTIN_DIR def main(*names, force=False): - home = pathlib.Path(datasets.home()) - for name in names: path = BUILTIN_DIR / f"{name}.categories" if path.exists() and not force: continue - dataset = find(name) + dataset = datasets.load(name) try: - categories = dataset._generate_categories(home / name) + categories = dataset._generate_categories() except NotImplementedError: continue @@ -55,7 +51,7 @@ def parse_args(argv=None): if __name__ == "__main__": - args = parse_args() + args = parse_args(["-f", "imagenet"]) try: main(*args.names, force=args.force) diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index ef5c1c29e0b..7200f00fd02 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -217,6 +217,5 @@ def _datapipe(self, resource_dps: List[IterDataPipe]) -> IterDataPipe[Dict[str, def __len__(self) -> int: pass - @classmethod - def _generate_categories(cls, root: pathlib.Path) -> Sequence[Union[str, Sequence[str]]]: + def _generate_categories(self) -> Sequence[Union[str, Sequence[str]]]: raise NotImplementedError