Skip to content

Commit c32d5ac

Browse files
committed
Add support for files with periods in name
1 parent 183a722 commit c32d5ac

File tree

2 files changed

+26
-37
lines changed

2 files changed

+26
-37
lines changed

test/test_datasets_utils.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,9 @@ def test_detect_file_type(self):
5858
("foo.gz", (".gz", None, ".gz")),
5959
("foo.zip", (".zip", ".zip", None)),
6060
("foo.xz", (".xz", None, ".xz")),
61+
("foo.bar.tar.gz", (".tar.gz", ".tar", ".gz")),
62+
("foo.bar.gz", (".gz", None, ".gz")),
63+
("foo.bar.zip", (".zip", ".zip", None)),
6164
]:
6265
with self.subTest(file=file):
6366
self.assertSequenceEqual(utils._detect_file_type(file), expected)
@@ -66,14 +69,6 @@ def test_detect_file_type_no_ext(self):
6669
with self.assertRaises(RuntimeError):
6770
utils._detect_file_type("foo")
6871

69-
def test_detect_file_type_to_many_exts(self):
70-
with self.assertRaises(RuntimeError):
71-
utils._detect_file_type("foo.bar.tar.gz")
72-
73-
def test_detect_file_type_unknown_archive_type(self):
74-
with self.assertRaises(RuntimeError):
75-
utils._detect_file_type("foo.bar.gz")
76-
7772
def test_detect_file_type_unknown_compression(self):
7873
with self.assertRaises(RuntimeError):
7974
utils._detect_file_type("foo.tar.baz")

torchvision/datasets/utils.py

Lines changed: 23 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -281,53 +281,47 @@ def _extract_zip(from_path: str, to_path: str, compression: Optional[str]) -> No
281281
_FILE_TYPE_ALIASES: Dict[str, Tuple[Optional[str], Optional[str]]] = {".tgz": (".tar", ".gz")}
282282

283283

284-
def _verify_archive_type(archive_type: str) -> None:
285-
if archive_type not in _ARCHIVE_EXTRACTORS.keys():
286-
valid_types = "', '".join(_ARCHIVE_EXTRACTORS.keys())
287-
raise RuntimeError(f"Unknown archive type '{archive_type}'. Known archive types are '{valid_types}'.")
288-
284+
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
285+
"""Detect the archive type and/or compression of a file.
289286
290-
def _verify_compression(compression: str) -> None:
291-
if compression not in _COMPRESSED_FILE_OPENERS.keys():
292-
valid_types = "', '".join(_COMPRESSED_FILE_OPENERS.keys())
293-
raise RuntimeError(f"Unknown compression '{compression}'. Known compressions are '{valid_types}'.")
287+
Args:
288+
file (str): the filename
294289
290+
Returns:
291+
(tuple): tuple of suffix, archive type, and compression
295292
296-
def _detect_file_type(file: str) -> Tuple[str, Optional[str], Optional[str]]:
297-
path = pathlib.Path(file)
298-
suffix = path.suffix
293+
Raises:
294+
RuntimeError: if file has no suffix or suffix is not supported
295+
"""
299296
suffixes = pathlib.Path(file).suffixes
300297
if not suffixes:
301298
raise RuntimeError(
302299
f"File '{file}' has no suffixes that could be used to detect the archive type and compression."
303300
)
304-
elif len(suffixes) > 2:
305-
raise RuntimeError(
306-
"Archive type and compression detection only works for 1 or 2 suffixes. " f"Got {len(suffixes)} instead."
307-
)
308-
elif len(suffixes) == 2:
309-
# if we have exactly two suffixes we assume the first one is the archive type and the second on is the
310-
# compression
311-
archive_type, compression = suffixes
312-
_verify_archive_type(archive_type)
313-
_verify_compression(compression)
314-
return "".join(suffixes), archive_type, compression
301+
suffix = suffixes[-1]
315302

316303
# check if the suffix is a known alias
317-
with contextlib.suppress(KeyError):
304+
if suffix in _FILE_TYPE_ALIASES:
318305
return (suffix, *_FILE_TYPE_ALIASES[suffix])
319306

320307
# check if the suffix is an archive type
321-
with contextlib.suppress(RuntimeError):
322-
_verify_archive_type(suffix)
308+
if suffix in _ARCHIVE_EXTRACTORS:
323309
return suffix, suffix, None
324310

325311
# check if the suffix is a compression
326-
with contextlib.suppress(RuntimeError):
327-
_verify_compression(suffix)
312+
if suffix in _COMPRESSED_FILE_OPENERS:
313+
# check for suffix hierarchy
314+
if len(suffixes) > 1:
315+
suffix2 = suffixes[-2]
316+
317+
# check if the suffix2 is an archive type
318+
if suffix2 in _ARCHIVE_EXTRACTORS:
319+
return suffix2 + suffix, suffix2, suffix
320+
328321
return suffix, None, suffix
329322

330-
raise RuntimeError(f"Suffix '{suffix}' is neither recognized as archive type nor as compression.")
323+
valid_suffixes = set(_FILE_TYPE_ALIASES) | set(_ARCHIVE_EXTRACTORS) | set(_COMPRESSED_FILE_OPENERS)
324+
raise RuntimeError(f"Unknown compression or archive type: '{suffix}'. Known suffixes are: '{valid_suffixes}'.")
331325

332326

333327
def _decompress(from_path: str, to_path: Optional[str] = None, remove_finished: bool = False) -> str:

0 commit comments

Comments
 (0)