From 1bfd2ee9fccec3d61caaf545e3843064c0c010ec Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 28 May 2021 22:16:52 -0400 Subject: [PATCH 01/19] modularizing download function --- torchtext/utils.py | 213 +++++++++++++++++++++++---------------------- 1 file changed, 111 insertions(+), 102 deletions(-) diff --git a/torchtext/utils.py b/torchtext/utils.py index f1115e95fb..1e6f1cbf0e 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -33,6 +33,80 @@ def inner(b=1, bsize=1, tsize=None): return inner +def validate_file(file_obj, hash_value, hash_type="sha256"): + """Validate a given file object with its hash. + + Args: + file_obj: File object to read from. + hash_value (str): Hash for url. + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + Returns: + bool: return True if its a valid file, else False. + + """ + + if hash_type == "sha256": + hash_func = hashlib.sha256() + elif hash_type == "md5": + hash_func = hashlib.md5() + else: + raise ValueError + + while True: + # Read by chunk to avoid filling memory + chunk = file_obj.read(1024 ** 2) + if not chunk: + break + hash_func.update(chunk) + return hash_func.hexdigest() == hash_value + + +def _check_hash(path, hash_value, hash_type): + logging.info('Validating hash {} matches hash of {}'.format(hash_value, path)) + with open(path, "rb") as file_obj: + if not validate_file(file_obj, hash_value, hash_type): + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(os.path.abspath(path))) + + +def _stream_response(r, chunk_size=16 * 1024): + total_size = int(r.headers.get('Content-length', 0)) + with tqdm(total=total_size, unit='B', unit_scale=1) as t: + for chunk in r.iter_content(chunk_size): + if chunk: + t.update(len(chunk)) + yield chunk + + +def _get_response_from_google_drive(url): + confirm_token = None + session = requests.Session() + response = session.get(url, stream=True) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + confirm_token = v + if confirm_token is None: + if "Quota exceeded" in str(response.content): + raise RuntimeError( + "Google drive link {} is currently unavailable, because the quota was exceeded.".format( + url + )) + else: + raise RuntimeError("Internal error: confirm_token was not found in Google drive link.") + + url = url + "&confirm=" + confirm_token + response = session.get(url, stream=True) + + if 'content-disposition' not in response.headers: + raise RuntimeError("Internal error: headers don't contain content-disposition.") + + filename = re.findall("filename=\"(.+)\"", response.headers['content-disposition']) + if filename is None: + raise RuntimeError("Filename could not be autodetected") + filename = filename[0] + + return response, filename + + def download_from_url(url, path=None, root='.data', overwrite=False, hash_value=None, hash_type="sha256"): """Download file, with logic (from tensor2tensor) for Google Drive. Returns @@ -40,6 +114,7 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value= Args: url: the url of the file from URL header. (None) + path: path where file will be saved root: download folder used to store the file in (.data) overwrite: overwrite existing files (False) hash_value (str, optional): hash for url (Default: ``None``). @@ -53,97 +128,59 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value= >>> '.data/validation.tar.gz' """ - if path is not None: - path = os.path.abspath(path) - root = os.path.abspath(root) - - def _check_hash(path): - if hash_value: - logging.info('Validating hash {} matches hash of {}'.format(hash_value, path)) - with open(path, "rb") as file_obj: - if not validate_file(file_obj, hash_value, hash_type): - raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(os.path.abspath(path))) - - def _process_response(r, root, filename): - chunk_size = 16 * 1024 - total_size = int(r.headers.get('Content-length', 0)) - if filename is None: - if 'content-disposition' not in r.headers: - raise RuntimeError("Internal error: headers don't contain content-disposition.") - d = r.headers['content-disposition'] - filename = re.findall("filename=\"(.+)\"", d) - if filename is None: - raise RuntimeError("Filename could not be autodetected") - filename = filename[0] - path = os.path.join(root, filename) - if os.path.exists(path): - logging.info('File %s already exists.' % path) - if not overwrite: - _check_hash(path) - return path - logging.info('Overwriting file %s.' % path) - logging.info('Downloading file {} to {}.'.format(filename, path)) - with open(path, "wb") as file: - with tqdm(total=total_size, unit='B', - unit_scale=1, desc=path.split('/')[-1]) as t: - for chunk in r.iter_content(chunk_size): - if chunk: - file.write(chunk) - t.update(len(chunk)) - logging.info('File {} downloaded.'.format(path)) - - _check_hash(path) - return path - + # figure out filename and root if path is None: _, filename = os.path.split(url) + root = os.path.abspath(root) + path = os.path.join(root, filename) else: + path = os.path.abspath(path) root, filename = os.path.split(os.path.abspath(path)) + # skip requests.get if path exists and not overwrite. + if os.path.exists(path): + logging.info('File %s already exists.' % path) + if not overwrite and hash_value: + _check_hash(path, hash_value, hash_type) + return path + + # make root dir if not exist already if not os.path.exists(root): try: os.makedirs(root) except OSError: - print("Can't create the download directory {}.".format(root)) - raise + raise OSError("Can't create the download directory {}.".format(root)) - if filename is not None: - path = os.path.join(root, filename) - # skip requests.get if path exists and not overwrite. + # get response with special handling of google drive + if 'drive.google.com' not in url: + response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) + else: + logging.info('Downloading from Google Drive; may take a few minutes') + response, filename = _get_response_from_google_drive(url) + + # path where file will be saved + path = os.path.join(root, filename) if os.path.exists(path): logging.info('File %s already exists.' % path) - if not overwrite: - _check_hash(path) + if not overwrite and hash_value: + _check_hash(path, hash_value, hash_type) return path + logging.info('Overwriting file %s.' % path) + logging.info('Downloading file {} to {}.'.format(filename, path)) - if 'drive.google.com' not in url: - response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) - return _process_response(response, root, filename) - else: - # google drive links get filename from google drive - filename = None + # download data to path from streamed response + with open(path, 'wb') as f: + for data in _stream_response(response): + f.write(data) - logging.info('Downloading from Google Drive; may take a few minutes') - confirm_token = None - session = requests.Session() - response = session.get(url, stream=True) - for k, v in response.cookies.items(): - if k.startswith("download_warning"): - confirm_token = v - if confirm_token is None: - if "Quota exceeded" in str(response.content): - raise RuntimeError( - "Google drive link {} is currently unavailable, because the quota was exceeded.".format( - url - )) - else: - raise RuntimeError("Internal error: confirm_token was not found in Google drive link.") + logging.info('File {} downloaded.'.format(path)) - if confirm_token: - url = url + "&confirm=" + confirm_token - response = session.get(url, stream=True) + # validate + if hash_value: + _check_hash(path, hash_value, hash_type) - return _process_response(response, root, filename) + # all good + return path def unicode_csv_reader(unicode_csv_data, **kwargs): @@ -263,31 +300,3 @@ def extract_archive(from_path, to_path=None, overwrite=False): else: raise NotImplementedError( "We currently only support tar.gz, .tgz, .gz and zip achives.") - - -def validate_file(file_obj, hash_value, hash_type="sha256"): - """Validate a given file object with its hash. - - Args: - file_obj: File object to read from. - hash_value (str): Hash for url. - hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). - Returns: - bool: return True if its a valid file, else False. - - """ - - if hash_type == "sha256": - hash_func = hashlib.sha256() - elif hash_type == "md5": - hash_func = hashlib.md5() - else: - raise ValueError - - while True: - # Read by chunk to avoid filling memory - chunk = file_obj.read(1024 ** 2) - if not chunk: - break - hash_func.update(chunk) - return hash_func.hexdigest() == hash_value From 13bb6dcb530e896b298b05cce90c12ecf1db84e7 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 1 Jun 2021 11:19:58 -0400 Subject: [PATCH 02/19] delegating download to iopath --- torchtext/_download_hooks.py | 131 +++++++++++++++++++++++++++++++++++ torchtext/utils.py | 72 ++----------------- 2 files changed, 138 insertions(+), 65 deletions(-) create mode 100644 torchtext/_download_hooks.py diff --git a/torchtext/_download_hooks.py b/torchtext/_download_hooks.py new file mode 100644 index 0000000000..53ab2a4064 --- /dev/null +++ b/torchtext/_download_hooks.py @@ -0,0 +1,131 @@ +from typing import List, Optional, Union, IO, Dict, Any +import requests +import os +import logging +import uuid +import re +from tqdm import tqdm +from iopath.common.file_io import ( + PathHandler, + PathManager, + get_cache_dir, + file_lock, + HTTPURLHandler, +) + + +def _stream_response(r, chunk_size=16 * 1024): + total_size = int(r.headers.get('Content-length', 0)) + with tqdm(total=total_size, unit='B', unit_scale=1) as t: + for chunk in r.iter_content(chunk_size): + if chunk: + t.update(len(chunk)) + yield chunk + + +def _get_response_from_google_drive(url): + confirm_token = None + session = requests.Session() + response = session.get(url, stream=True) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + confirm_token = v + if confirm_token is None: + if "Quota exceeded" in str(response.content): + raise RuntimeError( + "Google drive link {} is currently unavailable, because the quota was exceeded.".format( + url + )) + else: + raise RuntimeError("Internal error: confirm_token was not found in Google drive link.") + + url = url + "&confirm=" + confirm_token + response = session.get(url, stream=True) + + if 'content-disposition' not in response.headers: + raise RuntimeError("Internal error: headers don't contain content-disposition.") + + filename = re.findall("filename=\"(.+)\"", response.headers['content-disposition']) + if filename is None: + raise RuntimeError("Filename could not be autodetected") + filename = filename[0] + + return response, filename + + +class GoogleDrivePathHandler(PathHandler): + """ + Download URLs and cache them to disk. + """ + + MAX_FILENAME_LEN = 250 + + def __init__(self) -> None: + self.cache_map: Dict[str, str] = {} + + def _get_supported_prefixes(self) -> List[str]: + return ["https://drive.google.com"] + + def _get_local_path( + self, + path: str, + force: bool = False, + cache_dir: Optional[str] = None, + **kwargs: Any, + ) -> str: + """ + This implementation downloads the remote resource from google drive and caches it locally. + The resource will only be downloaded if not previously requested. + """ + self._check_kwargs(kwargs) + if ( + force + or path not in self.cache_map + or not os.path.exists(self.cache_map[path]) + ): + logger = logging.getLogger(__name__) + dirname = get_cache_dir(cache_dir) + + response, filename = _get_response_from_google_drive(path) + if len(filename) > self.MAX_FILENAME_LEN: + filename = filename[:100] + "_" + uuid.uuid4().hex + + cached = os.path.join(dirname, filename) + with file_lock(cached): + if not os.path.isfile(cached): + logger.info("Downloading {} ...".format(path)) + with open(cached, 'wb') as f: + for data in _stream_response(response): + f.write(data) + logger.info("URL {} cached in {}".format(path, cached)) + self.cache_map[path] = cached + return self.cache_map[path] + + def _open( + self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any + ) -> Union[IO[str], IO[bytes]]: + """ + Open a google drive path. The resource is first downloaded and cached + locally. + Args: + path (str): A URI supported by this PathHandler + mode (str): Specifies the mode in which the file is opened. It defaults + to 'r'. + buffering (int): Not used for this PathHandler. + Returns: + file: a file-like object. + """ + self._check_kwargs(kwargs) + assert mode in ("r", "rb"), "{} does not support open with {} mode".format( + self.__class__.__name__, mode + ) + assert ( + buffering == -1 + ), f"{self.__class__.__name__} does not support the `buffering` argument" + local_path = self._get_local_path(path, force=False) + return open(local_path, mode) + + +_PATH_MANAGER = PathManager() +_PATH_MANAGER.register_handler(HTTPURLHandler()) +_PATH_MANAGER.register_handler(GoogleDrivePathHandler()) diff --git a/torchtext/utils.py b/torchtext/utils.py index 1e6f1cbf0e..74b9f2eb12 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -1,14 +1,13 @@ -import requests import csv import hashlib -from tqdm import tqdm import os import tarfile import logging -import re import sys import zipfile import gzip +import shutil +from ._download_hooks import _PATH_MANAGER def reporthook(t): @@ -68,45 +67,6 @@ def _check_hash(path, hash_value, hash_type): raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(os.path.abspath(path))) -def _stream_response(r, chunk_size=16 * 1024): - total_size = int(r.headers.get('Content-length', 0)) - with tqdm(total=total_size, unit='B', unit_scale=1) as t: - for chunk in r.iter_content(chunk_size): - if chunk: - t.update(len(chunk)) - yield chunk - - -def _get_response_from_google_drive(url): - confirm_token = None - session = requests.Session() - response = session.get(url, stream=True) - for k, v in response.cookies.items(): - if k.startswith("download_warning"): - confirm_token = v - if confirm_token is None: - if "Quota exceeded" in str(response.content): - raise RuntimeError( - "Google drive link {} is currently unavailable, because the quota was exceeded.".format( - url - )) - else: - raise RuntimeError("Internal error: confirm_token was not found in Google drive link.") - - url = url + "&confirm=" + confirm_token - response = session.get(url, stream=True) - - if 'content-disposition' not in response.headers: - raise RuntimeError("Internal error: headers don't contain content-disposition.") - - filename = re.findall("filename=\"(.+)\"", response.headers['content-disposition']) - if filename is None: - raise RuntimeError("Filename could not be autodetected") - filename = filename[0] - - return response, filename - - def download_from_url(url, path=None, root='.data', overwrite=False, hash_value=None, hash_type="sha256"): """Download file, with logic (from tensor2tensor) for Google Drive. Returns @@ -137,41 +97,23 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value= path = os.path.abspath(path) root, filename = os.path.split(os.path.abspath(path)) - # skip requests.get if path exists and not overwrite. + # skip download if path exists and overwrite is not True if os.path.exists(path): logging.info('File %s already exists.' % path) if not overwrite and hash_value: _check_hash(path, hash_value, hash_type) return path - # make root dir if not exist already + # make root dir if does not exist if not os.path.exists(root): try: os.makedirs(root) except OSError: raise OSError("Can't create the download directory {}.".format(root)) - # get response with special handling of google drive - if 'drive.google.com' not in url: - response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) - else: - logging.info('Downloading from Google Drive; may take a few minutes') - response, filename = _get_response_from_google_drive(url) - - # path where file will be saved - path = os.path.join(root, filename) - if os.path.exists(path): - logging.info('File %s already exists.' % path) - if not overwrite and hash_value: - _check_hash(path, hash_value, hash_type) - return path - logging.info('Overwriting file %s.' % path) - logging.info('Downloading file {} to {}.'.format(filename, path)) - - # download data to path from streamed response - with open(path, 'wb') as f: - for data in _stream_response(response): - f.write(data) + # download data to path + local_path = _PATH_MANAGER.get_local_path(url) + shutil.move(local_path, path) logging.info('File {} downloaded.'.format(path)) From d23b1bb68fc2ca301e813a66b115f2fe3d41b830 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 1 Jun 2021 11:25:21 -0400 Subject: [PATCH 03/19] updating dependency on iopath --- .circleci/unittest/linux/scripts/environment.yml | 1 + .circleci/unittest/windows/scripts/environment.yml | 1 + requirements.txt | 1 + 3 files changed, 3 insertions(+) diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index e1e9f2cda8..73c1cc9e15 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -9,6 +9,7 @@ dependencies: - dataclasses - nltk - requests + - iopath - revtok - pytest - pytest-cov diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index 9716f09114..b05290edf4 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -9,6 +9,7 @@ dependencies: - dataclasses - nltk - requests + - iopath - revtok - pytest - pytest-cov diff --git a/requirements.txt b/requirements.txt index fd100b8eb3..ceb21c99cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ tqdm # Downloading data and other files requests +iopath # Optional NLP tools nltk From 601baffceaa4a311fce376071fdd7bbe0106237c Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 1 Jun 2021 22:49:48 -0400 Subject: [PATCH 04/19] dependency add --- packaging/torchtext/meta.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/packaging/torchtext/meta.yaml b/packaging/torchtext/meta.yaml index 36008e5cf5..61b8f50adb 100644 --- a/packaging/torchtext/meta.yaml +++ b/packaging/torchtext/meta.yaml @@ -20,6 +20,7 @@ requirements: run: - python - requests + - iopath - tqdm {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} From c89908b4dbcc7520d0513ac773d4687f0034869e Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Tue, 1 Jun 2021 23:51:36 -0400 Subject: [PATCH 05/19] adding iopath channel --- packaging/pkg_helpers.bash | 1 + 1 file changed, 1 insertion(+) diff --git a/packaging/pkg_helpers.bash b/packaging/pkg_helpers.bash index 8f1e24f1da..e0c81205bc 100644 --- a/packaging/pkg_helpers.bash +++ b/packaging/pkg_helpers.bash @@ -180,6 +180,7 @@ setup_pip_pytorch_version() { # You MUST have populated PYTORCH_VERSION_SUFFIX before hand. setup_conda_pytorch_constraint() { CONDA_CHANNEL_FLAGS=${CONDA_CHANNEL_FLAGS:-} + CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c iopath" if [[ -z "$PYTORCH_VERSION" ]]; then export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly" export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | python -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['pytorch'][-1]['version']))")" From 367f8ecc67eb64e12aa9f5cbed26869ba2695607 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 2 Jun 2021 00:04:59 -0400 Subject: [PATCH 06/19] dummy change --- .circleci/cached_datasets_list.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.circleci/cached_datasets_list.txt b/.circleci/cached_datasets_list.txt index c345588868..c32654767e 100644 --- a/.circleci/cached_datasets_list.txt +++ b/.circleci/cached_datasets_list.txt @@ -18,4 +18,5 @@ WikiText103 PennTreebank SQuAD1 SQuAD2 -EnWik9 \ No newline at end of file +EnWik9 +tempchangetoactivatechacherebuild \ No newline at end of file From acae94e3241d31928895f2c8a571370555593d1d Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 2 Jun 2021 00:27:30 -0400 Subject: [PATCH 07/19] requirements update --- docs/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/requirements.txt b/docs/requirements.txt index 560a2b3600..65538b7b8a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,3 @@ sphinx==2.4.4 +iopath -e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme From 00870b3dbcbb58943dd39ae548a4d3acc654580d Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 2 Jun 2021 09:56:19 -0400 Subject: [PATCH 08/19] fixing test failures and cache issues --- .circleci/config.yml | 2 +- .circleci/config.yml.in | 2 +- torchtext/utils.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index b2c61d69f6..c3eff5c17d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -46,7 +46,7 @@ commands: name: Generate CCI cache key command: echo "$(date "+%D")" > .cachekey - cat cached_datasets_list.txt >> .cachekey + cat .circleci/cached_datasets_list.txt >> .cachekey - persist_to_workspace: root: . paths: diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index 3c997d777c..dfdee08450 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -46,7 +46,7 @@ commands: name: Generate CCI cache key command: echo "$(date "+%D")" > .cachekey - cat cached_datasets_list.txt >> .cachekey + cat .circleci/cached_datasets_list.txt >> .cachekey - persist_to_workspace: root: . paths: diff --git a/torchtext/utils.py b/torchtext/utils.py index 74b9f2eb12..f960042e83 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -100,8 +100,9 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value= # skip download if path exists and overwrite is not True if os.path.exists(path): logging.info('File %s already exists.' % path) - if not overwrite and hash_value: - _check_hash(path, hash_value, hash_type) + if not overwrite: + if hash_value: + _check_hash(path, hash_value, hash_type) return path # make root dir if does not exist From 44dc54de51c3ac3a75e6a71c6e0cbaa0552f6723 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Wed, 2 Jun 2021 09:56:39 -0400 Subject: [PATCH 09/19] restoring state --- .circleci/cached_datasets_list.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.circleci/cached_datasets_list.txt b/.circleci/cached_datasets_list.txt index c32654767e..c345588868 100644 --- a/.circleci/cached_datasets_list.txt +++ b/.circleci/cached_datasets_list.txt @@ -18,5 +18,4 @@ WikiText103 PennTreebank SQuAD1 SQuAD2 -EnWik9 -tempchangetoactivatechacherebuild \ No newline at end of file +EnWik9 \ No newline at end of file From f9e3b1e6549abec04129d892b0001972e26ab203 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 3 Jun 2021 22:23:00 -0400 Subject: [PATCH 10/19] minor updates --- .circleci/cached_datasets_list.txt | 3 +- torchtext/_download_hooks.py | 44 ++++++++++++++++++++++++++++-- torchtext/utils.py | 5 ++-- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/.circleci/cached_datasets_list.txt b/.circleci/cached_datasets_list.txt index c345588868..75d76b687a 100644 --- a/.circleci/cached_datasets_list.txt +++ b/.circleci/cached_datasets_list.txt @@ -18,4 +18,5 @@ WikiText103 PennTreebank SQuAD1 SQuAD2 -EnWik9 \ No newline at end of file +EnWik9 +dummy \ No newline at end of file diff --git a/torchtext/_download_hooks.py b/torchtext/_download_hooks.py index 53ab2a4064..56dcc992a9 100644 --- a/torchtext/_download_hooks.py +++ b/torchtext/_download_hooks.py @@ -4,6 +4,7 @@ import logging import uuid import re +import shutil from tqdm import tqdm from iopath.common.file_io import ( PathHandler, @@ -126,6 +127,45 @@ def _open( return open(local_path, mode) +class CombinedInternalPathhandler(PathHandler): + def __init__(self): + path_manager = PathManager() + path_manager.register_handler(HTTPURLHandler()) + path_manager.register_handler(GoogleDrivePathHandler()) + self.path_manager = path_manager + + def _get_supported_prefixes(self) -> List[str]: + return ["https://", "http://"] + + def _get_local_path( + self, + path: str, + force: bool = False, + cache_dir: Optional[str] = None, + **kwargs: Any, + ) -> str: + + destination = kwargs["destination"] + + local_path = self.path_manager.get_local_path(path, force) + + shutil.move(local_path, destination) + + return destination + + def _open( + self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any + ) -> Union[IO[str], IO[bytes]]: + self._check_kwargs(kwargs) + assert mode in ("r", "rb"), "{} does not support open with {} mode".format( + self.__class__.__name__, mode + ) + assert ( + buffering == -1 + ), f"{self.__class__.__name__} does not support the `buffering` argument" + local_path = self._get_local_path(path, force=False) + return open(local_path, mode) + + _PATH_MANAGER = PathManager() -_PATH_MANAGER.register_handler(HTTPURLHandler()) -_PATH_MANAGER.register_handler(GoogleDrivePathHandler()) +_PATH_MANAGER.register_handler(CombinedInternalPathhandler()) diff --git a/torchtext/utils.py b/torchtext/utils.py index f960042e83..5feba28a9a 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -112,9 +112,8 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value= except OSError: raise OSError("Can't create the download directory {}.".format(root)) - # download data to path - local_path = _PATH_MANAGER.get_local_path(url) - shutil.move(local_path, path) + # download data and move to path + _PATH_MANAGER.get_local_path(url, destination=path) logging.info('File {} downloaded.'.format(path)) From 0d9f9994cfa809b5c739e1162232de84db7ed84e Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 3 Jun 2021 23:40:59 -0400 Subject: [PATCH 11/19] minor fix --- .circleci/cached_datasets_list.txt | 1 - torchtext/utils.py | 1 - 2 files changed, 2 deletions(-) diff --git a/.circleci/cached_datasets_list.txt b/.circleci/cached_datasets_list.txt index 75d76b687a..9989d5f278 100644 --- a/.circleci/cached_datasets_list.txt +++ b/.circleci/cached_datasets_list.txt @@ -19,4 +19,3 @@ PennTreebank SQuAD1 SQuAD2 EnWik9 -dummy \ No newline at end of file diff --git a/torchtext/utils.py b/torchtext/utils.py index 5feba28a9a..1195555fb0 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -6,7 +6,6 @@ import sys import zipfile import gzip -import shutil from ._download_hooks import _PATH_MANAGER From e2bd6d7e0751d4d0c5dbff2333b924141d6b7a1f Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Thu, 17 Jun 2021 23:24:14 -0400 Subject: [PATCH 12/19] minor naming changes --- torchtext/_download_hooks.py | 4 ++-- torchtext/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/torchtext/_download_hooks.py b/torchtext/_download_hooks.py index 56dcc992a9..269d251ec3 100644 --- a/torchtext/_download_hooks.py +++ b/torchtext/_download_hooks.py @@ -167,5 +167,5 @@ def _open( return open(local_path, mode) -_PATH_MANAGER = PathManager() -_PATH_MANAGER.register_handler(CombinedInternalPathhandler()) +_DATASET_DOWNLOAD_MANAGER = PathManager() +_DATASET_DOWNLOAD_MANAGER.register_handler(CombinedInternalPathhandler()) diff --git a/torchtext/utils.py b/torchtext/utils.py index 1195555fb0..a05ca8ddca 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -6,7 +6,7 @@ import sys import zipfile import gzip -from ._download_hooks import _PATH_MANAGER +from ._download_hooks import _DATASET_DOWNLOAD_MANAGER def reporthook(t): @@ -112,7 +112,7 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value= raise OSError("Can't create the download directory {}.".format(root)) # download data and move to path - _PATH_MANAGER.get_local_path(url, destination=path) + _DATASET_DOWNLOAD_MANAGER.get_local_path(url, destination=path) logging.info('File {} downloaded.'.format(path)) From 376c770051a3056061baf8d76ec699448849e5ba Mon Sep 17 00:00:00 2001 From: Guanheng Zhang Date: Wed, 25 Mar 2020 10:26:06 -0700 Subject: [PATCH 13/19] bump up the version --- .circleci/unittest/windows/scripts/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index b05290edf4..b27818aa99 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -21,5 +21,6 @@ dependencies: - tqdm - certifi - future + - pywin32==225 - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 From cb5b192587d8d9d7a85b74c6cb748971f2839e04 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 25 Jun 2021 17:57:49 -0400 Subject: [PATCH 14/19] Revert "bump up the version" This reverts commit 376c770051a3056061baf8d76ec699448849e5ba. --- .circleci/unittest/windows/scripts/environment.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index b27818aa99..b05290edf4 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -21,6 +21,5 @@ dependencies: - tqdm - certifi - future - - pywin32==225 - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 From cd95a10572e54c33e1b30bf15d8b4bb24f6ce1cc Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 25 Jun 2021 18:02:04 -0400 Subject: [PATCH 15/19] fixing environment issue --- .circleci/unittest/windows/scripts/environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index b05290edf4..b27818aa99 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -21,5 +21,6 @@ dependencies: - tqdm - certifi - future + - pywin32==225 - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 From 09a800f5f236768ac838252d96aa478b72b11927 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 25 Jun 2021 18:42:52 -0400 Subject: [PATCH 16/19] unspecify version of win32 --- .circleci/unittest/windows/scripts/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index b27818aa99..ab8f6f6cd2 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -21,6 +21,6 @@ dependencies: - tqdm - certifi - future - - pywin32==225 + - pywin32 - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 From b9543546a876657410febde09e7661c7c2605277 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Fri, 25 Jun 2021 19:25:54 -0400 Subject: [PATCH 17/19] trying versioning --- .circleci/unittest/windows/scripts/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index ab8f6f6cd2..7d4b795a66 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -21,6 +21,6 @@ dependencies: - tqdm - certifi - future - - pywin32 + - pywin32>=224 - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 From 683639c6b82260adc849256e2ce99ffa2a2aa61f Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Sat, 26 Jun 2021 03:35:00 -0400 Subject: [PATCH 18/19] minor update on fixing pywin32 version --- .circleci/unittest/windows/scripts/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index 7d4b795a66..1104b6607b 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -21,6 +21,6 @@ dependencies: - tqdm - certifi - future - - pywin32>=224 + - pywin32=225 - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0 From 76035fa5b049a2c1a3c878a499f9ed311f408a68 Mon Sep 17 00:00:00 2001 From: Parmeet Singh Bhatia Date: Sun, 27 Jun 2021 23:29:08 -0400 Subject: [PATCH 19/19] fixing dependency --- .circleci/unittest/windows/scripts/environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index 1104b6607b..04142428c2 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -4,6 +4,7 @@ channels: dependencies: - flake8>=3.7.9 - codecov + - pywin32 - pip - pip: - dataclasses @@ -21,6 +22,5 @@ dependencies: - tqdm - certifi - future - - pywin32=225 - https://github.com/explosion/spacy-models/releases/download/de_core_news_sm-3.0.0/de_core_news_sm-3.0.0.tar.gz#egg=de_core_news_sm==3.0.0 - https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.0.0/en_core_web_sm-3.0.0.tar.gz#egg=en_core_web_sm==3.0.0