diff --git a/.circleci/cached_datasets_list.txt b/.circleci/cached_datasets_list.txt index c345588868..9989d5f278 100644 --- a/.circleci/cached_datasets_list.txt +++ b/.circleci/cached_datasets_list.txt @@ -18,4 +18,4 @@ WikiText103 PennTreebank SQuAD1 SQuAD2 -EnWik9 \ No newline at end of file +EnWik9 diff --git a/.circleci/config.yml b/.circleci/config.yml index 953cd8b9be..5a9ca3ef1b 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -46,7 +46,7 @@ commands: name: Generate CCI cache key command: echo "$(date "+%D")" > .cachekey - cat cached_datasets_list.txt >> .cachekey + cat .circleci/cached_datasets_list.txt >> .cachekey - persist_to_workspace: root: . paths: diff --git a/.circleci/config.yml.in b/.circleci/config.yml.in index f97eff89fc..5473c45f65 100644 --- a/.circleci/config.yml.in +++ b/.circleci/config.yml.in @@ -46,7 +46,7 @@ commands: name: Generate CCI cache key command: echo "$(date "+%D")" > .cachekey - cat cached_datasets_list.txt >> .cachekey + cat .circleci/cached_datasets_list.txt >> .cachekey - persist_to_workspace: root: . paths: diff --git a/.circleci/unittest/linux/scripts/environment.yml b/.circleci/unittest/linux/scripts/environment.yml index e1e9f2cda8..73c1cc9e15 100644 --- a/.circleci/unittest/linux/scripts/environment.yml +++ b/.circleci/unittest/linux/scripts/environment.yml @@ -9,6 +9,7 @@ dependencies: - dataclasses - nltk - requests + - iopath - revtok - pytest - pytest-cov diff --git a/.circleci/unittest/windows/scripts/environment.yml b/.circleci/unittest/windows/scripts/environment.yml index 9716f09114..04142428c2 100644 --- a/.circleci/unittest/windows/scripts/environment.yml +++ b/.circleci/unittest/windows/scripts/environment.yml @@ -4,11 +4,13 @@ channels: dependencies: - flake8>=3.7.9 - codecov + - pywin32 - pip - pip: - dataclasses - nltk - requests + - iopath - revtok - pytest - pytest-cov diff --git a/docs/requirements.txt b/docs/requirements.txt index 560a2b3600..65538b7b8a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,2 +1,3 @@ sphinx==2.4.4 +iopath -e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme diff --git a/packaging/pkg_helpers.bash b/packaging/pkg_helpers.bash index 8f1e24f1da..e0c81205bc 100644 --- a/packaging/pkg_helpers.bash +++ b/packaging/pkg_helpers.bash @@ -180,6 +180,7 @@ setup_pip_pytorch_version() { # You MUST have populated PYTORCH_VERSION_SUFFIX before hand. setup_conda_pytorch_constraint() { CONDA_CHANNEL_FLAGS=${CONDA_CHANNEL_FLAGS:-} + CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c iopath" if [[ -z "$PYTORCH_VERSION" ]]; then export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly" export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | python -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['pytorch'][-1]['version']))")" diff --git a/packaging/torchtext/meta.yaml b/packaging/torchtext/meta.yaml index 36008e5cf5..61b8f50adb 100644 --- a/packaging/torchtext/meta.yaml +++ b/packaging/torchtext/meta.yaml @@ -20,6 +20,7 @@ requirements: run: - python - requests + - iopath - tqdm {{ environ.get('CONDA_PYTORCH_CONSTRAINT') }} diff --git a/requirements.txt b/requirements.txt index fd100b8eb3..ceb21c99cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ tqdm # Downloading data and other files requests +iopath # Optional NLP tools nltk diff --git a/torchtext/_download_hooks.py b/torchtext/_download_hooks.py new file mode 100644 index 0000000000..269d251ec3 --- /dev/null +++ b/torchtext/_download_hooks.py @@ -0,0 +1,171 @@ +from typing import List, Optional, Union, IO, Dict, Any +import requests +import os +import logging +import uuid +import re +import shutil +from tqdm import tqdm +from iopath.common.file_io import ( + PathHandler, + PathManager, + get_cache_dir, + file_lock, + HTTPURLHandler, +) + + +def _stream_response(r, chunk_size=16 * 1024): + total_size = int(r.headers.get('Content-length', 0)) + with tqdm(total=total_size, unit='B', unit_scale=1) as t: + for chunk in r.iter_content(chunk_size): + if chunk: + t.update(len(chunk)) + yield chunk + + +def _get_response_from_google_drive(url): + confirm_token = None + session = requests.Session() + response = session.get(url, stream=True) + for k, v in response.cookies.items(): + if k.startswith("download_warning"): + confirm_token = v + if confirm_token is None: + if "Quota exceeded" in str(response.content): + raise RuntimeError( + "Google drive link {} is currently unavailable, because the quota was exceeded.".format( + url + )) + else: + raise RuntimeError("Internal error: confirm_token was not found in Google drive link.") + + url = url + "&confirm=" + confirm_token + response = session.get(url, stream=True) + + if 'content-disposition' not in response.headers: + raise RuntimeError("Internal error: headers don't contain content-disposition.") + + filename = re.findall("filename=\"(.+)\"", response.headers['content-disposition']) + if filename is None: + raise RuntimeError("Filename could not be autodetected") + filename = filename[0] + + return response, filename + + +class GoogleDrivePathHandler(PathHandler): + """ + Download URLs and cache them to disk. + """ + + MAX_FILENAME_LEN = 250 + + def __init__(self) -> None: + self.cache_map: Dict[str, str] = {} + + def _get_supported_prefixes(self) -> List[str]: + return ["https://drive.google.com"] + + def _get_local_path( + self, + path: str, + force: bool = False, + cache_dir: Optional[str] = None, + **kwargs: Any, + ) -> str: + """ + This implementation downloads the remote resource from google drive and caches it locally. + The resource will only be downloaded if not previously requested. + """ + self._check_kwargs(kwargs) + if ( + force + or path not in self.cache_map + or not os.path.exists(self.cache_map[path]) + ): + logger = logging.getLogger(__name__) + dirname = get_cache_dir(cache_dir) + + response, filename = _get_response_from_google_drive(path) + if len(filename) > self.MAX_FILENAME_LEN: + filename = filename[:100] + "_" + uuid.uuid4().hex + + cached = os.path.join(dirname, filename) + with file_lock(cached): + if not os.path.isfile(cached): + logger.info("Downloading {} ...".format(path)) + with open(cached, 'wb') as f: + for data in _stream_response(response): + f.write(data) + logger.info("URL {} cached in {}".format(path, cached)) + self.cache_map[path] = cached + return self.cache_map[path] + + def _open( + self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any + ) -> Union[IO[str], IO[bytes]]: + """ + Open a google drive path. The resource is first downloaded and cached + locally. + Args: + path (str): A URI supported by this PathHandler + mode (str): Specifies the mode in which the file is opened. It defaults + to 'r'. + buffering (int): Not used for this PathHandler. + Returns: + file: a file-like object. + """ + self._check_kwargs(kwargs) + assert mode in ("r", "rb"), "{} does not support open with {} mode".format( + self.__class__.__name__, mode + ) + assert ( + buffering == -1 + ), f"{self.__class__.__name__} does not support the `buffering` argument" + local_path = self._get_local_path(path, force=False) + return open(local_path, mode) + + +class CombinedInternalPathhandler(PathHandler): + def __init__(self): + path_manager = PathManager() + path_manager.register_handler(HTTPURLHandler()) + path_manager.register_handler(GoogleDrivePathHandler()) + self.path_manager = path_manager + + def _get_supported_prefixes(self) -> List[str]: + return ["https://", "http://"] + + def _get_local_path( + self, + path: str, + force: bool = False, + cache_dir: Optional[str] = None, + **kwargs: Any, + ) -> str: + + destination = kwargs["destination"] + + local_path = self.path_manager.get_local_path(path, force) + + shutil.move(local_path, destination) + + return destination + + def _open( + self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any + ) -> Union[IO[str], IO[bytes]]: + self._check_kwargs(kwargs) + assert mode in ("r", "rb"), "{} does not support open with {} mode".format( + self.__class__.__name__, mode + ) + assert ( + buffering == -1 + ), f"{self.__class__.__name__} does not support the `buffering` argument" + local_path = self._get_local_path(path, force=False) + return open(local_path, mode) + + +_DATASET_DOWNLOAD_MANAGER = PathManager() +_DATASET_DOWNLOAD_MANAGER.register_handler(CombinedInternalPathhandler()) diff --git a/torchtext/utils.py b/torchtext/utils.py index f1115e95fb..a05ca8ddca 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -1,14 +1,12 @@ -import requests import csv import hashlib -from tqdm import tqdm import os import tarfile import logging -import re import sys import zipfile import gzip +from ._download_hooks import _DATASET_DOWNLOAD_MANAGER def reporthook(t): @@ -33,6 +31,41 @@ def inner(b=1, bsize=1, tsize=None): return inner +def validate_file(file_obj, hash_value, hash_type="sha256"): + """Validate a given file object with its hash. + + Args: + file_obj: File object to read from. + hash_value (str): Hash for url. + hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). + Returns: + bool: return True if its a valid file, else False. + + """ + + if hash_type == "sha256": + hash_func = hashlib.sha256() + elif hash_type == "md5": + hash_func = hashlib.md5() + else: + raise ValueError + + while True: + # Read by chunk to avoid filling memory + chunk = file_obj.read(1024 ** 2) + if not chunk: + break + hash_func.update(chunk) + return hash_func.hexdigest() == hash_value + + +def _check_hash(path, hash_value, hash_type): + logging.info('Validating hash {} matches hash of {}'.format(hash_value, path)) + with open(path, "rb") as file_obj: + if not validate_file(file_obj, hash_value, hash_type): + raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(os.path.abspath(path))) + + def download_from_url(url, path=None, root='.data', overwrite=False, hash_value=None, hash_type="sha256"): """Download file, with logic (from tensor2tensor) for Google Drive. Returns @@ -40,6 +73,7 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value= Args: url: the url of the file from URL header. (None) + path: path where file will be saved root: download folder used to store the file in (.data) overwrite: overwrite existing files (False) hash_value (str, optional): hash for url (Default: ``None``). @@ -53,97 +87,41 @@ def download_from_url(url, path=None, root='.data', overwrite=False, hash_value= >>> '.data/validation.tar.gz' """ - if path is not None: - path = os.path.abspath(path) - root = os.path.abspath(root) - - def _check_hash(path): - if hash_value: - logging.info('Validating hash {} matches hash of {}'.format(hash_value, path)) - with open(path, "rb") as file_obj: - if not validate_file(file_obj, hash_value, hash_type): - raise RuntimeError("The hash of {} does not match. Delete the file manually and retry.".format(os.path.abspath(path))) - - def _process_response(r, root, filename): - chunk_size = 16 * 1024 - total_size = int(r.headers.get('Content-length', 0)) - if filename is None: - if 'content-disposition' not in r.headers: - raise RuntimeError("Internal error: headers don't contain content-disposition.") - d = r.headers['content-disposition'] - filename = re.findall("filename=\"(.+)\"", d) - if filename is None: - raise RuntimeError("Filename could not be autodetected") - filename = filename[0] - path = os.path.join(root, filename) - if os.path.exists(path): - logging.info('File %s already exists.' % path) - if not overwrite: - _check_hash(path) - return path - logging.info('Overwriting file %s.' % path) - logging.info('Downloading file {} to {}.'.format(filename, path)) - with open(path, "wb") as file: - with tqdm(total=total_size, unit='B', - unit_scale=1, desc=path.split('/')[-1]) as t: - for chunk in r.iter_content(chunk_size): - if chunk: - file.write(chunk) - t.update(len(chunk)) - logging.info('File {} downloaded.'.format(path)) - - _check_hash(path) - return path - + # figure out filename and root if path is None: _, filename = os.path.split(url) + root = os.path.abspath(root) + path = os.path.join(root, filename) else: + path = os.path.abspath(path) root, filename = os.path.split(os.path.abspath(path)) + # skip download if path exists and overwrite is not True + if os.path.exists(path): + logging.info('File %s already exists.' % path) + if not overwrite: + if hash_value: + _check_hash(path, hash_value, hash_type) + return path + + # make root dir if does not exist if not os.path.exists(root): try: os.makedirs(root) except OSError: - print("Can't create the download directory {}.".format(root)) - raise + raise OSError("Can't create the download directory {}.".format(root)) - if filename is not None: - path = os.path.join(root, filename) - # skip requests.get if path exists and not overwrite. - if os.path.exists(path): - logging.info('File %s already exists.' % path) - if not overwrite: - _check_hash(path) - return path + # download data and move to path + _DATASET_DOWNLOAD_MANAGER.get_local_path(url, destination=path) - if 'drive.google.com' not in url: - response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) - return _process_response(response, root, filename) - else: - # google drive links get filename from google drive - filename = None - - logging.info('Downloading from Google Drive; may take a few minutes') - confirm_token = None - session = requests.Session() - response = session.get(url, stream=True) - for k, v in response.cookies.items(): - if k.startswith("download_warning"): - confirm_token = v - if confirm_token is None: - if "Quota exceeded" in str(response.content): - raise RuntimeError( - "Google drive link {} is currently unavailable, because the quota was exceeded.".format( - url - )) - else: - raise RuntimeError("Internal error: confirm_token was not found in Google drive link.") - - if confirm_token: - url = url + "&confirm=" + confirm_token - response = session.get(url, stream=True) - - return _process_response(response, root, filename) + logging.info('File {} downloaded.'.format(path)) + + # validate + if hash_value: + _check_hash(path, hash_value, hash_type) + + # all good + return path def unicode_csv_reader(unicode_csv_data, **kwargs): @@ -263,31 +241,3 @@ def extract_archive(from_path, to_path=None, overwrite=False): else: raise NotImplementedError( "We currently only support tar.gz, .tgz, .gz and zip achives.") - - -def validate_file(file_obj, hash_value, hash_type="sha256"): - """Validate a given file object with its hash. - - Args: - file_obj: File object to read from. - hash_value (str): Hash for url. - hash_type (str, optional): Hash type, among "sha256" and "md5" (Default: ``"sha256"``). - Returns: - bool: return True if its a valid file, else False. - - """ - - if hash_type == "sha256": - hash_func = hashlib.sha256() - elif hash_type == "md5": - hash_func = hashlib.md5() - else: - raise ValueError - - while True: - # Read by chunk to avoid filling memory - chunk = file_obj.read(1024 ** 2) - if not chunk: - break - hash_func.update(chunk) - return hash_func.hexdigest() == hash_value