Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import bz2
import contextlib
import gzip
import hashlib
import itertools
Expand Down Expand Up @@ -262,6 +263,26 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[

_save_response_content(content, fpath)

# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
if os.stat(fpath).st_size < 10 * 1024:
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
text = fh.read()
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
warnings.warn(
f"We detected some HTML elements in the downloaded file. "
f"This most likely means that the download triggered an unhandled API response by GDrive. "
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
f"the response:\n\n{text}"
)

if md5 and not check_md5(fpath, md5):
raise RuntimeError(
f"The MD5 checksum of the download file {fpath} does not match the one on record."
f"Please delete the file and try again. "
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
)


def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:
with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar:
Expand Down