Skip to content
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed double evaluation bug with fault-tolerance enabled where the second call was completely skipped ([#11119](https://github.com/PyTorchLightning/pytorch-lightning/pull/11119))


- Fixed `LightningCLI` race condition while saving the config ([#11199](https://github.com/PyTorchLightning/pytorch-lightning/pull/11199))


- Fixed an issue with the `TPUSpawnPlugin` handling the `XLA_USE_BF16` environment variable incorrectly ([#10990](https://github.com/PyTorchLightning/pytorch-lightning/pull/10990))


Expand Down
27 changes: 18 additions & 9 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,21 +404,30 @@ def __init__(
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
# and we want to save before processes are spawned
log_dir = trainer.log_dir
log_dir = trainer.log_dir # this broadcasts the directory
assert log_dir is not None
config_path = os.path.join(log_dir, self.config_filename)
if not self.overwrite and os.path.isfile(config_path):
raise RuntimeError(
f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
" results of a previous run. You can delete the previous config file,"
" set `LightningCLI(save_config_callback=None)` to disable config saving,"
" or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file."
)
fs = get_filesystem(log_dir)

if not self.overwrite:
# check if the file exists on rank 0
file_exists = fs.isfile(config_path) if trainer.is_global_zero else False
# broadcast whether to fail to all ranks
file_exists = trainer.strategy.broadcast(file_exists)
if file_exists:
raise RuntimeError(
f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"
" results of a previous run. You can delete the previous config file,"
" set `LightningCLI(save_config_callback=None)` to disable config saving,"
" or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file."
)

# save the file on rank 0
if trainer.is_global_zero:
# save only on rank zero to avoid race conditions on DDP.
# the `log_dir` needs to be created as we rely on the logger to do it usually
# but it hasn't logged anything at this point
get_filesystem(log_dir).makedirs(log_dir, exist_ok=True)
fs.makedirs(log_dir, exist_ok=True)
self.parser.save(
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)
Expand Down