Skip to content
158 changes: 117 additions & 41 deletions torchaudio/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import csv
import errno
import gzip
import hashlib
import logging
import os
Expand All @@ -12,7 +11,7 @@

import six
import torch
import torchaudio
from six.moves import urllib
from torch.utils.data import Dataset
from torch.utils.model_zoo import tqdm

Expand Down Expand Up @@ -53,18 +52,6 @@ def unicode_csv_reader(unicode_csv_data, **kwargs):
yield line


def gen_bar_updater():
pbar = tqdm(total=None)

def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)

return bar_update


def makedir_exist_ok(dirpath):
"""
Python2 support for os.makedirs(.., exist_ok=True)
Expand All @@ -78,41 +65,130 @@ def makedir_exist_ok(dirpath):
raise


def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
def stream_url(url, start_byte=None, block_size=32 * 1024, progress_bar=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be an excellent target for a multithreaded buffer function :D

Copy link
Contributor Author

@vincentqb vincentqb Oct 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, see here :D

"""Stream url by chunk

Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str, optional): Name to save the file under. If None, use the basename of the URL
md5 (str, optional): MD5 checksum of the download. If None, do not check
url (str): Url.
start_byte (Optional[int]): Start streaming at that point.
block_size (int): Size of chunks to stream.
progress_bar (bool): Display a progress bar.
"""
from six.moves import urllib

root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
# If we already have the whole file, there is no need to download it again
req = urllib.request.Request(url, method="HEAD")
url_size = int(urllib.request.urlopen(req).info().get("Content-Length", -1))
if url_size == start_byte:
return

req = urllib.request.Request(url)
if start_byte:
req.headers["Range"] = "bytes={}-".format(start_byte)

with urllib.request.urlopen(req) as upointer, tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=url_size,
disable=not progress_bar,
) as pbar:

num_bytes = 0
while True:
chunk = upointer.read(block_size)
if not chunk:
break
yield chunk
num_bytes += len(chunk)
pbar.update(len(chunk))


def download_url(
url,
download_folder,
filename=None,
hash_value=None,
hash_type="sha256",
progress_bar=True,
resume=False,
):
"""Download file to disk.

makedir_exist_ok(root)
Args:
url (str): Url.
download_folder (str): Folder to download file.
filename (str): Name of downloaded file. If None, it is inferred from the url.
hash_value (str): Hash for url.
hash_type (str): Hash type, among "sha256" and "md5".
progress_bar (bool): Display a progress bar.
resume (bool): Enable resuming download.
"""

req = urllib.request.Request(url, method="HEAD")
req_info = urllib.request.urlopen(req).info()

# Detect filename
filename = filename or req_info.get_filename() or os.path.basename(url)
filepath = os.path.join(download_folder, filename)

# downloads file
if os.path.isfile(fpath):
print("Using downloaded file: " + fpath)
if resume and os.path.exists(filepath):
mode = "ab"
local_size = os.path.getsize(filepath)
elif not resume and os.path.exists(filepath):
raise RuntimeError(
"{} already exists. Delete the file manually and retry.".format(filepath)
)
else:
try:
print("Downloading " + url + " to " + fpath)
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
except (urllib.error.URLError, IOError) as e:
if url[:5] == "https":
url = url.replace("https:", "http:")
print(
"Failed download. Trying https -> http instead."
" Downloading " + url + " to " + fpath
mode = "wb"
local_size = None

if hash_value and local_size == int(req_info.get("Content-Length", -1)):
with open(filepath, "rb") as file_obj:
if validate_file(file_obj, hash_value, hash_type):
return
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
)

with open(filepath, mode) as fpointer:
for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar):
fpointer.write(chunk)

with open(filepath, "rb") as file_obj:
if hash_value and not validate_file(file_obj, hash_value, hash_type):
raise RuntimeError(
"The hash of {} does not match. Delete the file manually and retry.".format(
filepath
)
urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater())
else:
raise e
)


def validate_file(file_obj, hash_value, hash_type="sha256"):
"""Validate a given file object with its hash.

Args:
file_obj: File object to read from.
hash_value (str): Hash for url.
hash_type (str): Hash type, among "sha256" and "md5".
"""

if hash_type == "sha256":
hash_func = hashlib.sha256()
elif hash_type == "md5":
hash_func = hashlib.md5()
else:
raise ValueError

while True:
# Read by chunk to avoid filling memory
chunk = f.read(1024 ** 2)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this function actually consume file_obj?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This consumes an object with read(chunk_size) signature, see line 147. Is that what you meant?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but should it be file_obj.read rather than f.read?

Copy link
Contributor Author

@vincentqb vincentqb Nov 25, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing out! Opened #352.

if not chunk:
break
hash_func.update(chunk)

return hash_func.hexdigest() == hash_value


def extract_archive(from_path, to_path=None, overwrite=False):
Expand Down