Skip to content
Merged
2 changes: 1 addition & 1 deletion test/test_prototype_builtin_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_home(mocker, tmp_path):


def test_coverage():
untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys()
untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys()
if untested_datasets:
raise AssertionError(
f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} "
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
from ._home import home

# Load this last, since some parts depend on the above being loaded first
from ._api import register, _list as list, info, load, find # usort: skip
from ._api import register, list_datasets, info, load, find # usort: skip
from ._folder import from_data_folder, from_image_folder
5 changes: 2 additions & 3 deletions torchvision/prototype/datasets/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def register(dataset: Dataset) -> None:
register(obj())


# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list'
def _list() -> List[str]:
def list_datasets() -> List[str]:
return sorted(DATASETS.keys())


Expand All @@ -39,7 +38,7 @@ def find(name: str) -> Dataset:
word=name,
possibilities=DATASETS.keys(),
alternative_hint=lambda _: (
"You can use torchvision.datasets.list() to get a list of all available datasets."
"You can use torchvision.datasets.list_datasets() to get a list of all available datasets."
),
)
) from error
Expand Down
2 changes: 1 addition & 1 deletion torchvision/prototype/datasets/generate_category_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def parse_args(argv=None):
args = parser.parse_args(argv or sys.argv[1:])

if not args.names:
args.names = datasets.list()
args.names = datasets.list_datasets()

return args

Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/datasets/utils/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class DatasetType(enum.Enum):


class DatasetConfig(FrozenBunch):
# This needs to be Frozen because we often pass configs as partial(func, config=config)
# and partial() requires the parameters to be hashable.
pass


Expand Down
10 changes: 0 additions & 10 deletions torchvision/prototype/datasets/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
"BUILTIN_DIR",
"read_mat",
"image_buffer_from_array",
"SequenceIterator",
"MappingIterator",
"Enumerator",
"getitem",
Expand Down Expand Up @@ -80,15 +79,6 @@ def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.Byt
return buffer


class SequenceIterator(IterDataPipe[D]):
def __init__(self, datapipe: IterDataPipe[Sequence[D]]):
self.datapipe = datapipe

def __iter__(self) -> Iterator[D]:
for sequence in self.datapipe:
yield from iter(sequence)


class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]):
def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None:
self.datapipe = datapipe
Expand Down
23 changes: 8 additions & 15 deletions torchvision/prototype/datasets/utils/_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import hashlib
import itertools
import pathlib
import warnings
from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn
from urllib.parse import urlparse

Expand Down Expand Up @@ -32,23 +31,17 @@ def __init__(
sha256: Optional[str] = None,
decompress: bool = False,
extract: bool = False,
preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] = None,
loader: Optional[Callable[[pathlib.Path], IterDataPipe[Tuple[str, IO]]]] = None,
) -> None:
self.file_name = file_name
self.sha256 = sha256

if preprocess and (decompress or extract):
warnings.warn("The parameters 'decompress' and 'extract' are ignored when 'preprocess' is passed.")
elif extract:
preprocess = self._extract
self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]]
if extract:
self._preprocess = self._extract
elif decompress:
preprocess = self._decompress
self._preprocess = preprocess

if loader is None:
loader = self._default_loader
self._loader = loader
self._preprocess = self._decompress
else:
self._preprocess = None

@staticmethod
def _extract(file: pathlib.Path) -> pathlib.Path:
Expand All @@ -60,7 +53,7 @@ def _extract(file: pathlib.Path) -> pathlib.Path:
def _decompress(file: pathlib.Path) -> pathlib.Path:
return pathlib.Path(_decompress(str(file), remove_finished=True))

def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]:
if path.is_dir():
return FileOpener(FileLister(str(path), recursive=True), mode="rb")

Expand Down Expand Up @@ -101,7 +94,7 @@ def load(
path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)}
# If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps.
if path_candidates == {path}:
if self._preprocess:
if self._preprocess is not None:
path = self._preprocess(path)
# Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority
# that we want.
Expand Down