diff --git a/torchaudio/datasets/utils.py b/torchaudio/datasets/utils.py index f3daebcd12..545a12bc41 100644 --- a/torchaudio/datasets/utils.py +++ b/torchaudio/datasets/utils.py @@ -1,6 +1,5 @@ import csv import errno -import gzip import hashlib import logging import os @@ -12,7 +11,7 @@ import six import torch -import torchaudio +from six.moves import urllib from torch.utils.data import Dataset from torch.utils.model_zoo import tqdm @@ -53,18 +52,6 @@ def unicode_csv_reader(unicode_csv_data, **kwargs): yield line -def gen_bar_updater(): - pbar = tqdm(total=None) - - def bar_update(count, block_size, total_size): - if pbar.total is None and total_size: - pbar.total = total_size - progress_bytes = count * block_size - pbar.update(progress_bytes - pbar.n) - - return bar_update - - def makedir_exist_ok(dirpath): """ Python2 support for os.makedirs(.., exist_ok=True) @@ -78,41 +65,130 @@ def makedir_exist_ok(dirpath): raise -def download_url(url, root, filename=None, md5=None): - """Download a file from a url and place it in root. +def stream_url(url, start_byte=None, block_size=32 * 1024, progress_bar=True): + """Stream url by chunk Args: - url (str): URL to download file from - root (str): Directory to place downloaded file in - filename (str, optional): Name to save the file under. If None, use the basename of the URL - md5 (str, optional): MD5 checksum of the download. If None, do not check + url (str): Url. + start_byte (Optional[int]): Start streaming at that point. + block_size (int): Size of chunks to stream. + progress_bar (bool): Display a progress bar. """ - from six.moves import urllib - root = os.path.expanduser(root) - if not filename: - filename = os.path.basename(url) - fpath = os.path.join(root, filename) + # If we already have the whole file, there is no need to download it again + req = urllib.request.Request(url, method="HEAD") + url_size = int(urllib.request.urlopen(req).info().get("Content-Length", -1)) + if url_size == start_byte: + return + + req = urllib.request.Request(url) + if start_byte: + req.headers["Range"] = "bytes={}-".format(start_byte) + + with urllib.request.urlopen(req) as upointer, tqdm( + unit="B", + unit_scale=True, + unit_divisor=1024, + total=url_size, + disable=not progress_bar, + ) as pbar: + + num_bytes = 0 + while True: + chunk = upointer.read(block_size) + if not chunk: + break + yield chunk + num_bytes += len(chunk) + pbar.update(len(chunk)) + + +def download_url( + url, + download_folder, + filename=None, + hash_value=None, + hash_type="sha256", + progress_bar=True, + resume=False, +): + """Download file to disk. - makedir_exist_ok(root) + Args: + url (str): Url. + download_folder (str): Folder to download file. + filename (str): Name of downloaded file. If None, it is inferred from the url. + hash_value (str): Hash for url. + hash_type (str): Hash type, among "sha256" and "md5". + progress_bar (bool): Display a progress bar. + resume (bool): Enable resuming download. + """ + + req = urllib.request.Request(url, method="HEAD") + req_info = urllib.request.urlopen(req).info() + + # Detect filename + filename = filename or req_info.get_filename() or os.path.basename(url) + filepath = os.path.join(download_folder, filename) - # downloads file - if os.path.isfile(fpath): - print("Using downloaded file: " + fpath) + if resume and os.path.exists(filepath): + mode = "ab" + local_size = os.path.getsize(filepath) + elif not resume and os.path.exists(filepath): + raise RuntimeError( + "{} already exists. Delete the file manually and retry.".format(filepath) + ) else: - try: - print("Downloading " + url + " to " + fpath) - urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) - except (urllib.error.URLError, IOError) as e: - if url[:5] == "https": - url = url.replace("https:", "http:") - print( - "Failed download. Trying https -> http instead." - " Downloading " + url + " to " + fpath + mode = "wb" + local_size = None + + if hash_value and local_size == int(req_info.get("Content-Length", -1)): + with open(filepath, "rb") as file_obj: + if validate_file(file_obj, hash_value, hash_type): + return + raise RuntimeError( + "The hash of {} does not match. Delete the file manually and retry.".format( + filepath + ) + ) + + with open(filepath, mode) as fpointer: + for chunk in stream_url(url, start_byte=local_size, progress_bar=progress_bar): + fpointer.write(chunk) + + with open(filepath, "rb") as file_obj: + if hash_value and not validate_file(file_obj, hash_value, hash_type): + raise RuntimeError( + "The hash of {} does not match. Delete the file manually and retry.".format( + filepath ) - urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updater()) - else: - raise e + ) + + +def validate_file(file_obj, hash_value, hash_type="sha256"): + """Validate a given file object with its hash. + + Args: + file_obj: File object to read from. + hash_value (str): Hash for url. + hash_type (str): Hash type, among "sha256" and "md5". + """ + + if hash_type == "sha256": + hash_func = hashlib.sha256() + elif hash_type == "md5": + hash_func = hashlib.md5() + else: + raise ValueError + + while True: + # Read by chunk to avoid filling memory + chunk = f.read(1024 ** 2) + if not chunk: + break + hash_func.update(chunk) + + return hash_func.hexdigest() == hash_value def extract_archive(from_path, to_path=None, overwrite=False):