diff --git a/torchtext/utils.py b/torchtext/utils.py index fb5dde7c32..921673a0f0 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -6,11 +6,11 @@ import zipfile import torch +from iopath.common.file_io import file_lock from torchtext import _CACHE_DIR from ._download_hooks import _DATASET_DOWNLOAD_MANAGER - logger = logging.getLogger(__name__) @@ -96,32 +96,35 @@ def download_from_url(url, path=None, root=".data", overwrite=False, hash_value= path = os.path.abspath(path) root, filename = os.path.split(os.path.abspath(path)) - # skip download if path exists and overwrite is not True - if os.path.exists(path): - logger.info("File %s already exists." % path) - if not overwrite: - if hash_value: - _check_hash(path, hash_value, hash_type) - return path - - # make root dir if does not exist - if not os.path.exists(root): - try: - os.makedirs(root) - except OSError: - raise OSError("Can't create the download directory {}.".format(root)) - - # download data and move to path - _DATASET_DOWNLOAD_MANAGER.get_local_path(url, destination=path) - - logger.info("File {} downloaded.".format(path)) - - # validate - if hash_value: - _check_hash(path, hash_value, hash_type) - - # all good - return path + # In a concurrent setting, adding a file lock ensures the first thread to acquire will actually download the model + # and the other ones will just use the existing path (which will not contain a partially downloaded model). + with file_lock(path): + # skip download if path exists and overwrite is not True + if os.path.exists(path): + logger.info("File %s already exists." % path) + if not overwrite: + if hash_value: + _check_hash(path, hash_value, hash_type) + return path + + # make root dir if does not exist + if not os.path.exists(root): + try: + os.makedirs(root) + except OSError as exc: + raise OSError("Can't create the download directory {}.".format(root)) from exc + + # download data and move to path + _DATASET_DOWNLOAD_MANAGER.get_local_path(url, destination=path) + + logger.info("File {} downloaded.".format(path)) + + # validate + if hash_value: + _check_hash(path, hash_value, hash_type) + + # all good + return path def extract_archive(from_path, to_path=None, overwrite=False):