From afb6a92f44a2cd237aeb728211a911b30e3b6b3a Mon Sep 17 00:00:00 2001 From: Gael Le Lan Date: Mon, 3 Apr 2023 08:45:24 -0700 Subject: [PATCH 1/2] Add file_lock when loading SentencePieceModel Summary: SentencePieceModel loading can cause a RuntimeError when concurrent threads try to load/download it (e.g. when using T5 tokenizer in a DDP model training). Adding a file lock ensures the first thread to acquire the lock will actually download the model and the other ones will just use the existing path (which will not contain a partially downloaded model). This diff was inspired by D42686913 and reverts D44566854 behavior (there is no need to overwrite anymore). It should also disable unit test flakiness such as https://www.internalfb.com/intern/test/281475067136403?ref_report_id=0 and solve https://fb.workplace.com/groups/pytorchtext/permalink/920234369294862/. Reviewed By: joecummings Differential Revision: D44604474 fbshipit-source-id: 1c117fb6d1e72cce31cbf30bf72d513ad535b0d4 --- torchtext/utils.py | 56 +++++++++++++++++++++++++--------------------- 1 file changed, 30 insertions(+), 26 deletions(-) diff --git a/torchtext/utils.py b/torchtext/utils.py index fb5dde7c32..229311d69e 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -10,6 +10,7 @@ from ._download_hooks import _DATASET_DOWNLOAD_MANAGER +from iopath.common.file_io import file_lock logger = logging.getLogger(__name__) @@ -96,32 +97,35 @@ def download_from_url(url, path=None, root=".data", overwrite=False, hash_value= path = os.path.abspath(path) root, filename = os.path.split(os.path.abspath(path)) - # skip download if path exists and overwrite is not True - if os.path.exists(path): - logger.info("File %s already exists." % path) - if not overwrite: - if hash_value: - _check_hash(path, hash_value, hash_type) - return path - - # make root dir if does not exist - if not os.path.exists(root): - try: - os.makedirs(root) - except OSError: - raise OSError("Can't create the download directory {}.".format(root)) - - # download data and move to path - _DATASET_DOWNLOAD_MANAGER.get_local_path(url, destination=path) - - logger.info("File {} downloaded.".format(path)) - - # validate - if hash_value: - _check_hash(path, hash_value, hash_type) - - # all good - return path + # In a concurrent setting, adding a file lock ensures the first thread to acquire will actually download the model + # and the other ones will just use the existing path (which will not contain a partially downloaded model). + with file_lock(path): + # skip download if path exists and overwrite is not True + if os.path.exists(path): + logger.info("File %s already exists." % path) + if not overwrite: + if hash_value: + _check_hash(path, hash_value, hash_type) + return path + + # make root dir if does not exist + if not os.path.exists(root): + try: + os.makedirs(root) + except OSError as exc: + raise OSError("Can't create the download directory {}.".format(root)) from exc + + # download data and move to path + _DATASET_DOWNLOAD_MANAGER.get_local_path(url, destination=path) + + logger.info("File {} downloaded.".format(path)) + + # validate + if hash_value: + _check_hash(path, hash_value, hash_type) + + # all good + return path def extract_archive(from_path, to_path=None, overwrite=False): From c4ac4057961cdbfc8ed0b655afcf5130c17c7398 Mon Sep 17 00:00:00 2001 From: Joe Cummings Date: Tue, 4 Apr 2023 22:38:01 -0400 Subject: [PATCH 2/2] Formatting fixes --- torchtext/utils.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchtext/utils.py b/torchtext/utils.py index 229311d69e..921673a0f0 100644 --- a/torchtext/utils.py +++ b/torchtext/utils.py @@ -6,12 +6,11 @@ import zipfile import torch +from iopath.common.file_io import file_lock from torchtext import _CACHE_DIR from ._download_hooks import _DATASET_DOWNLOAD_MANAGER -from iopath.common.file_io import file_lock - logger = logging.getLogger(__name__)