From 6ac25c2e8e3222fe1017f5330f27f76104b5c543 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 4 Dec 2021 00:12:10 +0100 Subject: [PATCH 1/8] don't use best model path to determine file location --- .../plugins/training_type/ddp_spawn.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 563f39a1f0cf4..23b4fc18c4f32 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -16,6 +16,7 @@ import re from collections import UserList from multiprocessing.queues import SimpleQueue +from pathlib import Path from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union import numpy as np @@ -255,10 +256,10 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op return # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + weights_path = Path(checkpoint_callback.dirpath if checkpoint_callback is not None else ".") / ".temp.ckpt" + self.checkpoint_io.save_checkpoint(state_dict, weights_path) # adds the `callback_metrics` to the queue extra = _FakeQueue() @@ -268,7 +269,7 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op else: self.add_to_queue(trainer, extra) - return _SpawnOutput(best_model_path, last_path, results, extra) + return _SpawnOutput(best_model_path, weights_path, results, extra) def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: # transfer back the best path to the trainer @@ -277,11 +278,12 @@ def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", traine # TODO: pass also best score # load last weights - if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + if spawn_output.weights_path is not None: ckpt = self.checkpoint_io.load_checkpoint( - spawn_output.last_path, map_location=(lambda storage, loc: storage) + spawn_output.weights_path, map_location=(lambda storage, loc: storage) ) self.lightning_module.load_state_dict(ckpt) + self.checkpoint_io.remove_checkpoint(spawn_output.weights_path) # get the `callback_metrics` and set it to the trainer if is_overridden("get_from_queue", self.lightning_module): @@ -416,6 +418,6 @@ def empty(self) -> bool: class _SpawnOutput(NamedTuple): best_model_path: Optional[_PATH] - last_path: Optional[_PATH] + weights_path: Optional[_PATH] trainer_results: Any extra: _FakeQueue From e1adab3d78b19253976f1fb1530beeb9da1448f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 4 Dec 2021 00:34:41 +0100 Subject: [PATCH 2/8] update changelog --- CHANGELOG.md | 3 +++ pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 - 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a168db7b486d..999cfcf691552 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `training_step`, `validation_step`, `test_step` and `predict_step` method signatures in `Accelerator` and updated input from caller side ([#10908](https://github.com/PyTorchLightning/pytorch-lightning/pull/10908)) +- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 23b4fc18c4f32..db6502b191306 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -283,7 +283,6 @@ def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", traine spawn_output.weights_path, map_location=(lambda storage, loc: storage) ) self.lightning_module.load_state_dict(ckpt) - self.checkpoint_io.remove_checkpoint(spawn_output.weights_path) # get the `callback_metrics` and set it to the trainer if is_overridden("get_from_queue", self.lightning_module): From c3e49377cb0553efc9554c47d5afb5b589581c55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 4 Dec 2021 00:58:30 +0100 Subject: [PATCH 3/8] update tpu spawn --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 5 ++--- pytorch_lightning/plugins/training_type/tpu_spawn.py | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index db6502b191306..d71fe10c9f619 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -13,7 +13,6 @@ # limitations under the License. import logging import os -import re from collections import UserList from multiprocessing.queues import SimpleQueue from pathlib import Path @@ -273,8 +272,8 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: # transfer back the best path to the trainer - if self.lightning_module.trainer.checkpoint_callback: - self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path + if trainer.checkpoint_callback: + trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path # TODO: pass also best score # load last weights diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 6aafe60e40bce..5bd496ac23f7a 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -13,9 +13,9 @@ # limitations under the License. import io import os -import re import time from multiprocessing.queues import SimpleQueue +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -201,10 +201,10 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op state_dict = self.lightning_module.state_dict() # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + weights_path = Path(checkpoint_callback.dirpath if checkpoint_callback is not None else ".") / ".temp.ckpt" + self.checkpoint_io.save_checkpoint(state_dict, weights_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training if self.local_rank != 0: @@ -218,7 +218,7 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op else: self.add_to_queue(trainer, extra) - return _SpawnOutput(best_model_path, last_path, results, extra) + return _SpawnOutput(best_model_path, weights_path, results, extra) def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: From a4890d6573c11a6120bd6db80f2b11537ed766e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 4 Dec 2021 01:16:20 +0100 Subject: [PATCH 4/8] use default root dir --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 3 +-- pytorch_lightning/plugins/training_type/tpu_spawn.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index d71fe10c9f619..a279056d47d8e 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,6 @@ import os from collections import UserList from multiprocessing.queues import SimpleQueue -from pathlib import Path from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union import numpy as np @@ -257,7 +256,7 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: - weights_path = Path(checkpoint_callback.dirpath if checkpoint_callback is not None else ".") / ".temp.ckpt" + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") self.checkpoint_io.save_checkpoint(state_dict, weights_path) # adds the `callback_metrics` to the queue diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 5bd496ac23f7a..56ebb71efb6f8 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -15,7 +15,6 @@ import os import time from multiprocessing.queues import SimpleQueue -from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -203,7 +202,7 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op # save the last weights weights_path = None if trainer.state.fn == TrainerFn.FITTING: - weights_path = Path(checkpoint_callback.dirpath if checkpoint_callback is not None else ".") / ".temp.ckpt" + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") self.checkpoint_io.save_checkpoint(state_dict, weights_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training From b8812d3e8e356b5b6f409cc714d0859ca4a59ed6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 4 Dec 2021 02:41:02 +0100 Subject: [PATCH 5/8] add test --- .../plugins/training_type/ddp_spawn.py | 12 ++++---- .../plugins/training_type/tpu_spawn.py | 4 +-- tests/plugins/test_ddp_spawn_plugin.py | 29 +++++++++++++++++++ 3 files changed, 37 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index a279056d47d8e..bed74873bc9a1 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -134,19 +134,19 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st def start_training(self, trainer: "pl.Trainer") -> Any: spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) + self._recover_results_in_main_process(spawn_output, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] return spawn_output.trainer_results def start_evaluating(self, trainer: "pl.Trainer") -> Any: spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) + self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results def start_predicting(self, trainer: "pl.Trainer") -> Any: spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) + self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: @@ -199,7 +199,7 @@ def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: self.barrier() results = trainer.run_stage() - outputs = self.__collect_rank_zero_results(trainer, results) + outputs = self._collect_rank_zero_results(trainer, results) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() @@ -242,7 +242,7 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up ddp environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -269,7 +269,7 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op return _SpawnOutput(best_model_path, weights_path, results, extra) - def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: + def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: # transfer back the best path to the trainer if trainer.checkpoint_callback: trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 56ebb71efb6f8..61874f887b5c4 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -171,7 +171,7 @@ def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Op results = trainer.run_stage() - outputs = self.__collect_rank_zero_results(trainer, results) + outputs = self._collect_rank_zero_results(trainer, results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") @@ -191,7 +191,7 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up tpu spawn environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 7997c783ef6a3..44a97eaf4f747 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path +from unittest.mock import Mock import pytest import torch @@ -143,3 +145,30 @@ def test_ddp_spawn_configure_ddp(tmpdir): trainer.validate(model, dataloaders=model.val_dataloader()) trainer.test(model, dataloaders=model.test_dataloader()) trainer.predict(model, dataloaders=model.predict_dataloader()) + + +@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"]) +def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): + """Tests that the spawn plugin transfers the new weights to the main process and deletes the temporary file.""" + model = Mock(wraps=BoringModel(), spec=BoringModel) + plugin = DDPSpawnPlugin() + plugin.model = model + trainer = Trainer(default_root_dir=tmpdir) + trainer.state.fn = trainer_fn # pretend we are in a particular trainer state + temp_file = Path(tmpdir, ".temp.ckpt") + + assert not temp_file.exists() + spawn_output = plugin._collect_rank_zero_results(trainer, {}) + + model.state_dict.assert_called_once() + if trainer_fn == TrainerFn.FITTING: + assert spawn_output.weights_path == str(temp_file) + assert temp_file.exists() + else: + assert spawn_output.weights_path is None + assert not temp_file.exists() + + # <-- here would normally be the multiprocessing boundary + plugin._recover_results_in_main_process(spawn_output, trainer) + assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None) + assert not temp_file.exists() From 00e6087d25afa725c15d8cf7673c9b0a863a787b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Sat, 4 Dec 2021 20:39:35 +0100 Subject: [PATCH 6/8] update typezzz --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 6 +++--- pytorch_lightning/plugins/training_type/tpu_spawn.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 5e10155cc3ca5..136cdc40de10c 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Optional, Tuple import torch from torch.nn import Module @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _SpawnOutput from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType @@ -114,7 +114,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: + def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 61874f887b5c4..38d68c18d7912 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -15,7 +15,7 @@ import os import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.multiprocessing as mp @@ -153,7 +153,7 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None: def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: + def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() From b5e45cb91fa35b50c963e222fbfb8c31f1fbfc0c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Dec 2021 19:40:57 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 136cdc40de10c..04835125928f8 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin, _SpawnOutput +from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput, DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType From bde350ba0e379e0189c6cde3f07adeb4207aab1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 6 Dec 2021 15:14:03 +0100 Subject: [PATCH 8/8] fix merge conflict --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index a721ae57d7c56..af08a6df19ff3 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -280,7 +280,7 @@ def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer spawn_output.weights_path, map_location=(lambda storage, loc: storage) ) self.lightning_module.load_state_dict(ckpt) - self.checkpoint_io.remove_checkpoint(spawn_output.last_path) + self.checkpoint_io.remove_checkpoint(spawn_output.weights_path) # get the `callback_metrics` and set it to the trainer if is_overridden("get_from_queue", self.lightning_module):