diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 7b6a22600e1..d8e07314e00 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -18,7 +18,7 @@ def test_home(mocker, tmp_path): def test_coverage(): - untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys() + untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys() if untested_datasets: raise AssertionError( f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} " diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 1945b5a5d9e..48bae0b65f5 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -11,5 +11,5 @@ from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import register, _list as list, info, load, find # usort: skip +from ._api import register, list_datasets, info, load, find # usort: skip from ._folder import from_data_folder, from_image_folder diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index e9240eb46ce..f3c398d5552 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -23,8 +23,7 @@ def register(dataset: Dataset) -> None: register(obj()) -# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' -def _list() -> List[str]: +def list_datasets() -> List[str]: return sorted(DATASETS.keys()) @@ -39,7 +38,7 @@ def find(name: str) -> Dataset: word=name, possibilities=DATASETS.keys(), alternative_hint=lambda _: ( - "You can use torchvision.datasets.list() to get a list of all available datasets." + "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." ), ) ) from error diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index 40843ecf50b..3c2bf7e73cb 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -49,7 +49,7 @@ def parse_args(argv=None): args = parser.parse_args(argv or sys.argv[1:]) if not args.names: - args.names = datasets.list() + args.names = datasets.list_datasets() return args diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 38c991fe7a1..04fb4312728 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -24,6 +24,8 @@ class DatasetType(enum.Enum): class DatasetConfig(FrozenBunch): + # This needs to be Frozen because we often pass configs as partial(func, config=config) + # and partial() requires the parameters to be hashable. pass diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index e21e8ffd25f..1b437d50b85 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -39,7 +39,6 @@ "BUILTIN_DIR", "read_mat", "image_buffer_from_array", - "SequenceIterator", "MappingIterator", "Enumerator", "getitem", @@ -80,15 +79,6 @@ def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.Byt return buffer -class SequenceIterator(IterDataPipe[D]): - def __init__(self, datapipe: IterDataPipe[Sequence[D]]): - self.datapipe = datapipe - - def __iter__(self) -> Iterator[D]: - for sequence in self.datapipe: - yield from iter(sequence) - - class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None: self.datapipe = datapipe diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 2fc6bbca49e..c7fde65468a 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -2,7 +2,6 @@ import hashlib import itertools import pathlib -import warnings from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn from urllib.parse import urlparse @@ -32,23 +31,17 @@ def __init__( sha256: Optional[str] = None, decompress: bool = False, extract: bool = False, - preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] = None, - loader: Optional[Callable[[pathlib.Path], IterDataPipe[Tuple[str, IO]]]] = None, ) -> None: self.file_name = file_name self.sha256 = sha256 - if preprocess and (decompress or extract): - warnings.warn("The parameters 'decompress' and 'extract' are ignored when 'preprocess' is passed.") - elif extract: - preprocess = self._extract + self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] + if extract: + self._preprocess = self._extract elif decompress: - preprocess = self._decompress - self._preprocess = preprocess - - if loader is None: - loader = self._default_loader - self._loader = loader + self._preprocess = self._decompress + else: + self._preprocess = None @staticmethod def _extract(file: pathlib.Path) -> pathlib.Path: @@ -60,7 +53,7 @@ def _extract(file: pathlib.Path) -> pathlib.Path: def _decompress(file: pathlib.Path) -> pathlib.Path: return pathlib.Path(_decompress(str(file), remove_finished=True)) - def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: + def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: if path.is_dir(): return FileOpener(FileLister(str(path), recursive=True), mode="rb") @@ -101,7 +94,7 @@ def load( path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} # If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps. if path_candidates == {path}: - if self._preprocess: + if self._preprocess is not None: path = self._preprocess(path) # Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority # that we want.