Skip to content

Commit e7057d5

Browse files
Add should_rank_save_checkpoint property to Training Plugins (#7684)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3e5d6e9 commit e7057d5

File tree

6 files changed

+26
-22
lines changed

6 files changed

+26
-22
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141
- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))
4242

4343

44+
- Added `should_rank_save_checkpoint` property to Training Plugins ([#7684](https://github.com/PyTorchLightning/pytorch-lightning/pull/7684))
45+
46+
4447
### Changed
4548

4649
- Changed calling of `untoggle_optimizer(opt_idx)` out of the closure function ([#7563](https://github.com/PyTorchLightning/pytorch-lightning/pull/7563)

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,11 @@
3333

3434
import pytorch_lightning as pl
3535
from pytorch_lightning.callbacks.base import Callback
36-
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_only, rank_zero_warn
36+
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_warn
3737
from pytorch_lightning.utilities.cloud_io import get_filesystem
3838
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3939
from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT
4040
from pytorch_lightning.utilities.warnings import WarningCache
41-
from pytorch_lightning.utilities.xla_device import tpu_training_and_local_rank_zero
4241

4342
log = logging.getLogger(__name__)
4443
warning_cache = WarningCache()
@@ -473,9 +472,8 @@ def save_function(self, value: Optional[Callable]) -> None:
473472
)
474473
self._save_function = value
475474

476-
@rank_zero_only
477-
def _del_model(self, filepath: str) -> None:
478-
if self._fs.exists(filepath):
475+
def _del_model(self, trainer: 'pl.Trainer', filepath: str) -> None:
476+
if trainer.should_rank_save_checkpoint and self._fs.exists(filepath):
479477
self._fs.rm(filepath)
480478
log.debug(f"Removed checkpoint: {filepath}")
481479

@@ -493,7 +491,7 @@ def _do_save(self, trainer: 'pl.Trainer', filepath: str) -> None:
493491
trainer.dev_debugger.track_checkpointing_history(filepath)
494492

495493
# make paths
496-
if trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer):
494+
if trainer.should_rank_save_checkpoint:
497495
self._fs.makedirs(os.path.dirname(filepath), exist_ok=True)
498496

499497
# delegate the saving to the trainer
@@ -631,7 +629,7 @@ def __resolve_ckpt_dir(self, trainer: 'pl.Trainer') -> None:
631629

632630
self.dirpath = ckpt_path
633631

634-
if (not trainer.fast_dev_run and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))):
632+
if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint:
635633
self._fs.makedirs(self.dirpath, exist_ok=True)
636634

637635
def _add_backward_monitor_support(self, trainer: 'pl.Trainer') -> None:
@@ -694,11 +692,8 @@ def _save_last_checkpoint(self, trainer: 'pl.Trainer', monitor_candidates: Dict[
694692

695693
self._save_model(trainer, filepath)
696694

697-
if (
698-
self.last_model_path and self.last_model_path != filepath
699-
and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))
700-
):
701-
self._del_model(self.last_model_path)
695+
if self.last_model_path and self.last_model_path != filepath and trainer.should_rank_save_checkpoint:
696+
self._del_model(trainer, self.last_model_path)
702697

703698
self.last_model_path = filepath
704699

@@ -724,9 +719,9 @@ def _save_none_monitor_checkpoint(self, trainer: 'pl.Trainer', monitor_candidate
724719

725720
if (
726721
self.save_top_k is None and self.best_model_path and self.best_model_path != filepath
727-
and (trainer.is_global_zero or tpu_training_and_local_rank_zero(trainer))
722+
and trainer.should_rank_save_checkpoint
728723
):
729-
self._del_model(self.best_model_path)
724+
self._del_model(trainer, self.best_model_path)
730725

731726
self.best_model_path = filepath
732727

@@ -773,7 +768,7 @@ def _update_best_and_save(
773768
self._save_model(trainer, filepath)
774769

775770
if del_filepath is not None and filepath != del_filepath:
776-
self._del_model(del_filepath)
771+
self._del_model(trainer, del_filepath)
777772

778773
def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
779774
"""

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,3 +308,7 @@ def teardown(self) -> None:
308308
# TPU teardown
309309
os.environ.pop("PT_XLA_DEBUG", None)
310310
self.barrier("teardown")
311+
312+
@property
313+
def should_rank_save_checkpoint(self) -> bool:
314+
return self.local_rank == 0

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,8 @@ def teardown(self) -> None:
309309
@classmethod
310310
def register_plugins(cls, plugin_registry):
311311
pass
312+
313+
@property
314+
def should_rank_save_checkpoint(self) -> bool:
315+
"""Returns whether the checkpoint should be saved (rank based)"""
316+
return self.is_global_zero

pytorch_lightning/trainer/properties.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ def world_size(self) -> int:
9797
# some training types define a world size
9898
return getattr(self.accelerator.training_type_plugin, "world_size", 1)
9999

100+
@property
101+
def should_rank_save_checkpoint(self) -> bool:
102+
return self.accelerator.training_type_plugin.should_rank_save_checkpoint
103+
100104
@property
101105
def _distrib_type(self) -> DistributedType:
102106
return self.accelerator_connector._distrib_type

pytorch_lightning/utilities/xla_device.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
import traceback
1818
from multiprocessing import Process, Queue
1919

20-
import pytorch_lightning as pl
21-
from pytorch_lightning.utilities.enums import DeviceType
2220
from pytorch_lightning.utilities.imports import _XLA_AVAILABLE
2321

2422
if _XLA_AVAILABLE:
@@ -105,8 +103,3 @@ def tpu_device_exists() -> bool:
105103
if XLADeviceUtils._TPU_AVAILABLE:
106104
os.environ["PL_TPU_AVAILABLE"] = '1'
107105
return XLADeviceUtils._TPU_AVAILABLE
108-
109-
110-
def tpu_training_and_local_rank_zero(trainer: 'pl.Trainer') -> bool:
111-
return trainer._device_type == DeviceType.TPU and \
112-
trainer.training_type_plugin.local_rank == 0

0 commit comments

Comments
 (0)