diff --git a/CHANGELOG.md b/CHANGELOG.md index a95b7183fac4a..8ab19b6e16cda 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -354,6 +354,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119)) +- Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199)) + + - Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990)) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index f25920a40bf83..51fb22a301035 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -404,21 +404,30 @@ def __init__( def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: # save the config in `setup` because (1) we want it to save regardless of the trainer function run # and we want to save before processes are spawned - log_dir = trainer.log_dir + log_dir = trainer.log_dir # this broadcasts the directory assert log_dir is not None config_path = os.path.join(log_dir, self.config_filename) - if not self.overwrite and os.path.isfile(config_path): - raise RuntimeError( - f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" - " results of a previous run. You can delete the previous config file," - " set `LightningCLI(save_config_callback=None)` to disable config saving," - " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." - ) + fs = get_filesystem(log_dir) + + if not self.overwrite: + # check if the file exists on rank 0 + file_exists = fs.isfile(config_path) if trainer.is_global_zero else False + # broadcast whether to fail to all ranks + file_exists = trainer.strategy.broadcast(file_exists) + if file_exists: + raise RuntimeError( + f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" + " results of a previous run. You can delete the previous config file," + " set `LightningCLI(save_config_callback=None)` to disable config saving," + " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." + ) + + # save the file on rank 0 if trainer.is_global_zero: # save only on rank zero to avoid race conditions on DDP. # the `log_dir` needs to be created as we rely on the logger to do it usually # but it hasn't logged anything at this point - get_filesystem(log_dir).makedirs(log_dir, exist_ok=True) + fs.makedirs(log_dir, exist_ok=True) self.parser.save( self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile )