Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed `training_step`, `validation_step`, `test_step` and `predict_step` method signatures in `Accelerator` and updated input from caller side ([#10908](https://github.com/PyTorchLightning/pytorch-lightning/pull/10908))


- Changed the name of the temporary checkpoint that the `DDPSpawnPlugin` and related plugins save ([#10934](https://github.com/PyTorchLightning/pytorch-lightning/pull/10934))


### Deprecated

- Deprecated `ClusterEnvironment.master_{address,port}` in favor of `ClusterEnvironment.main_{address,port}` ([#10103](https://github.com/PyTorchLightning/pytorch-lightning/issues/10103))
Expand Down
35 changes: 17 additions & 18 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import logging
import os
import re
from collections import UserList
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
Expand Down Expand Up @@ -135,19 +134,19 @@ def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[st

def start_training(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []
return spawn_output.trainer_results

def start_evaluating(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results

def start_predicting(self, trainer: "pl.Trainer") -> Any:
spawn_output: _SpawnOutput = self.spawn(self.new_process, trainer)
self.__recover_results_in_main_process(spawn_output, trainer)
self._recover_results_in_main_process(spawn_output, trainer)
return spawn_output.trainer_results

def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> Optional[Union[Any, "_SpawnOutput"]]:
Expand Down Expand Up @@ -200,7 +199,7 @@ def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
self.barrier()

results = trainer.run_stage()
outputs = self.__collect_rank_zero_results(trainer, results)
outputs = self._collect_rank_zero_results(trainer, results)

# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
Expand Down Expand Up @@ -243,7 +242,7 @@ def determine_ddp_device_ids(self):
return None
return [self.root_device.index]

def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_warn("cleaning up ddp environment...")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
Expand All @@ -255,10 +254,10 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op
return

# save the last weights
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)
weights_path = None
if trainer.state.fn == TrainerFn.FITTING:
weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
self.checkpoint_io.save_checkpoint(state_dict, weights_path)

# adds the `callback_metrics` to the queue
extra = _FakeQueue()
Expand All @@ -267,21 +266,21 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op
self.lightning_module.add_to_queue(extra)
self.add_to_queue(trainer, extra)

return _SpawnOutput(best_model_path, last_path, trainer.state, results, extra)
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)

def __recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer) -> None:
# transfer back the best path to the trainer
if self.lightning_module.trainer.checkpoint_callback:
self.lightning_module.trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path
if trainer.checkpoint_callback:
trainer.checkpoint_callback.best_model_path = spawn_output.best_model_path

# TODO: pass also best score
# load last weights
if spawn_output.last_path is not None and self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
if spawn_output.weights_path is not None:
ckpt = self.checkpoint_io.load_checkpoint(
spawn_output.last_path, map_location=(lambda storage, loc: storage)
spawn_output.weights_path, map_location=(lambda storage, loc: storage)
)
self.lightning_module.load_state_dict(ckpt)
self.checkpoint_io.remove_checkpoint(spawn_output.last_path)
self.checkpoint_io.remove_checkpoint(spawn_output.weights_path)

trainer.state = spawn_output.trainer_state

Expand Down Expand Up @@ -417,7 +416,7 @@ def empty(self) -> bool:

class _SpawnOutput(NamedTuple):
best_model_path: Optional[_PATH]
last_path: Optional[_PATH]
weights_path: Optional[_PATH]
trainer_state: TrainerState
trainer_results: Any
extra: _FakeQueue
6 changes: 3 additions & 3 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Dict, Generator, List, Optional, Tuple
from typing import Dict, Generator, List, Optional, Tuple

import torch
from torch.nn import Module
from torch.optim import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.plugins.precision.sharded_native_amp import ShardedNativeMixedPrecisionPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import _FakeQueue, DDPSpawnPlugin
from pytorch_lightning.plugins.training_type.ddp_spawn import _SpawnOutput, DDPSpawnPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE, rank_zero_only
from pytorch_lightning.utilities.enums import _StrategyType
Expand Down Expand Up @@ -114,7 +114,7 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
def post_training_step(self):
pass

def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, _FakeQueue]]:
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
# Ensure that the scaler points to the correct process group
# which is re-initialized in a new process
if isinstance(self.precision_plugin, ShardedNativeMixedPrecisionPlugin):
Expand Down
19 changes: 9 additions & 10 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
# limitations under the License.
import io
import os
import re
import time
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.multiprocessing as mp
Expand Down Expand Up @@ -155,7 +154,7 @@ def init_dist_connection(self, global_rank: int, world_size: int) -> None:
def set_world_ranks(self, process_idx: int = 0) -> None:
pass

def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Optional[str], Any, "_FakeQueue"]]:
def new_process(self, trainer: "pl.Trainer") -> Optional["_SpawnOutput"]:
if self.tpu_global_core_rank != 0 and trainer.progress_bar_callback is not None:
trainer.progress_bar_callback.disable()

Expand All @@ -173,7 +172,7 @@ def new_process(self, trainer: "pl.Trainer") -> Optional[Tuple[Optional[str], Op

results = trainer.run_stage()

outputs = self.__collect_rank_zero_results(trainer, results)
outputs = self._collect_rank_zero_results(trainer, results)

# https://github.com/pytorch/xla/issues/1801#issuecomment-602799542
self.barrier("end-process")
Expand All @@ -193,7 +192,7 @@ def barrier(self, name: Optional[str] = None) -> None:
if self.is_distributed:
rendezvous(name)

def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_warn("cleaning up tpu spawn environment...")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
Expand All @@ -202,10 +201,10 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op
state_dict = self.lightning_module.state_dict()

# save the last weights
last_path = None
if trainer.state.fn == TrainerFn.FITTING and best_model_path is not None and len(best_model_path) > 0:
last_path = re.sub(".ckpt", ".tmp_end.ckpt", best_model_path)
self.checkpoint_io.save_checkpoint(state_dict, last_path)
weights_path = None
if trainer.state.fn == TrainerFn.FITTING:
weights_path = os.path.join(trainer.default_root_dir, ".temp.ckpt")
self.checkpoint_io.save_checkpoint(state_dict, weights_path)

# We use `local_rank` here as separate filesystems are used for each VM for TPU Pod Training
if self.local_rank != 0:
Expand All @@ -219,7 +218,7 @@ def __collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Op
else:
self.add_to_queue(trainer, extra)

return _SpawnOutput(best_model_path, last_path, trainer.state, results, extra)
return _SpawnOutput(best_model_path, weights_path, trainer.state, results, extra)

def broadcast(self, obj: object, src: int = 0) -> object:
if not self.is_distributed:
Expand Down
29 changes: 29 additions & 0 deletions tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from unittest.mock import Mock

import pytest
import torch
Expand Down Expand Up @@ -143,3 +145,30 @@ def test_ddp_spawn_configure_ddp(tmpdir):
trainer.validate(model, dataloaders=model.val_dataloader())
trainer.test(model, dataloaders=model.test_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())


@pytest.mark.parametrize("trainer_fn", [TrainerFn.FITTING, "other"])
def test_ddp_spawn_transfer_weights(tmpdir, trainer_fn):
"""Tests that the spawn plugin transfers the new weights to the main process and deletes the temporary file."""
model = Mock(wraps=BoringModel(), spec=BoringModel)
plugin = DDPSpawnPlugin()
plugin.model = model
trainer = Trainer(default_root_dir=tmpdir)
trainer.state.fn = trainer_fn # pretend we are in a particular trainer state
temp_file = Path(tmpdir, ".temp.ckpt")

assert not temp_file.exists()
spawn_output = plugin._collect_rank_zero_results(trainer, {})

model.state_dict.assert_called_once()
if trainer_fn == TrainerFn.FITTING:
assert spawn_output.weights_path == str(temp_file)
assert temp_file.exists()
else:
assert spawn_output.weights_path is None
assert not temp_file.exists()

# <-- here would normally be the multiprocessing boundary
plugin._recover_results_in_main_process(spawn_output, trainer)
assert model.load_state_dict.call_count == int(spawn_output.weights_path is not None)
assert not temp_file.exists()