diff --git a/CHANGELOG.md b/CHANGELOG.md index 66bc14c129f53..d2a3fea03411e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed `training_step`, `validation_step`, `test_step` and `predict_step` method signatures in `Accelerator` and updated input from caller side ([#10908](https://github.com/PyTorchLightning/pytorch-lightning/pull/10908)) +- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934)) + + ### Deprecated - Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103)) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 17f610929364a..6df8246e99a27 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -13,7 +13,6 @@ # limitations under the License. import logging import os -import re from collections import UserList from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union @@ -135,19 +134,19 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st def start_training(self, trainer: "pl.Trainer") -> Any: spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) + self._recover_results_in_main_process(spawn_output, trainer) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] return spawn_output.trainer_results def start_evaluating(self, trainer: "pl.Trainer") -> Any: spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) + self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results def start_predicting(self, trainer: "pl.Trainer") -> Any: spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer) - self.__recover_results_in_main_process(spawn_output, trainer) + self._recover_results_in_main_process(spawn_output, trainer) return spawn_output.trainer_results def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]: @@ -200,7 +199,7 @@ def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: self.barrier() results = trainer.run_stage() - outputs = self.__collect_rank_zero_results(trainer, results) + outputs = self._collect_rank_zero_results(trainer, results) # ensure that spawned processes go through teardown before joining trainer._call_teardown_hook() @@ -243,7 +242,7 @@ def determine_ddp_device_ids(self): return None return [self.root_device.index] - def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up ddp environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -255,10 +254,10 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op return # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + self.checkpoint_io.save_checkpoint(state_dict, weights_path) # adds the `callback_metrics` to the queue extra = _FakeQueue() @@ -267,21 +266,21 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op self.lightning_module.add_to_queue(extra) self.add_to_queue(trainer, extra) - return _SpawnOutput(best_model_path, last_path, trainer.state, results, extra) + return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) - def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: + def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None: # transfer back the best path to the trainer - if self.lightning_module.trainer.checkpoint_callback: - self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path + if trainer.checkpoint_callback: + trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path # TODO: pass also best score # load last weights - if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING: + if spawn_output.weights_path is not None: ckpt = self.checkpoint_io.load_checkpoint( - spawn_output.last_path, map_location=(lambda storage, loc: storage) + spawn_output.weights_path, map_location=(lambda storage, loc: storage) ) self.lightning_module.load_state_dict(ckpt) - self.checkpoint_io.remove_checkpoint(spawn_output.last_path) + self.checkpoint_io.remove_checkpoint(spawn_output.weights_path) trainer.state = spawn_output.trainer_state @@ -417,7 +416,7 @@ def empty(self) -> bool: class _SpawnOutput(NamedTuple): best_model_path: Optional[_PATH] - last_path: Optional[_PATH] + weights_path: Optional[_PATH] trainer_state: TrainerState trainer_results: Any extra: _FakeQueue diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 5e10155cc3ca5..04835125928f8 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Dict, Generator, List, Optional, Tuple +from typing import Dict, Generator, List, Optional, Tuple import torch from torch.nn import Module @@ -20,7 +20,7 @@ import pytorch_lightning as pl from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin -from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin +from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput, DDPSpawnPlugin from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only from pytorch_lightning.utilities.enums import _StrategyType @@ -114,7 +114,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]: + def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin): diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index 8a9713dce892d..aacdf662e2996 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -13,10 +13,9 @@ # limitations under the License. import io import os -import re import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.multiprocessing as mp @@ -155,7 +154,7 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None: def set_world_ranks(self, process_idx: int = 0) -> None: pass - def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]: + def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]: if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None: trainer.progress_bar_callback.disable() @@ -173,7 +172,7 @@ def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Op results = trainer.run_stage() - outputs = self.__collect_rank_zero_results(trainer, results) + outputs = self._collect_rank_zero_results(trainer, results) # https://github.com/pytorch/xla/issues/1801#issuecomment-602799542 self.barrier("end-process") @@ -193,7 +192,7 @@ def barrier(self, name: Optional[str] = None) -> None: if self.is_distributed: rendezvous(name) - def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: + def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_warn("cleaning up tpu spawn environment...") checkpoint_callback = trainer.checkpoint_callback best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None @@ -202,10 +201,10 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op state_dict = self.lightning_module.state_dict() # save the last weights - last_path = None - if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0: - last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path) - self.checkpoint_io.save_checkpoint(state_dict, last_path) + weights_path = None + if trainer.state.fn == TrainerFn.FITTING: + weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt") + self.checkpoint_io.save_checkpoint(state_dict, weights_path) # We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training if self.local_rank != 0: @@ -219,7 +218,7 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op else: self.add_to_queue(trainer, extra) - return _SpawnOutput(best_model_path, last_path, trainer.state, results, extra) + return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra) def broadcast(self, obj: object, src: int = 0) -> object: if not self.is_distributed: diff --git a/tests/plugins/test_ddp_spawn_plugin.py b/tests/plugins/test_ddp_spawn_plugin.py index 7997c783ef6a3..44a97eaf4f747 100644 --- a/tests/plugins/test_ddp_spawn_plugin.py +++ b/tests/plugins/test_ddp_spawn_plugin.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path +from unittest.mock import Mock import pytest import torch @@ -143,3 +145,30 @@ def test_ddp_spawn_configure_ddp(tmpdir): trainer.validate(model, dataloaders=model.val_dataloader()) trainer.test(model, dataloaders=model.test_dataloader()) trainer.predict(model, dataloaders=model.predict_dataloader()) + + +@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"]) +def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn): + """Tests that the spawn plugin transfers the new weights to the main process and deletes the temporary file.""" + model = Mock(wraps=BoringModel(), spec=BoringModel) + plugin = DDPSpawnPlugin() + plugin.model = model + trainer = Trainer(default_root_dir=tmpdir) + trainer.state.fn = trainer_fn # pretend we are in a particular trainer state + temp_file = Path(tmpdir, ".temp.ckpt") + + assert not temp_file.exists() + spawn_output = plugin._collect_rank_zero_results(trainer, {}) + + model.state_dict.assert_called_once() + if trainer_fn == TrainerFn.FITTING: + assert spawn_output.weights_path == str(temp_file) + assert temp_file.exists() + else: + assert spawn_output.weights_path is None + assert not temp_file.exists() + + # <-- here would normally be the multiprocessing boundary + plugin._recover_results_in_main_process(spawn_output, trainer) + assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None) + assert not temp_file.exists()