From d467132e929e0e1c66a574b6d79d38eeabecb2dd Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 30 Mar 2022 09:54:24 +0200 Subject: [PATCH 1/2] improve error handling for GDrive downloads --- torchvision/datasets/utils.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index dbc9cf2a6b4..ed40f4f91b3 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -243,6 +243,27 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ _save_response_content(itertools.chain((first_chunk,), response_content_generator), fpath) response.close() + if md5 and not check_md5(fpath, md5): + msg = f"The MD5 checksum of the download file {fpath} does not match the one on record." + remediation_msg = ( + "Please delete the file and try again. " + "If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." + ) + # GDrive API responses should be smaller than 10kB + if os.stat(fpath).st_size < 10 * 1024: + with open(fpath) as fh: + content = fh.read() + # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 + if re.search(r"]*\s*>|(\&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", content): + remediation_msg = ( + f"We detected some HTML elements in the downloaded file. " + f"This most likely means that the download triggered an unhandled API response by GDrive. " + f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " + f"the response:\n\n{content}" + ) + + raise RuntimeError(f"{msg} {remediation_msg}") + def _get_confirm_token(response: requests.models.Response) -> Optional[str]: for key, value in response.cookies.items(): From 558481008daf9cefb84ad2c7bfdd056dadf98baa Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Mon, 11 Apr 2022 11:47:00 +0200 Subject: [PATCH 2/2] perform HTML check regardless of MD5 check --- torchvision/datasets/utils.py | 36 +++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index c69ab054afa..c85ee68d8c1 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -1,4 +1,5 @@ import bz2 +import contextlib import gzip import hashlib import itertools @@ -261,26 +262,25 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[ _save_response_content(content, fpath) + # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text + if os.stat(fpath).st_size < 10 * 1024: + with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh: + text = fh.read() + # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 + if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text): + warnings.warn( + f"We detected some HTML elements in the downloaded file. " + f"This most likely means that the download triggered an unhandled API response by GDrive. " + f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " + f"the response:\n\n{text}" + ) + if md5 and not check_md5(fpath, md5): - msg = f"The MD5 checksum of the download file {fpath} does not match the one on record." - remediation_msg = ( - "Please delete the file and try again. " - "If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." + raise RuntimeError( + f"The MD5 checksum of the download file {fpath} does not match the one on record." + f"Please delete the file and try again. " + f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues." ) - # GDrive API responses should be smaller than 10kB - if os.stat(fpath).st_size < 10 * 1024: - with open(fpath) as fh: - content = fh.read() - # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604 - if re.search(r"]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", content): - remediation_msg = ( - f"We detected some HTML elements in the downloaded file. " - f"This most likely means that the download triggered an unhandled API response by GDrive. " - f"Please report this to torchvision at https://github.com/pytorch/vision/issues including " - f"the response:\n\n{content}" - ) - - raise RuntimeError(f"{msg} {remediation_msg}") def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None: