Skip to content

Commit 7d41547

Browse files
authored
add custom user agent for download_url (#3498)
* add custom user agent for download_url * fix progress bar * lint * [test] use repo instead of nightly for download tests
1 parent 89edfaa commit 7d41547

File tree

3 files changed

+28
-17
lines changed

3 files changed

+28
-17
lines changed

.github/workflows/tests-schedule.yml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ jobs:
2626
- name: Checkout repository
2727
uses: actions/checkout@v2
2828

29-
- name: Install PyTorch from the nightlies
30-
run: |
31-
pip install numpy
32-
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
29+
- name: Install torch nightly build
30+
run: pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
31+
32+
- name: Install torchvision
33+
run: pip install -e .
3334

3435
- name: Install all optional dataset requirements
3536
run: pip install scipy pandas pycocotools lmdb requests

test/test_datasets_download.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pytest
1515

1616
from torchvision import datasets
17-
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive
17+
from torchvision.datasets.utils import download_url, check_integrity, download_file_from_google_drive, USER_AGENT
1818

1919
from common_utils import get_tmp_dir
2020
from fakedata_generation import places365_root
@@ -150,7 +150,7 @@ def assert_server_response_ok():
150150

151151

152152
def assert_url_is_accessible(url, timeout=5.0):
153-
request = Request(url, headers=dict(method="HEAD"))
153+
request = Request(url, headers={"method": "HEAD", "User-Agent": USER_AGENT})
154154
with assert_server_response_ok():
155155
urlopen(request, timeout=timeout)
156156

@@ -160,7 +160,8 @@ def assert_file_downloads_correctly(url, md5, timeout=5.0):
160160
file = path.join(root, path.basename(url))
161161
with assert_server_response_ok():
162162
with open(file, "wb") as fh:
163-
response = urlopen(url, timeout=timeout)
163+
request = Request(url, headers={"User-Agent": USER_AGENT})
164+
response = urlopen(request, timeout=timeout)
164165
fh.write(response.read())
165166

166167
assert check_integrity(file, md5=md5), "The MD5 checksums mismatch"

torchvision/datasets/utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,28 @@
77
from typing import Any, Callable, List, Iterable, Optional, TypeVar
88
from urllib.parse import urlparse
99
import zipfile
10+
import urllib
11+
import urllib.request
12+
import urllib.error
1013

1114
import torch
1215
from torch.utils.model_zoo import tqdm
1316

1417

18+
USER_AGENT = "pytorch/vision"
19+
20+
21+
def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
22+
with open(filename, "wb") as fh:
23+
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response:
24+
with tqdm(total=response.length) as pbar:
25+
for chunk in iter(lambda: response.read(chunk_size), ""):
26+
if not chunk:
27+
break
28+
pbar.update(chunk_size)
29+
fh.write(chunk)
30+
31+
1532
def gen_bar_updater() -> Callable[[int, int, int], None]:
1633
pbar = tqdm(total=None)
1734

@@ -83,8 +100,6 @@ def download_url(
83100
md5 (str, optional): MD5 checksum of the download. If None, do not check
84101
max_redirect_hops (int, optional): Maximum number of redirect hops allowed
85102
"""
86-
import urllib
87-
88103
root = os.path.expanduser(root)
89104
if not filename:
90105
filename = os.path.basename(url)
@@ -108,19 +123,13 @@ def download_url(
108123
# download the file
109124
try:
110125
print('Downloading ' + url + ' to ' + fpath)
111-
urllib.request.urlretrieve(
112-
url, fpath,
113-
reporthook=gen_bar_updater()
114-
)
126+
_urlretrieve(url, fpath)
115127
except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
116128
if url[:5] == 'https':
117129
url = url.replace('https:', 'http:')
118130
print('Failed download. Trying https -> http instead.'
119131
' Downloading ' + url + ' to ' + fpath)
120-
urllib.request.urlretrieve(
121-
url, fpath,
122-
reporthook=gen_bar_updater()
123-
)
132+
_urlretrieve(url, fpath)
124133
else:
125134
raise e
126135
# check integrity of downloaded file

0 commit comments

Comments
 (0)