diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ef6553f8fe45..2e660a87063af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -569,6 +569,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed calls to `profile("model_forward")` in favor of profiling `training_step` ([#12032](https://github.com/PyTorchLightning/pytorch-lightning/pull/12032)) + +- Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the `_SpawnLauncher` ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966)) + + ### Fixed - Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752) diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 0eb4b68651aa8..d38d6faf6886a 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -136,9 +136,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None: self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() - def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: - return {"nprocs": self.num_processes} - def _worker_setup(self, process_idx: int): reset_seed() self.set_world_ranks(process_idx) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index 594ef3146fcbb..3b393f4a0c1b1 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -47,6 +47,7 @@ class _SpawnLauncher(_Launcher): def __init__(self, strategy: Strategy) -> None: self._strategy = strategy + self._start_method = "spawn" def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. @@ -65,12 +66,13 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] # This needs to be done in the main process here before spawning to ensure each rank will connect # through the same port os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) - context = mp.get_context("spawn") + context = mp.get_context(self._start_method) return_queue = context.SimpleQueue() mp.spawn( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), nprocs=self._strategy.num_processes, + start_method=self._start_method, ) spawn_output = return_queue.get() if trainer is None: diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 50746a463cbed..71acfc1011582 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -14,7 +14,7 @@ import os import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TYPE_CHECKING import torch.multiprocessing as mp @@ -31,6 +31,9 @@ else: xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 +if TYPE_CHECKING: + from pytorch_lightning.strategies import Strategy + class _XLASpawnLauncher(_SpawnLauncher): r"""Spawns processes that run a given function in parallel on XLA supported hardware, and joins them all at the end. @@ -42,8 +45,15 @@ class _XLASpawnLauncher(_SpawnLauncher): Note: - This launcher requires all objects to be pickleable. - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. + + Args: + strategy: A reference to the strategy that is used together with this launcher """ + def __init__(self, strategy: "Strategy") -> None: + super().__init__(strategy) + self._start_method = "fork" + def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. @@ -57,12 +67,13 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] a selected set of attributes get restored in the main process after processes join. **kwargs: Optional keyword arguments to be passed to the given function. """ - context = mp.get_context(self._strategy.start_method or "fork") + context = mp.get_context(self._start_method) return_queue = context.SimpleQueue() xmp.spawn( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), - **self._strategy.get_mp_spawn_kwargs() + nprocs=len(self._strategy.parallel_devices), + start_method=self._start_method, ) spawn_output = return_queue.get() if trainer is None: diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index 1606f3bb4478b..c8ae00c366ecc 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -213,12 +213,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ return output - def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: - return { - "nprocs": len(self.parallel_devices), - "start_method": self.start_method, - } - def _worker_setup(self, process_idx: int): reset_seed() self.tpu_local_core_rank = xm.get_local_ordinal()