diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a26b63151f5a8..ed10885f5dbd9 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -21,7 +21,7 @@ import time from pathlib import Path from time import sleep -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import __main__ import numpy as np @@ -50,10 +50,16 @@ ) from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import ( + init_ddp_connection, + rank_zero_info, + rank_zero_only, + ReduceOp, + sync_ddp_if_available, +) from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_10: from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer @@ -122,6 +128,7 @@ def __init__( self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self._rank_0_has_called_call_children_scripts: bool = False + self._has_loaded_state_dict: bool = False self.set_world_ranks() @property @@ -530,3 +537,26 @@ def teardown(self) -> None: self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() + + self._has_loaded_state_dict = False + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + if "state_dict" not in checkpoint and self._has_loaded_state_dict: + return + self.lightning_module.load_state_dict(checkpoint["state_dict"]) + + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + rank_zero_info( + f"DistributedDataParallel has {self.num_processes} processes. " + "Serializing checkpoint loading to avoid CPU OOMs." + ) + for current_worker in range(self.num_processes): + if self.local_rank == current_worker: + checkpoint = super().load_checkpoint(checkpoint_path) + self.lightning_module.on_load_checkpoint(checkpoint) + self.load_model_state_dict(checkpoint) + del checkpoint["state_dict"] + self._has_loaded_state_dict = True + log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.") + self.barrier() + return checkpoint diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index e2e8c316f48d1..fcc180efb66ff 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -746,6 +746,8 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool: return False def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + if "state_dict" not in checkpoint and self._has_loaded_state_dict: + return # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()` if self.load_full_weights and self.zero_stage_3: self.model_to_device() diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index cf36a3502702d..c9d608ff52986 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -171,6 +171,7 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.lightning_module.on_load_checkpoint(checkpoint) self.lightning_module.load_state_dict(checkpoint["state_dict"]) def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index da6a81e8add44..207e078c8176b 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -144,9 +144,6 @@ def restore_model(self) -> None: model = self.trainer.lightning_module - # hook: give user access to checkpoint if needed. - model.on_load_checkpoint(self._loaded_checkpoint) - # call hpc specific hook if self.hpc_resume_path is not None: model.on_hpc_load(self._loaded_checkpoint) @@ -166,7 +163,6 @@ def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None: if checkpoint_path is not None: checkpoint = self._load_and_validate_checkpoint(checkpoint_path) - self.trainer.lightning_module.on_load_checkpoint(checkpoint) self.trainer.training_type_plugin.load_model_state_dict(checkpoint) def restore_training_state(self) -> None: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 610f512324b82..f5cfc924e770a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1026,9 +1026,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.data_connector.prepare_data() self.callback_connector._attach_model_callbacks() - if self._ckpt_path and not self.training_type_plugin.restore_checkpoint_after_pre_dispatch: - self._load_checkpoint_weights() - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -1038,6 +1035,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # check if we should delay restoring checkpoint till later if not self.training_type_plugin.restore_checkpoint_after_pre_dispatch: + if self._ckpt_path: + self._load_checkpoint_weights() self.checkpoint_connector.resume_start() self._restore_modules_and_callbacks() diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index ff938598a4ada..4f898eece996d 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -78,6 +78,7 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): ckpt_path = trainer.checkpoint_callback.best_model_path trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path) connector = trainer.checkpoint_connector + trainer.accelerator.connect(model) connector.resume_start() assert connector.resume_checkpoint_path == ckpt_path assert connector._loaded_checkpoint