Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def test_detect_file_type(self):
("foo.gz", (".gz", None, ".gz")),
("foo.zip", (".zip", ".zip", None)),
("foo.xz", (".xz", None, ".xz")),
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
("foo.bar.gz", (".gz", None, ".gz")),
("foo.bar.zip", (".zip", ".zip", None)),
]:
with self.subTest(file=file):
self.assertSequenceEqual(utils._detect_file_type(file), expected)
Expand All @@ -71,14 +74,6 @@ def test_detect_file_type_no_ext(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo")

def test_detect_file_type_to_many_exts(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.tar.gz")

def test_detect_file_type_unknown_archive_type(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.bar.gz")

def test_detect_file_type_unknown_compression(self):
with self.assertRaises(RuntimeError):
utils._detect_file_type("foo.tar.baz")
Expand Down
52 changes: 23 additions & 29 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,53 +291,47 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
}


def _verify_archive_type(archive_type: str) -> None:
if archive_type not in _ARCHIVE_EXTRACTORS.keys():
valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys())
raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.")

def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
"""Detect the archive type and/or compression of a file.

def _verify_compression(compression: str) -> None:
if compression not in _COMPRESSED_FILE_OPENERS.keys():
valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys())
raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.")
Args:
file (str): the filename

Returns:
(tuple): tuple of suffix, archive type, and compression

def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
path = pathlib.Path(file)
suffix = path.suffix
Raises:
RuntimeError: if file has no suffix or suffix is not supported
"""
suffixes = pathlib.Path(file).suffixes
if not suffixes:
raise RuntimeError(
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
)
elif len(suffixes) > 2:
raise RuntimeError(
"Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead."
)
elif len(suffixes) == 2:
# if we have exactly two suffixes we assume the first one is the archive type and the second on is the
# compression
archive_type, compression = suffixes
_verify_archive_type(archive_type)
_verify_compression(compression)
return "".join(suffixes), archive_type, compression
suffix = suffixes[-1]

# check if the suffix is a known alias
with contextlib.suppress(KeyError):
if suffix in _FILE_TYPE_ALIASES:
return (suffix, *_FILE_TYPE_ALIASES[suffix])

# check if the suffix is an archive type
with contextlib.suppress(RuntimeError):
_verify_archive_type(suffix)
if suffix in _ARCHIVE_EXTRACTORS:
return suffix, suffix, None

# check if the suffix is a compression
with contextlib.suppress(RuntimeError):
_verify_compression(suffix)
if suffix in _COMPRESSED_FILE_OPENERS:
# check for suffix hierarchy
if len(suffixes) > 1:
suffix2 = suffixes[-2]

# check if the suffix2 is an archive type
if suffix2 in _ARCHIVE_EXTRACTORS:
return suffix2 + suffix, suffix2, suffix

return suffix, None, suffix

raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.")
valid_suffixes = sorted(set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS))
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'.\nKnown suffixes are: '{valid_suffixes}'.")


def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:
Expand Down