Skip to content

Commit 236b587

Browse files
committed
resume download, validate with md5 or sha256.
1 parent 5023bd2 commit 236b587

File tree

2 files changed

+103
-39
lines changed

2 files changed

+103
-39
lines changed

requirements.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,5 @@ pytest
1818

1919
# Testing only Py3 compat
2020
backports.tempfile
21+
22+
requests

torchaudio/datasets/utils.py

Lines changed: 101 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import tarfile
99
import zipfile
1010

11+
import requests
1112
import six
1213
import torch
1314
import torchaudio
15+
from six.moves import urllib
1416
from torch.utils.data import Dataset
1517
from torch.utils.model_zoo import tqdm
1618

@@ -51,18 +53,6 @@ def unicode_csv_reader(unicode_csv_data, **kwargs):
5153
yield line
5254

5355

54-
def gen_bar_updater():
55-
pbar = tqdm(total=None)
56-
57-
def bar_update(count, block_size, total_size):
58-
if pbar.total is None and total_size:
59-
pbar.total = total_size
60-
progress_bytes = count * block_size
61-
pbar.update(progress_bytes - pbar.n)
62-
63-
return bar_update
64-
65-
6656
def makedir_exist_ok(dirpath):
6757
"""
6858
Python2 support for os.makedirs(.., exist_ok=True)
@@ -76,41 +66,113 @@ def makedir_exist_ok(dirpath):
7666
raise
7767

7868

79-
def download_url(url, root, filename=None, md5=None):
80-
"""Download a file from a url and place it in root.
69+
def download_url_resume(url, download_folder, resume_byte_pos=None):
70+
"""Download url to disk with possible resumption.
8171
8272
Args:
83-
url (str): URL to download file from
84-
root (str): Directory to place downloaded file in
85-
filename (str, optional): Name to save the file under. If None, use the basename of the URL
86-
md5 (str, optional): MD5 checksum of the download. If None, do not check
73+
url (str): Url.
74+
download_folder (str): Folder to download file.
75+
resume_byte_pos (int): Position of byte from where to resume the download.
8776
"""
88-
from six.moves import urllib
77+
# Get size of file
78+
r = requests.head(url)
79+
file_size = int(r.headers.get("content-length", 0))
8980

90-
root = os.path.expanduser(root)
91-
if not filename:
92-
filename = os.path.basename(url)
93-
fpath = os.path.join(root, filename)
81+
# Append information to resume download at specific byte position to header
82+
resume_header = (
83+
{"Range": "bytes={}-".format(resume_byte_pos)} if resume_byte_pos else None
84+
)
9485

95-
makedir_exist_ok(root)
86+
# Establish connection
87+
r = requests.get(url, stream=True, headers=resume_header)
9688

97-
# downloads file
98-
if os.path.isfile(fpath):
99-
print("Using downloaded file: " + fpath)
100-
else:
101-
try:
102-
print("Downloading " + url + " to " + fpath)
103-
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
104-
except (urllib.error.URLError, IOError) as e:
105-
if url[:5] == "https":
106-
url = url.replace("https:", "http:")
89+
# Set configuration
90+
n_block = 32
91+
block_size = 1024
92+
initial_pos = resume_byte_pos if resume_byte_pos else 0
93+
mode = "ab" if resume_byte_pos else "wb"
94+
95+
filename = os.path.basename(url)
96+
filepath = os.path.join(download_folder, os.path.basename(url))
97+
98+
with open(filepath, mode) as f:
99+
with tqdm(
100+
unit="B", unit_scale=True, unit_divisor=1024, total=file_size
101+
) as pbar:
102+
for chunk in r.iter_content(n_block * block_size):
103+
f.write(chunk)
104+
pbar.update(len(chunk))
105+
106+
107+
def download_url(url, download_folder, hash_value=None, hash_type="sha256"):
108+
"""Execute the correct download operation.
109+
Depending on the size of the file online and offline, resume the
110+
download if the file offline is smaller than online.
111+
112+
Args:
113+
url (str): Url.
114+
download_folder (str): Folder to download file.
115+
hash_value (str): Hash for url.
116+
hash_type (str): Hash type.
117+
"""
118+
# Establish connection to header of file
119+
r = requests.head(url)
120+
121+
# Get filesize of online and offline file
122+
file_size_online = int(r.headers.get("content-length", 0))
123+
filepath = os.path.join(download_folder, os.path.basename(url))
124+
125+
if os.path.exists(filepath):
126+
file_size_offline = os.path.getsize(filepath)
127+
128+
if file_size_online != file_size_offline:
129+
# Resume download
130+
print("File {} is incomplete. Resume download.".format(filepath))
131+
download_url_resume(url, download_folder, file_size_offline)
132+
elif hash_value:
133+
if validate_download_url(url, download_folder, hash_value, hash_type):
134+
print("File {} is validated. Skip download.".format(filepath))
135+
else:
107136
print(
108-
"Failed download. Trying https -> http instead."
109-
" Downloading " + url + " to " + fpath
137+
"File {} is corrupt. Delete it manually and retry.".format(filepath)
110138
)
111-
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
112-
else:
113-
raise e
139+
else:
140+
# Skip download
141+
print("File {} is complete. Skip download.".format(filepath))
142+
else:
143+
# Start download
144+
print("File {} has not been downloaded. Start download.".format(filepath))
145+
download_url_resume(url, download_folder)
146+
147+
148+
def validate_download_url(url, download_folder, hash_value, hash_type="sha256"):
149+
"""Validate a given file with its hash.
150+
The downloaded file is hashed and compared to a pre-registered
151+
has value to validate the download procedure.
152+
153+
Args:
154+
url (str): Url.
155+
download_folder (str): Folder to download file.
156+
hash_value (str): Hash for url.
157+
hash_type (str): Hash type.
158+
"""
159+
filepath = os.path.join(download_folder, os.path.basename(url))
160+
161+
if hash_type == "sha256":
162+
sha = hashlib.sha256()
163+
elif hash_type == "md5":
164+
sha = hashlib.md5()
165+
else:
166+
raise ValueError
167+
168+
with open(filepath, "rb") as f:
169+
while True:
170+
chunk = f.read(1000 * 1000) # 1MB so that memory is not exhausted
171+
if not chunk:
172+
break
173+
sha.update(chunk)
174+
175+
return sha.hexdigest() == hash_value
114176

115177

116178
def extract_archive(from_path, to_path=None, overwrite=False):

0 commit comments

Comments
 (0)