Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Removed calls to `profile("model_forward")` in favor of profiling `training_step` ([#12032](https://github.com/PyTorchLightning/pytorch-lightning/pull/12032))


- Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the `_SpawnLauncher` ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966))


### Fixed

- Fixed an issue where `HorovodStrategy.teardown()` did not complete gracefully if an exception was thrown during callback setup [#11752](https://github.com/PyTorchLightning/pytorch-lightning/pull/11752)
Expand Down
3 changes: 0 additions & 3 deletions pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {"nprocs": self.num_processes}

def _worker_setup(self, process_idx: int):
reset_seed()
self.set_world_ranks(process_idx)
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/strategies/launchers/spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class _SpawnLauncher(_Launcher):

def __init__(self, strategy: Strategy) -> None:
self._strategy = strategy
self._start_method = "spawn"

def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.
Expand All @@ -65,12 +66,13 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
# This needs to be done in the main process here before spawning to ensure each rank will connect
# through the same port
os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)
context = mp.get_context("spawn")
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()
mp.spawn(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
nprocs=self._strategy.num_processes,
start_method=self._start_method,
)
spawn_output = return_queue.get()
if trainer is None:
Expand Down
17 changes: 14 additions & 3 deletions pytorch_lightning/strategies/launchers/xla_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import time
from multiprocessing.queues import SimpleQueue
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, TYPE_CHECKING

import torch.multiprocessing as mp

Expand All @@ -31,6 +31,9 @@
else:
xm, xmp, MpDeviceLoader, rendezvous = [None] * 4

if TYPE_CHECKING:
from pytorch_lightning.strategies import Strategy


class _XLASpawnLauncher(_SpawnLauncher):
r"""Spawns processes that run a given function in parallel on XLA supported hardware, and joins them all at the end.
Expand All @@ -42,8 +45,15 @@ class _XLASpawnLauncher(_SpawnLauncher):
Note:
- This launcher requires all objects to be pickleable.
- It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.

Args:
strategy: A reference to the strategy that is used together with this launcher
"""

def __init__(self, strategy: "Strategy") -> None:
super().__init__(strategy)
self._start_method = "fork"

def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
"""Spawns processes that run the given function in parallel.

Expand All @@ -57,12 +67,13 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
a selected set of attributes get restored in the main process after processes join.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
context = mp.get_context(self._strategy.start_method or "fork")
context = mp.get_context(self._start_method)
return_queue = context.SimpleQueue()
xmp.spawn(
self._wrapping_function,
args=(trainer, function, args, kwargs, return_queue),
**self._strategy.get_mp_spawn_kwargs()
nprocs=len(self._strategy.parallel_devices),
start_method=self._start_method,
)
spawn_output = return_queue.get()
if trainer is None:
Expand Down
6 changes: 0 additions & 6 deletions pytorch_lightning/strategies/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,12 +213,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[

return output

def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {
"nprocs": len(self.parallel_devices),
"start_method": self.start_method,
}

def _worker_setup(self, process_idx: int):
reset_seed()
self.tpu_local_core_rank = xm.get_local_ordinal()
Expand Down