diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 7fe80b8ab56..af6e1a972c2 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,4 +1,5 @@ import bz2 +import contextlib import gzip import hashlib import itertools @@ -262,6 +263,26 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ _save_response_content(content, fpath) + # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text + if os.stat(fpath).st_size < 10 * 1024: + with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh: + text = fh.read() + # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 + if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text): + warnings.warn( + f"We detected some HTML elements in the downloaded file. " + f"This most likely means that the download triggered an unhandled API response by GDrive. " + f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " + f"the response:\n\n{text}" + ) + + if md5 and not check_md5(fpath, md5): + raise RuntimeError( + f"The MD5 checksum of the download file {fpath} does not match the one on record." + f"Please delete the file and try again. " + f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." + ) + def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: with tarfile.open(from_path, f"r:{compression[1:]}" if compression else "r") as tar: