Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .circleci/unittest/linux/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ dependencies:
- dataclasses
- nltk
- requests
- iopath
- revtok
- pytest
- pytest-cov
Expand Down
2 changes: 0 additions & 2 deletions .circleci/unittest/windows/scripts/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ channels:
dependencies:
- flake8>=3.7.9
- codecov
- pywin32
- pip
- pip:
- dataclasses
- nltk
- requests
- iopath
- revtok
- pytest
- pytest-cov
Expand Down
1 change: 0 additions & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
sphinx==2.4.4
iopath
-e git+git://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
1 change: 0 additions & 1 deletion packaging/pkg_helpers.bash
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ setup_pip_pytorch_version() {
# You MUST have populated PYTORCH_VERSION_SUFFIX before hand.
setup_conda_pytorch_constraint() {
CONDA_CHANNEL_FLAGS=${CONDA_CHANNEL_FLAGS:-}
CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c iopath"
if [[ -z "$PYTORCH_VERSION" ]]; then
export CONDA_CHANNEL_FLAGS="${CONDA_CHANNEL_FLAGS} -c pytorch-nightly"
export PYTORCH_VERSION="$(conda search --json 'pytorch[channel=pytorch-nightly]' | python -c "import sys, json, re; print(re.sub(r'\\+.*$', '', json.load(sys.stdin)['pytorch'][-1]['version']))")"
Expand Down
1 change: 0 additions & 1 deletion packaging/torchtext/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ requirements:
run:
- python
- requests
- iopath
- tqdm
{{ environ.get('CONDA_PYTORCH_CONSTRAINT') }}

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ tqdm

# Downloading data and other files
requests
iopath

# Optional NLP tools
nltk
Expand Down
134 changes: 10 additions & 124 deletions torchtext/_download_hooks.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,6 @@
from typing import List, Optional, Union, IO, Dict, Any
import requests
import os
import logging
import uuid
import re
import shutil
from tqdm import tqdm
from iopath.common.file_io import (
PathHandler,
PathManager,
get_cache_dir,
file_lock,
HTTPURLHandler,
)


def _stream_response(r, chunk_size=16 * 1024):
Expand Down Expand Up @@ -54,118 +42,16 @@ def _get_response_from_google_drive(url):
return response, filename


class GoogleDrivePathHandler(PathHandler):
"""
Download URLs and cache them to disk.
"""

MAX_FILENAME_LEN = 250

def __init__(self) -> None:
self.cache_map: Dict[str, str] = {}

def _get_supported_prefixes(self) -> List[str]:
return ["https://drive.google.com"]

def _get_local_path(
self,
path: str,
force: bool = False,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> str:
"""
This implementation downloads the remote resource from google drive and caches it locally.
The resource will only be downloaded if not previously requested.
"""
self._check_kwargs(kwargs)
if (
force
or path not in self.cache_map
or not os.path.exists(self.cache_map[path])
):
logger = logging.getLogger(__name__)
dirname = get_cache_dir(cache_dir)

response, filename = _get_response_from_google_drive(path)
if len(filename) > self.MAX_FILENAME_LEN:
filename = filename[:100] + "_" + uuid.uuid4().hex

cached = os.path.join(dirname, filename)
with file_lock(cached):
if not os.path.isfile(cached):
logger.info("Downloading {} ...".format(path))
with open(cached, 'wb') as f:
for data in _stream_response(response):
f.write(data)
logger.info("URL {} cached in {}".format(path, cached))
self.cache_map[path] = cached
return self.cache_map[path]

def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a google drive path. The resource is first downloaded and cached
locally.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): Not used for this PathHandler.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
assert mode in ("r", "rb"), "{} does not support open with {} mode".format(
self.__class__.__name__, mode
)
assert (
buffering == -1
), f"{self.__class__.__name__} does not support the `buffering` argument"
local_path = self._get_local_path(path, force=False)
return open(local_path, mode)


class CombinedInternalPathhandler(PathHandler):
def __init__(self):
path_manager = PathManager()
path_manager.register_handler(HTTPURLHandler())
path_manager.register_handler(GoogleDrivePathHandler())
self.path_manager = path_manager

def _get_supported_prefixes(self) -> List[str]:
return ["https://", "http://"]

def _get_local_path(
self,
path: str,
force: bool = False,
cache_dir: Optional[str] = None,
**kwargs: Any,
) -> str:

destination = kwargs["destination"]

local_path = self.path_manager.get_local_path(path, force)

shutil.move(local_path, destination)

return destination
class DownloadManager:
def get_local_path(self, url, destination):
if 'drive.google.com' not in url:
response = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
else:
response, filename = _get_response_from_google_drive(url)

def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
self._check_kwargs(kwargs)
assert mode in ("r", "rb"), "{} does not support open with {} mode".format(
self.__class__.__name__, mode
)
assert (
buffering == -1
), f"{self.__class__.__name__} does not support the `buffering` argument"
local_path = self._get_local_path(path, force=False)
return open(local_path, mode)
with open(destination, 'wb') as f:
for chunk in _stream_response(response):
f.write(chunk)


_DATASET_DOWNLOAD_MANAGER = PathManager()
_DATASET_DOWNLOAD_MANAGER.register_handler(CombinedInternalPathhandler())
_DATASET_DOWNLOAD_MANAGER = DownloadManager()