From 3e7141cc7f79463e3946858f09da53296bf458de Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 25 Jan 2022 17:28:53 +0000 Subject: [PATCH 1/8] Some Qs --- torchvision/prototype/datasets/_api.py | 6 +++++- torchvision/prototype/datasets/_folder.py | 4 +++- .../prototype/datasets/utils/_dataset.py | 18 +++++++++++++++--- .../prototype/datasets/utils/_internal.py | 3 ++- .../prototype/datasets/utils/_resource.py | 7 ++++++- 5 files changed, 31 insertions(+), 7 deletions(-) diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index e9240eb46ce..3f9c7014c6b 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -18,16 +18,19 @@ def register(dataset: Dataset) -> None: DATASETS[dataset.name] = dataset +## Should we manually register datasets instead of relying on the content of _builtings.__init__.py? for name, obj in _builtin.__dict__.items(): if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset: register(obj()) # This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' +## Maybe we could call it as list_available_datasets()? def _list() -> List[str]: return sorted(DATASETS.keys()) +## Does this need to be public? Looks like users would only need load() and info() def find(name: str) -> Dataset: name = name.lower() try: @@ -57,11 +60,12 @@ def info(name: str) -> DatasetInfo: } +## Should DEFAULT_DECODER just be None? Or "auto"? Or "default"? def load( name: str, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment] - skip_integrity_check: bool = False, + skip_integrity_check: bool = False, ## When do we need it to be True? **options: Any, ) -> IterDataPipe[Dict[str, Any]]: dataset = find(name) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index fbca8b07b1a..d80129acf5f 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -12,6 +12,8 @@ from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding +## We don't seem to use these anywhere for now. +## If we keep them, should we add tests? __all__ = ["from_data_folder", "from_image_folder"] @@ -52,7 +54,7 @@ def from_data_folder( dp = FileLister(str(root), recursive=recursive, masks=masks) dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root)) dp = hint_sharding(dp) - dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) + dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) ## Should this be hint_shuffling() ? dp = FileOpener(dp, mode="rb") return ( Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)), diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 38c991fe7a1..488b5e6ba05 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -19,11 +19,13 @@ class DatasetType(enum.Enum): + ## Looks like this is only used to determine the decoder to use. + ## Should we then have a more direct way of setting the default decoder? RAW = enum.auto() IMAGE = enum.auto() -class DatasetConfig(FrozenBunch): +class DatasetConfig(FrozenBunch): ## Do we need this to be a FrozenBunch? pass @@ -34,7 +36,7 @@ def __init__( *, type: Union[str, DatasetType], dependencies: Collection[str] = (), - categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, + categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, ## Nit: do we need to allow Paths? citation: Optional[str] = None, homepage: Optional[str] = None, license: Optional[str] = None, @@ -60,12 +62,15 @@ def __init__( self.homepage = homepage self.license = license + ## What about mutually exclusive parameters? + ## How are we going to document the dataset parameters? self._valid_options = valid_options or dict() self._configs = tuple( DatasetConfig(**dict(zip(self._valid_options.keys(), combination))) for combination in itertools.product(*self._valid_options.values()) ) + ## What is this? self.extra = FrozenBunch(extra or dict()) @property @@ -75,7 +80,7 @@ def default_config(self) -> DatasetConfig: @staticmethod def read_categories_file(path: pathlib.Path) -> List[List[str]]: with open(path, newline="") as file: - return [row for row in csv.reader(file)] + return [row for row in csv.reader(file)] ## Nit: do we need a csv reader? def make_config(self, **options: Any) -> DatasetConfig: if not self._valid_options and options: @@ -135,6 +140,9 @@ def __init__(self) -> None: def _make_info(self) -> DatasetInfo: pass + ## Could these @properties be simple attributes? + ## Also, do we actually need to expose these attributes in the Dataset class, or could we just access dataset.info.the_attr ? + ## Dumb idea: what if we get rid of the `info` attribute and put everything in the Dataset class namespace? @property def info(self) -> DatasetInfo: return self._info @@ -151,6 +159,7 @@ def default_config(self) -> DatasetConfig: def categories(self) -> Tuple[str, ...]: return self.info.categories + ## Q: why is resources() "public" while _make_datapipe() isn't? @abc.abstractmethod def resources(self, config: DatasetConfig) -> List[OnlineResource]: pass @@ -165,6 +174,7 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: pass + ## Is this related to Manifold support or does this have OSS user facing impact? def supports_sharded(self) -> bool: return False @@ -179,12 +189,14 @@ def load( if not config: config = self.info.default_config + ## TODO: need to understand this bit if use_sharded_dataset() and self.supports_sharded(): root = os.path.join(root, *config.values()) dataset_size = self.info.extra["sizes"][config] return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return] self.info.check_dependencies() + ## Nit: a lot of method are called load() and they all do somewhat different things on different objects resource_dps = [ resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) ] diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index e21e8ffd25f..a304bea8c0b 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -80,7 +80,7 @@ def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.Byt return buffer -class SequenceIterator(IterDataPipe[D]): +class SequenceIterator(IterDataPipe[D]): ## Is this used? def __init__(self, datapipe: IterDataPipe[Sequence[D]]): self.datapipe = datapipe @@ -152,6 +152,7 @@ class CompressionType(enum.Enum): LZMA = "lzma" +## Is this somewhat redundant with the resources preprocessing logic? class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): types = CompressionType diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 2fc6bbca49e..54c304fe93c 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -24,6 +24,7 @@ ) +## We don't seem to use extract, preprocess, or loader params. Maybe we can remove them for now? class OnlineResource(abc.ABC): def __init__( self, @@ -31,7 +32,7 @@ def __init__( file_name: str, sha256: Optional[str] = None, decompress: bool = False, - extract: bool = False, + extract: bool = False, ## Do we ever want to decompress without extracting? preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] = None, loader: Optional[Callable[[pathlib.Path], IterDataPipe[Tuple[str, IO]]]] = None, ) -> None: @@ -72,6 +73,7 @@ def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: return dp + ## Overall, what's the plan regarding allowing user to impelment their own datasets (and their own loaders etc?) _ARCHIVE_LOADERS = { ".tar": TarArchiveReader, ".zip": ZipArchiveReader, @@ -83,6 +85,8 @@ def _guess_archive_loader( ) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]: try: _, archive_type, _ = _detect_file_type(path.name) + ## Unrelated but it looks like it could be a ValueError. Or instead of + ## raising, _detect_file_type could just return an empty tuple or None? except RuntimeError: return None return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type] @@ -106,6 +110,7 @@ def load( # Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority # that we want. else: + ## Are there current examples where we reach this line? path = min(path_candidates, key=lambda path: len(path.suffixes)) return self._loader(path) From b9c64ab63ce1afe58ded4df9dcd37059a22a1b77 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Jan 2022 12:52:11 +0000 Subject: [PATCH 2/8] Some modifications --- torchvision/prototype/datasets/__init__.py | 2 +- torchvision/prototype/datasets/_api.py | 9 ++----- torchvision/prototype/datasets/_folder.py | 4 +-- .../prototype/datasets/utils/_dataset.py | 20 ++++---------- .../prototype/datasets/utils/_internal.py | 11 -------- .../prototype/datasets/utils/_resource.py | 26 ++++++------------- 6 files changed, 17 insertions(+), 55 deletions(-) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index 1945b5a5d9e..ccfe184b140 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -11,5 +11,5 @@ from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import register, _list as list, info, load, find # usort: skip +from ._api import register, list_names, info, load, find # usort: skip from ._folder import from_data_folder, from_image_folder diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 3f9c7014c6b..f4910fc02ec 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -18,19 +18,15 @@ def register(dataset: Dataset) -> None: DATASETS[dataset.name] = dataset -## Should we manually register datasets instead of relying on the content of _builtings.__init__.py? for name, obj in _builtin.__dict__.items(): if not name.startswith("_") and isinstance(obj, type) and issubclass(obj, Dataset) and obj is not Dataset: register(obj()) -# This is exposed as 'list', but we avoid that here to not shadow the built-in 'list' -## Maybe we could call it as list_available_datasets()? -def _list() -> List[str]: +def list_names() -> List[str]: return sorted(DATASETS.keys()) -## Does this need to be public? Looks like users would only need load() and info() def find(name: str) -> Dataset: name = name.lower() try: @@ -60,12 +56,11 @@ def info(name: str) -> DatasetInfo: } -## Should DEFAULT_DECODER just be None? Or "auto"? Or "default"? def load( name: str, *, decoder: Optional[Callable[[io.IOBase], torch.Tensor]] = DEFAULT_DECODER, # type: ignore[assignment] - skip_integrity_check: bool = False, ## When do we need it to be True? + skip_integrity_check: bool = False, **options: Any, ) -> IterDataPipe[Dict[str, Any]]: dataset = find(name) diff --git a/torchvision/prototype/datasets/_folder.py b/torchvision/prototype/datasets/_folder.py index d80129acf5f..fbca8b07b1a 100644 --- a/torchvision/prototype/datasets/_folder.py +++ b/torchvision/prototype/datasets/_folder.py @@ -12,8 +12,6 @@ from torchvision.prototype.datasets.utils._internal import INFINITE_BUFFER_SIZE, hint_sharding -## We don't seem to use these anywhere for now. -## If we keep them, should we add tests? __all__ = ["from_data_folder", "from_image_folder"] @@ -54,7 +52,7 @@ def from_data_folder( dp = FileLister(str(root), recursive=recursive, masks=masks) dp: IterDataPipe = Filter(dp, functools.partial(_is_not_top_level_file, root=root)) dp = hint_sharding(dp) - dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) ## Should this be hint_shuffling() ? + dp = Shuffler(dp, buffer_size=INFINITE_BUFFER_SIZE) dp = FileOpener(dp, mode="rb") return ( Mapper(dp, functools.partial(_collate_and_decode_data, root=root, categories=categories, decoder=decoder)), diff --git a/torchvision/prototype/datasets/utils/_dataset.py b/torchvision/prototype/datasets/utils/_dataset.py index 488b5e6ba05..04fb4312728 100644 --- a/torchvision/prototype/datasets/utils/_dataset.py +++ b/torchvision/prototype/datasets/utils/_dataset.py @@ -19,13 +19,13 @@ class DatasetType(enum.Enum): - ## Looks like this is only used to determine the decoder to use. - ## Should we then have a more direct way of setting the default decoder? RAW = enum.auto() IMAGE = enum.auto() -class DatasetConfig(FrozenBunch): ## Do we need this to be a FrozenBunch? +class DatasetConfig(FrozenBunch): + # This needs to be Frozen because we often pass configs as partial(func, config=config) + # and partial() requires the parameters to be hashable. pass @@ -36,7 +36,7 @@ def __init__( *, type: Union[str, DatasetType], dependencies: Collection[str] = (), - categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, ## Nit: do we need to allow Paths? + categories: Optional[Union[int, Sequence[str], str, pathlib.Path]] = None, citation: Optional[str] = None, homepage: Optional[str] = None, license: Optional[str] = None, @@ -62,15 +62,12 @@ def __init__( self.homepage = homepage self.license = license - ## What about mutually exclusive parameters? - ## How are we going to document the dataset parameters? self._valid_options = valid_options or dict() self._configs = tuple( DatasetConfig(**dict(zip(self._valid_options.keys(), combination))) for combination in itertools.product(*self._valid_options.values()) ) - ## What is this? self.extra = FrozenBunch(extra or dict()) @property @@ -80,7 +77,7 @@ def default_config(self) -> DatasetConfig: @staticmethod def read_categories_file(path: pathlib.Path) -> List[List[str]]: with open(path, newline="") as file: - return [row for row in csv.reader(file)] ## Nit: do we need a csv reader? + return [row for row in csv.reader(file)] def make_config(self, **options: Any) -> DatasetConfig: if not self._valid_options and options: @@ -140,9 +137,6 @@ def __init__(self) -> None: def _make_info(self) -> DatasetInfo: pass - ## Could these @properties be simple attributes? - ## Also, do we actually need to expose these attributes in the Dataset class, or could we just access dataset.info.the_attr ? - ## Dumb idea: what if we get rid of the `info` attribute and put everything in the Dataset class namespace? @property def info(self) -> DatasetInfo: return self._info @@ -159,7 +153,6 @@ def default_config(self) -> DatasetConfig: def categories(self) -> Tuple[str, ...]: return self.info.categories - ## Q: why is resources() "public" while _make_datapipe() isn't? @abc.abstractmethod def resources(self, config: DatasetConfig) -> List[OnlineResource]: pass @@ -174,7 +167,6 @@ def _make_datapipe( ) -> IterDataPipe[Dict[str, Any]]: pass - ## Is this related to Manifold support or does this have OSS user facing impact? def supports_sharded(self) -> bool: return False @@ -189,14 +181,12 @@ def load( if not config: config = self.info.default_config - ## TODO: need to understand this bit if use_sharded_dataset() and self.supports_sharded(): root = os.path.join(root, *config.values()) dataset_size = self.info.extra["sizes"][config] return _make_sharded_datapipe(root, dataset_size) # type: ignore[no-any-return] self.info.check_dependencies() - ## Nit: a lot of method are called load() and they all do somewhat different things on different objects resource_dps = [ resource.load(root, skip_integrity_check=skip_integrity_check) for resource in self.resources(config) ] diff --git a/torchvision/prototype/datasets/utils/_internal.py b/torchvision/prototype/datasets/utils/_internal.py index a304bea8c0b..1b437d50b85 100644 --- a/torchvision/prototype/datasets/utils/_internal.py +++ b/torchvision/prototype/datasets/utils/_internal.py @@ -39,7 +39,6 @@ "BUILTIN_DIR", "read_mat", "image_buffer_from_array", - "SequenceIterator", "MappingIterator", "Enumerator", "getitem", @@ -80,15 +79,6 @@ def image_buffer_from_array(array: np.ndarray, *, format: str = "png") -> io.Byt return buffer -class SequenceIterator(IterDataPipe[D]): ## Is this used? - def __init__(self, datapipe: IterDataPipe[Sequence[D]]): - self.datapipe = datapipe - - def __iter__(self) -> Iterator[D]: - for sequence in self.datapipe: - yield from iter(sequence) - - class MappingIterator(IterDataPipe[Union[Tuple[K, D], D]]): def __init__(self, datapipe: IterDataPipe[Dict[K, D]], *, drop_key: bool = False) -> None: self.datapipe = datapipe @@ -152,7 +142,6 @@ class CompressionType(enum.Enum): LZMA = "lzma" -## Is this somewhat redundant with the resources preprocessing logic? class Decompressor(IterDataPipe[Tuple[str, io.IOBase]]): types = CompressionType diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 54c304fe93c..06eb465a714 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -24,7 +24,6 @@ ) -## We don't seem to use extract, preprocess, or loader params. Maybe we can remove them for now? class OnlineResource(abc.ABC): def __init__( self, @@ -32,24 +31,19 @@ def __init__( file_name: str, sha256: Optional[str] = None, decompress: bool = False, - extract: bool = False, ## Do we ever want to decompress without extracting? - preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] = None, - loader: Optional[Callable[[pathlib.Path], IterDataPipe[Tuple[str, IO]]]] = None, + extract: bool = False, ) -> None: self.file_name = file_name self.sha256 = sha256 - if preprocess and (decompress or extract): - warnings.warn("The parameters 'decompress' and 'extract' are ignored when 'preprocess' is passed.") - elif extract: - preprocess = self._extract + if extract: + self._preprocess = self._extract elif decompress: - preprocess = self._decompress - self._preprocess = preprocess + self._preprocess = self._decompress + else: + self._preprocess = None - if loader is None: - loader = self._default_loader - self._loader = loader + self._loader = self._default_loader @staticmethod def _extract(file: pathlib.Path) -> pathlib.Path: @@ -73,7 +67,6 @@ def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: return dp - ## Overall, what's the plan regarding allowing user to impelment their own datasets (and their own loaders etc?) _ARCHIVE_LOADERS = { ".tar": TarArchiveReader, ".zip": ZipArchiveReader, @@ -85,8 +78,6 @@ def _guess_archive_loader( ) -> Optional[Callable[[IterDataPipe[Tuple[str, IO]]], IterDataPipe[Tuple[str, IO]]]]: try: _, archive_type, _ = _detect_file_type(path.name) - ## Unrelated but it looks like it could be a ValueError. Or instead of - ## raising, _detect_file_type could just return an empty tuple or None? except RuntimeError: return None return self._ARCHIVE_LOADERS.get(archive_type) # type: ignore[arg-type] @@ -105,12 +96,11 @@ def load( path_candidates = {self.download(root, skip_integrity_check=skip_integrity_check)} # If the only thing we find is the raw file, we use it and optionally perform some preprocessing steps. if path_candidates == {path}: - if self._preprocess: + if self._preprocess is not None: path = self._preprocess(path) # Otherwise we use the path with the fewest suffixes. This gives us the extracted > decompressed > raw priority # that we want. else: - ## Are there current examples where we reach this line? path = min(path_candidates, key=lambda path: len(path.suffixes)) return self._loader(path) From 8b2fbeb726e3226f34afd3208d2abcb7cfe4d377 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Jan 2022 13:12:09 +0000 Subject: [PATCH 3/8] don't need _loader in __init__ --- torchvision/prototype/datasets/utils/_resource.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index 06eb465a714..c5e536f7f2b 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -43,8 +43,6 @@ def __init__( else: self._preprocess = None - self._loader = self._default_loader - @staticmethod def _extract(file: pathlib.Path) -> pathlib.Path: return pathlib.Path( @@ -55,7 +53,7 @@ def _extract(file: pathlib.Path) -> pathlib.Path: def _decompress(file: pathlib.Path) -> pathlib.Path: return pathlib.Path(_decompress(str(file), remove_finished=True)) - def _default_loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: + def _loader(self, path: pathlib.Path) -> IterDataPipe[Tuple[str, IO]]: if path.is_dir(): return FileOpener(FileLister(str(path), recursive=True), mode="rb") From 1f2966fc99c481d558cd27b5e7b3ead4a711b8bb Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 27 Jan 2022 09:26:04 +0000 Subject: [PATCH 4/8] list_names -> list_datasets --- torchvision/prototype/datasets/__init__.py | 2 +- torchvision/prototype/datasets/_api.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/__init__.py b/torchvision/prototype/datasets/__init__.py index ccfe184b140..48bae0b65f5 100644 --- a/torchvision/prototype/datasets/__init__.py +++ b/torchvision/prototype/datasets/__init__.py @@ -11,5 +11,5 @@ from ._home import home # Load this last, since some parts depend on the above being loaded first -from ._api import register, list_names, info, load, find # usort: skip +from ._api import register, list_datasets, info, load, find # usort: skip from ._folder import from_data_folder, from_image_folder diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index f4910fc02ec..435fd293fd9 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -23,7 +23,7 @@ def register(dataset: Dataset) -> None: register(obj()) -def list_names() -> List[str]: +def list_datasets() -> List[str]: return sorted(DATASETS.keys()) From 73e65e3dd6afebdb058c66733e904aa3b19f1c98 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 27 Jan 2022 09:43:23 +0000 Subject: [PATCH 5/8] Update torchvision/prototype/datasets/utils/_resource.py Co-authored-by: Philip Meier --- torchvision/prototype/datasets/utils/_resource.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index c5e536f7f2b..f5195f32938 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -36,6 +36,7 @@ def __init__( self.file_name = file_name self.sha256 = sha256 + self._preprocess: Optional[Callable[[pathlib.Path], pathlib.Path]] if extract: self._preprocess = self._extract elif decompress: From bc3681928599be3e17f35da6225e4c247abd698b Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 27 Jan 2022 10:03:45 +0000 Subject: [PATCH 6/8] Remove unsued import --- torchvision/prototype/datasets/utils/_resource.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/prototype/datasets/utils/_resource.py b/torchvision/prototype/datasets/utils/_resource.py index c5e536f7f2b..93d44cc80ba 100644 --- a/torchvision/prototype/datasets/utils/_resource.py +++ b/torchvision/prototype/datasets/utils/_resource.py @@ -2,7 +2,6 @@ import hashlib import itertools import pathlib -import warnings from typing import Optional, Sequence, Tuple, Callable, IO, Any, Union, NoReturn from urllib.parse import urlparse From adfbb5b21466c40153501ec2d01fb374eb0bf791 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 27 Jan 2022 13:26:44 +0000 Subject: [PATCH 7/8] fix tests --- test/test_prototype_builtin_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_prototype_builtin_datasets.py b/test/test_prototype_builtin_datasets.py index 7b6a22600e1..d8e07314e00 100644 --- a/test/test_prototype_builtin_datasets.py +++ b/test/test_prototype_builtin_datasets.py @@ -18,7 +18,7 @@ def test_home(mocker, tmp_path): def test_coverage(): - untested_datasets = set(datasets.list()) - DATASET_MOCKS.keys() + untested_datasets = set(datasets.list_datasets()) - DATASET_MOCKS.keys() if untested_datasets: raise AssertionError( f"The dataset(s) {sequence_to_str(sorted(untested_datasets), separate_last='and ')} " From 98014e95895107e245ced324323ca5dde505df11 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 27 Jan 2022 13:27:27 +0000 Subject: [PATCH 8/8] Some missing renames --- torchvision/prototype/datasets/_api.py | 2 +- torchvision/prototype/datasets/generate_category_files.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/datasets/_api.py b/torchvision/prototype/datasets/_api.py index 435fd293fd9..f3c398d5552 100644 --- a/torchvision/prototype/datasets/_api.py +++ b/torchvision/prototype/datasets/_api.py @@ -38,7 +38,7 @@ def find(name: str) -> Dataset: word=name, possibilities=DATASETS.keys(), alternative_hint=lambda _: ( - "You can use torchvision.datasets.list() to get a list of all available datasets." + "You can use torchvision.datasets.list_datasets() to get a list of all available datasets." ), ) ) from error diff --git a/torchvision/prototype/datasets/generate_category_files.py b/torchvision/prototype/datasets/generate_category_files.py index 40843ecf50b..3c2bf7e73cb 100644 --- a/torchvision/prototype/datasets/generate_category_files.py +++ b/torchvision/prototype/datasets/generate_category_files.py @@ -49,7 +49,7 @@ def parse_args(argv=None): args = parser.parse_args(argv or sys.argv[1:]) if not args.names: - args.names = datasets.list() + args.names = datasets.list_datasets() return args