From 44ae59a54f50ef34b869502c5ba7ac8399d1975f Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Sun, 19 Sep 2021 23:11:25 -0700 Subject: [PATCH 01/13] git redo 9585 --- .../plugins/training_type/ddp.py | 22 +++++++++++++++++-- .../plugins/training_type/fully_sharded.py | 20 ++++++++--------- .../training_type/training_type_plugin.py | 5 ++++- .../connectors/checkpoint_connector.py | 12 +--------- pytorch_lightning/trainer/trainer.py | 15 ++++++------- tests/accelerators/test_common.py | 8 ++++--- ..._ddp_fully_sharded_with_full_state_dict.py | 17 +++++++------- .../connectors/test_checkpoint_connector.py | 1 + 8 files changed, 57 insertions(+), 43 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index df0f658bf712a..49d35800d0eab 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -21,7 +21,7 @@ import time from pathlib import Path from time import sleep -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Union import __main__ import numpy as np @@ -51,13 +51,14 @@ from pytorch_lightning.utilities.distributed import ( distributed_available, init_ddp_connection, + rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available, ) from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT if _TORCH_GREATER_EQUAL_1_10: from torch.distributed.optim import DistributedOptimizer, PostLocalSGDOptimizer, ZeroRedundancyOptimizer @@ -127,6 +128,7 @@ def __init__( self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self._rank_0_has_called_call_children_scripts: bool = False + self._self_deleted_checkpoint_state_dict: bool = False self.set_world_ranks() @property @@ -535,3 +537,19 @@ def teardown(self) -> None: self.lightning_module.cpu() # clean up memory torch.cuda.empty_cache() + + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + if "state_dict" not in checkpoint and self._self_deleted_checkpoint_state_dict: + return + self.lightning_module.load_state_dict(checkpoint["state_dict"]) + + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + rank_zero_info(f"DistributedDataParallel has {self.num_processes} processes. Serializing to avoid CPU OOMs.") + for current_worker in range(self.num_processes): + if self.local_rank == current_worker: + checkpoint = super().load_checkpoint(checkpoint_path) + del checkpoint["state_dict"] + self._self_deleted_checkpoint_state_dict = True + log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.") + self.barrier() + return checkpoint diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 72338e2923c07..bb267e050b953 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -141,16 +141,16 @@ def wrap_policy(*args, **kwargs): ): yield - def setup_environment(self) -> None: - super().setup_environment() - model_call_configure_sharded_model_hook = getattr( - self.lightning_module, "call_configure_sharded_model_hook", False - ) - if not model_call_configure_sharded_model_hook: - # if model has not called configure sharded model, we reset - # the training type plugin's call_configure_sharded_model_hook - # to give trainer a chance to configure. - self.call_configure_sharded_model_hook = True + # def setup_environment(self) -> None: + # super().setup_environment() + # model_call_configure_sharded_model_hook = getattr( + # self.lightning_module, "call_configure_sharded_model_hook", False + # ) + # if not model_call_configure_sharded_model_hook: + # # if model has not called configure sharded model, we reset + # # the training type plugin's call_configure_sharded_model_hook + # # to give trainer a chance to configure. + # self.call_configure_sharded_model_hook = True def configure_ddp(self) -> None: if not self.cpu_offload: diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 13d6f93f5fb97..937661f3aaca5 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -153,7 +153,10 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() - return self.checkpoint_io.load_checkpoint(checkpoint_path) + checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path) + self.lightning_module.on_load_checkpoint(checkpoint) + self.load_model_state_dict(checkpoint) + return checkpoint def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: self.lightning_module.load_state_dict(checkpoint["state_dict"]) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index b750b0f81b26f..65851404c1c76 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -136,16 +136,10 @@ def restore_model(self) -> None: model = self.trainer.lightning_module - # hook: give user access to checkpoint if needed. - model.on_load_checkpoint(self._loaded_checkpoint) - # call hpc specific hook if self.hpc_resume_path is not None: model.on_hpc_load(self._loaded_checkpoint) - # restore model state_dict - self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) - # reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing. if not self.trainer.is_global_zero: for module in self.trainer.lightning_module.modules(): @@ -154,12 +148,8 @@ def restore_model(self) -> None: def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None: """Restore only the model weights.""" - checkpoint = self._loaded_checkpoint if checkpoint_path is not None: - checkpoint = self._load_and_validate_checkpoint(checkpoint_path) - - self.trainer.lightning_module.on_load_checkpoint(checkpoint) - self.trainer.training_type_plugin.load_model_state_dict(checkpoint) + self._load_and_validate_checkpoint(checkpoint_path) def restore_training_state(self) -> None: """Restore the trainer state from the pre-loaded checkpoint. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7cabcb292622a..e1caaaf24dacd 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -976,9 +976,6 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, self.data_connector.prepare_data() self.callback_connector._attach_model_callbacks() - if self._ckpt_path and not self.accelerator.restore_checkpoint_after_pre_dispatch: - self._load_checkpoint_weights() - # ---------------------------- # SET UP TRAINING # ---------------------------- @@ -988,6 +985,8 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT, # check if we should delay restoring checkpoint till later if not self.accelerator.restore_checkpoint_after_pre_dispatch: + if self._ckpt_path: + self._load_checkpoint_weights() self._restore_modules_and_callbacks() self._call_configure_sharded_model() # allow user to setup in model sharded environment @@ -1282,14 +1281,14 @@ def _call_configure_sharded_model(self) -> None: # we will not call the hook; the hook has initialized the sharded model for example. # used on the model if the user re-create a trainer with resume_from_checkpoint - model = self.lightning_module - model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) - if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: + # model = self.lightning_module + # model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) + # if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: with self.accelerator.model_sharded_context(): self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") - model.call_configure_sharded_model_hook = True - self.accelerator.call_configure_sharded_model_hook = False + # model.call_configure_sharded_model_hook = True + # self.accelerator.call_configure_sharded_model_hook = False def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index 61f0a1e247215..ff4eefb01a598 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -106,7 +106,8 @@ def call_configure_sharded_model_hook(self) -> bool: ) trainer.fit(model) - assert not model.configure_sharded_model_called + # assert not model.configure_sharded_model_called + assert model.configure_sharded_model_called def test_accelerator_configure_sharded_model_called_once(tmpdir): @@ -115,9 +116,10 @@ def test_accelerator_configure_sharded_model_called_once(tmpdir): model = DummyModel() trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1) - assert trainer.accelerator.call_configure_sharded_model_hook is True + # assert trainer.accelerator.call_configure_sharded_model_hook is True trainer.fit(model) - assert trainer.accelerator.call_configure_sharded_model_hook is False + # assert trainer.accelerator.call_configure_sharded_model_hook is False + assert model.configure_sharded_model_called def test_configure_sharded_model_called_once(tmpdir): diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index c9c29d31c42ae..a2c630564c3e6 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -62,10 +62,11 @@ def setup(self, stage: str) -> None: self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) def configure_sharded_model(self) -> None: - for i, layer in enumerate(self.layer): - if i % 2 == 0: - self.layer[i] = wrap(layer) - self.layer = wrap(self.layer) + if not isinstance(self.layer, FullyShardedDataParallel): + for i, layer in enumerate(self.layer): + if i % 2 == 0: + self.layer[i] = wrap(layer) + self.layer = wrap(self.layer) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: # when loading full state dict, we first need to create a new unwrapped model @@ -131,13 +132,13 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.fit(model) - model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) - trainer_accelerator_call_configure_sharded_model_hook = trainer.accelerator.call_configure_sharded_model_hook + # model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) + # trainer_accelerator_call_configure_sharded_model_hook = trainer.accelerator.call_configure_sharded_model_hook model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path - assert model_call_configure_sharded_model_hook - assert not trainer_accelerator_call_configure_sharded_model_hook + # assert model_call_configure_sharded_model_hook + # assert not trainer_accelerator_call_configure_sharded_model_hook trainer.save_checkpoint(model_path, weights_only=True) _assert_save_equality(trainer, model_path, cls=TestFSDPModel) diff --git a/tests/trainer/connectors/test_checkpoint_connector.py b/tests/trainer/connectors/test_checkpoint_connector.py index 83a45f02224d5..f9f9d73d10561 100644 --- a/tests/trainer/connectors/test_checkpoint_connector.py +++ b/tests/trainer/connectors/test_checkpoint_connector.py @@ -74,6 +74,7 @@ def test_preloaded_checkpoint_lifecycle(tmpdir): ckpt_path = trainer.checkpoint_callback.best_model_path trainer = Trainer(default_root_dir=tmpdir, max_steps=2, resume_from_checkpoint=ckpt_path) connector = trainer.checkpoint_connector + trainer.accelerator.connect(model) connector.resume_start() assert connector.resume_checkpoint_path == ckpt_path assert connector._loaded_checkpoint From 488e09d7745b7616e73909bf035d82f562790089 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Sep 2021 06:20:35 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../plugins/training_type/fully_sharded.py | 16 ++++++++-------- pytorch_lightning/trainer/trainer.py | 10 +++++----- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index bb267e050b953..16f2dde9e0403 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -143,14 +143,14 @@ def wrap_policy(*args, **kwargs): # def setup_environment(self) -> None: # super().setup_environment() - # model_call_configure_sharded_model_hook = getattr( - # self.lightning_module, "call_configure_sharded_model_hook", False - # ) - # if not model_call_configure_sharded_model_hook: - # # if model has not called configure sharded model, we reset - # # the training type plugin's call_configure_sharded_model_hook - # # to give trainer a chance to configure. - # self.call_configure_sharded_model_hook = True + # model_call_configure_sharded_model_hook = getattr( + # self.lightning_module, "call_configure_sharded_model_hook", False + # ) + # if not model_call_configure_sharded_model_hook: + # # if model has not called configure sharded model, we reset + # # the training type plugin's call_configure_sharded_model_hook + # # to give trainer a chance to configure. + # self.call_configure_sharded_model_hook = True def configure_ddp(self) -> None: if not self.cpu_offload: diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e1caaaf24dacd..1be4caaa0a2fc 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1284,11 +1284,11 @@ def _call_configure_sharded_model(self) -> None: # model = self.lightning_module # model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) # if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: - with self.accelerator.model_sharded_context(): - self.call_hook("configure_sharded_model") - self.call_hook("on_configure_sharded_model") - # model.call_configure_sharded_model_hook = True - # self.accelerator.call_configure_sharded_model_hook = False + with self.accelerator.model_sharded_context(): + self.call_hook("configure_sharded_model") + self.call_hook("on_configure_sharded_model") + # model.call_configure_sharded_model_hook = True + # self.accelerator.call_configure_sharded_model_hook = False def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn From 70a06d96db650b9bfca46578b0104f5363fd2d42 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Sun, 19 Sep 2021 23:22:56 -0700 Subject: [PATCH 03/13] testfsdp comment bool --- tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index a2c630564c3e6..b476934930db1 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -56,7 +56,7 @@ def setup(self, stage: str) -> None: return # resetting call_configure_sharded_model_hook attribute so that we could call # configure sharded model - self.call_configure_sharded_model_hook = False + # self.call_configure_sharded_model_hook = False # for loading full state dict, we first need to create a new unwrapped model # to load state dict and then wrapping self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) From 41409f4a9ba41738b92e7a1bc62fb2b18a6d9142 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Mon, 20 Sep 2021 09:50:54 -0700 Subject: [PATCH 04/13] clean up comments --- .../plugins/training_type/fully_sharded.py | 11 ---- pytorch_lightning/trainer/trainer.py | 9 --- tests/accelerators/test_common.py | 56 ------------------- ..._ddp_fully_sharded_with_full_state_dict.py | 8 --- 4 files changed, 84 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/fully_sharded.py b/pytorch_lightning/plugins/training_type/fully_sharded.py index 16f2dde9e0403..74f30b76e383f 100644 --- a/pytorch_lightning/plugins/training_type/fully_sharded.py +++ b/pytorch_lightning/plugins/training_type/fully_sharded.py @@ -141,17 +141,6 @@ def wrap_policy(*args, **kwargs): ): yield - # def setup_environment(self) -> None: - # super().setup_environment() - # model_call_configure_sharded_model_hook = getattr( - # self.lightning_module, "call_configure_sharded_model_hook", False - # ) - # if not model_call_configure_sharded_model_hook: - # # if model has not called configure sharded model, we reset - # # the training type plugin's call_configure_sharded_model_hook - # # to give trainer a chance to configure. - # self.call_configure_sharded_model_hook = True - def configure_ddp(self) -> None: if not self.cpu_offload: # When using CPU Offload, FSDP will manage the CUDA movement for us. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 1be4caaa0a2fc..65936cf59ccb3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1277,18 +1277,9 @@ def _call_setup_hook(self) -> None: self.accelerator.barrier("post_setup") def _call_configure_sharded_model(self) -> None: - # Call configure sharded model hook if accelerator requests. In some cases - # we will not call the hook; the hook has initialized the sharded model for example. - - # used on the model if the user re-create a trainer with resume_from_checkpoint - # model = self.lightning_module - # model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) - # if self.accelerator.call_configure_sharded_model_hook and not model_call_configure_sharded_model_hook: with self.accelerator.model_sharded_context(): self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") - # model.call_configure_sharded_model_hook = True - # self.accelerator.call_configure_sharded_model_hook = False def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index ff4eefb01a598..b0cb37256a708 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -77,59 +77,3 @@ def configure_sharded_model(self): trainer.fit(model) assert model.configure_sharded_model_called - - -class DummyModel(BoringModel): - def __init__(self): - super().__init__() - self.configure_sharded_model_called = False - - def configure_sharded_model(self): - self.configure_sharded_model_called = True - - -def test_configure_sharded_model_false(tmpdir): - """Ensure ``configure_sharded_model`` is not called, when turned off.""" - - class CustomPlugin(SingleDevicePlugin): - @property - def call_configure_sharded_model_hook(self) -> bool: - return False - - model = DummyModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=2, - limit_val_batches=2, - max_epochs=1, - plugins=CustomPlugin(device=torch.device("cpu")), - ) - trainer.fit(model) - - # assert not model.configure_sharded_model_called - assert model.configure_sharded_model_called - - -def test_accelerator_configure_sharded_model_called_once(tmpdir): - """Ensure that the configure sharded model hook is called, and set to False after to ensure not called - again.""" - - model = DummyModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1) - # assert trainer.accelerator.call_configure_sharded_model_hook is True - trainer.fit(model) - # assert trainer.accelerator.call_configure_sharded_model_hook is False - assert model.configure_sharded_model_called - - -def test_configure_sharded_model_called_once(tmpdir): - """Ensure ``configure_sharded_model`` is only called once.""" - - model = DummyModel() - trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=2, limit_val_batches=2, max_epochs=1) - trainer.fit(model) - - assert model.configure_sharded_model_called - model.configure_sharded_model_called = False - - assert not model.configure_sharded_model_called diff --git a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py index b476934930db1..992f14ce283e1 100644 --- a/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py +++ b/tests/plugins/test_ddp_fully_sharded_with_full_state_dict.py @@ -54,9 +54,6 @@ def setup(self, stage: str) -> None: # when running stages like test, validate, and predict, we will skip setting up, # will directly use the module itself unless we load from checkpoint return - # resetting call_configure_sharded_model_hook attribute so that we could call - # configure sharded model - # self.call_configure_sharded_model_hook = False # for loading full state dict, we first need to create a new unwrapped model # to load state dict and then wrapping self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) @@ -132,13 +129,8 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): def _run_multiple_stages(trainer, model, model_path: Optional[str] = None): trainer.fit(model) - # model_call_configure_sharded_model_hook = getattr(model, "call_configure_sharded_model_hook", False) - # trainer_accelerator_call_configure_sharded_model_hook = trainer.accelerator.call_configure_sharded_model_hook - model_path = model_path if model_path else trainer.checkpoint_callback.last_model_path - # assert model_call_configure_sharded_model_hook - # assert not trainer_accelerator_call_configure_sharded_model_hook trainer.save_checkpoint(model_path, weights_only=True) _assert_save_equality(trainer, model_path, cls=TestFSDPModel) From 195c2d035701a4b311338bd9d2eaa349ec67ea63 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Mon, 20 Sep 2021 10:49:54 -0700 Subject: [PATCH 05/13] clean import --- tests/accelerators/test_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/accelerators/test_common.py b/tests/accelerators/test_common.py index b0cb37256a708..93564e27defa9 100644 --- a/tests/accelerators/test_common.py +++ b/tests/accelerators/test_common.py @@ -16,7 +16,6 @@ import tests.helpers.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.plugins import SingleDevicePlugin from tests.accelerators.test_dp import CustomClassificationModelDP from tests.helpers.boring_model import BoringModel from tests.helpers.datamodules import ClassifDataModule From afb662a8b12d4fcfc618683fc0bd450a7c958ea2 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Thu, 23 Sep 2021 18:21:17 -0700 Subject: [PATCH 06/13] updates --- pytorch_lightning/plugins/training_type/ddp.py | 2 ++ tests/plugins/test_deepspeed_plugin.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 49d35800d0eab..17c54ac1c788c 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -538,6 +538,8 @@ def teardown(self) -> None: # clean up memory torch.cuda.empty_cache() + self._self_deleted_checkpoint_state_dict = False + def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: if "state_dict" not in checkpoint and self._self_deleted_checkpoint_state_dict: return diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c7ccaab3e72f4..a351237ec5b3a 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -603,7 +603,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): run_checkpoint_test(tmpdir) -@RunIf(min_gpus=1, deepspeed=True, special=False) +@RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the optimizer state and scheduler states cannot be restored.""" From 35dde07e3742869b6f77700f5946ebc1d3fae445 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Fri, 24 Sep 2021 14:37:12 -0700 Subject: [PATCH 07/13] flip test --- tests/plugins/test_deepspeed_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index a351237ec5b3a..c7ccaab3e72f4 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -603,7 +603,7 @@ def test_deepspeed_multigpu_stage_3_checkpointing(tmpdir): run_checkpoint_test(tmpdir) -@RunIf(min_gpus=1, deepspeed=True, special=True) +@RunIf(min_gpus=1, deepspeed=True, special=False) def test_deepspeed_multigpu_stage_3_warns_resume_training(tmpdir): """Test to ensure with Stage 3 and multiple GPUs that we can resume from training, throwing a warning that the optimizer state and scheduler states cannot be restored.""" From ba898678515b7dff87267852241fbac184646ff3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 27 Sep 2021 16:55:31 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/ddp.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 24ab7728d20e3..a16529c9e7678 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -50,7 +50,13 @@ ) from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.distributed import group as _group -from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_info, rank_zero_only, ReduceOp, sync_ddp_if_available +from pytorch_lightning.utilities.distributed import ( + init_ddp_connection, + rank_zero_info, + rank_zero_only, + ReduceOp, + sync_ddp_if_available, +) from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT From 3ce5f2e698a505796b8d5a132b4c7adeff149172 Mon Sep 17 00:00:00 2001 From: jjenniferdai <89552168+jjenniferdai@users.noreply.github.com> Date: Thu, 30 Sep 2021 11:51:02 -0700 Subject: [PATCH 09/13] info msg checkpoint loading MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/plugins/training_type/ddp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index a16529c9e7678..4e06f2df13174 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -546,7 +546,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: self.lightning_module.load_state_dict(checkpoint["state_dict"]) def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: - rank_zero_info(f"DistributedDataParallel has {self.num_processes} processes. Serializing to avoid CPU OOMs.") + rank_zero_info(f"DistributedDataParallel has {self.num_processes} processes. Serializing checkpoint loading to avoid CPU OOMs.") for current_worker in range(self.num_processes): if self.local_rank == current_worker: checkpoint = super().load_checkpoint(checkpoint_path) From 38099607803649611fa52e497cc49f35391873b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 30 Sep 2021 18:52:10 +0000 Subject: [PATCH 10/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/ddp.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 4e06f2df13174..1c662507d2158 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -546,7 +546,9 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: self.lightning_module.load_state_dict(checkpoint["state_dict"]) def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: - rank_zero_info(f"DistributedDataParallel has {self.num_processes} processes. Serializing checkpoint loading to avoid CPU OOMs.") + rank_zero_info( + f"DistributedDataParallel has {self.num_processes} processes. Serializing checkpoint loading to avoid CPU OOMs." + ) for current_worker in range(self.num_processes): if self.local_rank == current_worker: checkpoint = super().load_checkpoint(checkpoint_path) From e234e48bc1d1b77d79795266d75efef1bdfc87dd Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Thu, 30 Sep 2021 12:22:59 -0700 Subject: [PATCH 11/13] pytorch_lightning/plugins/training_type/ddp.py --- pytorch_lightning/plugins/training_type/ddp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 1c662507d2158..2322b8ed4064c 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -547,7 +547,8 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: rank_zero_info( - f"DistributedDataParallel has {self.num_processes} processes. Serializing checkpoint loading to avoid CPU OOMs." + f"DistributedDataParallel has {self.num_processes} processes. " + "Serializing checkpoint loading to avoid CPU OOMs." ) for current_worker in range(self.num_processes): if self.local_rank == current_worker: From 85c573d12b5728a9832faa96cf93849671c7efe7 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Wed, 13 Oct 2021 19:23:36 -0700 Subject: [PATCH 12/13] option2 to option1 --- pytorch_lightning/plugins/training_type/ddp.py | 10 ++++++---- .../plugins/training_type/training_type_plugin.py | 6 ++---- .../trainer/connectors/checkpoint_connector.py | 8 +++++++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp.py b/pytorch_lightning/plugins/training_type/ddp.py index 2322b8ed4064c..ed10885f5dbd9 100644 --- a/pytorch_lightning/plugins/training_type/ddp.py +++ b/pytorch_lightning/plugins/training_type/ddp.py @@ -128,7 +128,7 @@ def __init__( self._pids: Optional[List[int]] = None self._sync_dir: Optional[str] = None self._rank_0_has_called_call_children_scripts: bool = False - self._self_deleted_checkpoint_state_dict: bool = False + self._has_loaded_state_dict: bool = False self.set_world_ranks() @property @@ -538,10 +538,10 @@ def teardown(self) -> None: # clean up memory torch.cuda.empty_cache() - self._self_deleted_checkpoint_state_dict = False + self._has_loaded_state_dict = False def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: - if "state_dict" not in checkpoint and self._self_deleted_checkpoint_state_dict: + if "state_dict" not in checkpoint and self._has_loaded_state_dict: return self.lightning_module.load_state_dict(checkpoint["state_dict"]) @@ -553,8 +553,10 @@ def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: for current_worker in range(self.num_processes): if self.local_rank == current_worker: checkpoint = super().load_checkpoint(checkpoint_path) + self.lightning_module.on_load_checkpoint(checkpoint) + self.load_model_state_dict(checkpoint) del checkpoint["state_dict"] - self._self_deleted_checkpoint_state_dict = True + self._has_loaded_state_dict = True log.info(f"Rank {self.global_rank}: done loading model states from {checkpoint_path}.") self.barrier() return checkpoint diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 247741fcd2a5b..c9d608ff52986 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -168,12 +168,10 @@ def results(self) -> Optional[Union[_EVALUATE_OUTPUT, _PREDICT_OUTPUT]]: def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: torch.cuda.empty_cache() - checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path) - self.lightning_module.on_load_checkpoint(checkpoint) - self.load_model_state_dict(checkpoint) - return checkpoint + return self.checkpoint_io.load_checkpoint(checkpoint_path) def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + self.lightning_module.on_load_checkpoint(checkpoint) self.lightning_module.load_state_dict(checkpoint["state_dict"]) def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4fd211ee3edd4..51b6f5a0bbf67 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -148,6 +148,9 @@ def restore_model(self) -> None: if self.hpc_resume_path is not None: model.on_hpc_load(self._loaded_checkpoint) + # restore model state_dict + self.trainer.training_type_plugin.load_model_state_dict(self._loaded_checkpoint) + # reset metrics states on non-rank 0 as all states have been accumulated on rank 0 via syncing on checkpointing. if not self.trainer.is_global_zero: for module in self.trainer.lightning_module.modules(): @@ -156,8 +159,11 @@ def restore_model(self) -> None: def restore_model_weights(self, checkpoint_path: Optional[_PATH]) -> None: """Restore only the model weights.""" + checkpoint = self._loaded_checkpoint if checkpoint_path is not None: - self._load_and_validate_checkpoint(checkpoint_path) + checkpoint = self._load_and_validate_checkpoint(checkpoint_path) + + self.trainer.training_type_plugin.load_model_state_dict(checkpoint) def restore_training_state(self) -> None: """Restore the trainer state from the pre-loaded checkpoint. From 701fb6aed2c75b987861bb21d0f2348b1ba53701 Mon Sep 17 00:00:00 2001 From: Jennifer Dai Date: Thu, 14 Oct 2021 11:51:03 -0700 Subject: [PATCH 13/13] update --- pytorch_lightning/plugins/training_type/deepspeed.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index e2e8c316f48d1..fcc180efb66ff 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -746,6 +746,8 @@ def lightning_restore_optimizer_and_schedulers(self) -> bool: return False def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + if "state_dict" not in checkpoint and self._has_loaded_state_dict: + return # override to do nothing, deepspeed engine already loaded the weights in `load_checkpoint()` if self.load_full_weights and self.zero_stage_3: self.model_to_device()