From 8c74c9b436ed88df1b80b869dcf94d9dfe43bd52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 17 Feb 2022 22:29:59 +0100 Subject: [PATCH 1/5] refactor get_mp_spawn_kwargs --- pytorch_lightning/strategies/ddp_spawn.py | 3 --- .../strategies/launchers/xla_spawn.py | 15 +++++++++++++-- pytorch_lightning/strategies/tpu_spawn.py | 6 ------ 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/strategies/ddp_spawn.py b/pytorch_lightning/strategies/ddp_spawn.py index 9b58137d2719d..8467bbb7d8375 100644 --- a/pytorch_lightning/strategies/ddp_spawn.py +++ b/pytorch_lightning/strategies/ddp_spawn.py @@ -139,9 +139,6 @@ def set_world_ranks(self, process_idx: int = 0) -> None: self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() - def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: - return {"nprocs": self.num_processes} - def _worker_setup(self, process_idx: int): reset_seed() self.set_world_ranks(process_idx) diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 8bac7888c568b..4886924e41754 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -19,6 +19,7 @@ import torch.multiprocessing as mp import pytorch_lightning as pl +from pytorch_lightning.strategies import Strategy from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnLauncher, _SpawnOutput from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE @@ -42,8 +43,17 @@ class _XLASpawnLauncher(_SpawnLauncher): Note: - This launcher requires all objects to be pickleable. - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``. + + Args: + strategy: A reference to the strategy that is used together with this launcher + start_method: Start method for `~torch_xla.distributed.xla_multiprocessing.spawn`. Accepted options are + ``'spawn'`` or ``'fork'``. """ + def __init__(self, strategy: Strategy, start_method: str = "fork") -> None: + super().__init__(strategy) + self._start_method = start_method + def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. @@ -60,12 +70,13 @@ def launch(self, function: Callable, *args: Any, **kwargs: Any) -> Any: into the function. """ trainer = kwargs.pop("trainer", None) - context = mp.get_context(self._strategy.start_method or "fork") + context = mp.get_context(self._start_method) return_queue = context.SimpleQueue() xmp.spawn( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), - **self._strategy.get_mp_spawn_kwargs() + nprocs=len(self._strategy.parallel_devices), + start_method=self._start_method, ) spawn_output = return_queue.get() if trainer is None: diff --git a/pytorch_lightning/strategies/tpu_spawn.py b/pytorch_lightning/strategies/tpu_spawn.py index f3d855b43f8a6..8612b25b36d96 100644 --- a/pytorch_lightning/strategies/tpu_spawn.py +++ b/pytorch_lightning/strategies/tpu_spawn.py @@ -210,12 +210,6 @@ def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ return output - def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: - return { - "nprocs": len(self.parallel_devices), - "start_method": self.start_method, - } - def _worker_setup(self, process_idx: int): reset_seed() self.tpu_local_core_rank = xm.get_local_ordinal() From 636900fff658332580e887f228dd2de4856e6db1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 17 Feb 2022 22:45:28 +0100 Subject: [PATCH 2/5] resolve circular import --- pytorch_lightning/strategies/launchers/xla_spawn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index 4886924e41754..1f843795dc460 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -14,12 +14,11 @@ import os import time from multiprocessing.queues import SimpleQueue -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TYPE_CHECKING import torch.multiprocessing as mp import pytorch_lightning as pl -from pytorch_lightning.strategies import Strategy from pytorch_lightning.strategies.launchers.spawn import _FakeQueue, _SpawnLauncher, _SpawnOutput from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _TPU_AVAILABLE @@ -32,6 +31,9 @@ else: xm, xmp, MpDeviceLoader, rendezvous = [None] * 4 +if TYPE_CHECKING: + from pytorch_lightning.strategies import Strategy + class _XLASpawnLauncher(_SpawnLauncher): r"""Spawns processes that run a given function in parallel on XLA supported hardware, and joins them all at the end. @@ -50,7 +52,7 @@ class _XLASpawnLauncher(_SpawnLauncher): ``'spawn'`` or ``'fork'``. """ - def __init__(self, strategy: Strategy, start_method: str = "fork") -> None: + def __init__(self, strategy: "Strategy", start_method: str = "fork") -> None: super().__init__(strategy) self._start_method = start_method From f74eeca8cf9b47cd022dbd2e35a52c75324f0fdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 17 Feb 2022 22:51:14 +0100 Subject: [PATCH 3/5] add changelog --- CHANGELOG.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ba17b476d3873..59f0f512e47a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -541,7 +541,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed `FitLoop.current_epoch` getter and setter ([#11562](https://github.com/PyTorchLightning/pytorch-lightning/pull/11562)) - - Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517)) +- Removed access to `_short_id` in `NeptuneLogger` ([#11517](https://github.com/PyTorchLightning/pytorch-lightning/pull/11517)) + + +- Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the Launcher ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966)) ### Fixed From 2e3d3926958ff799a16bfd4f4c92c58fa98463b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 22 Feb 2022 11:46:57 +0100 Subject: [PATCH 4/5] rewiew --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 06c481edf2f7a..2e660a87063af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -569,7 +569,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Removed calls to `profile("model_forward")` in favor of profiling `training_step` ([#12032](https://github.com/PyTorchLightning/pytorch-lightning/pull/12032)) -- Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the Launcher ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966)) + +- Removed `get_mp_spawn_kwargs` from `DDPSpawnStrategy` and `TPUSpawnStrategy` in favor of configuration in the `_SpawnLauncher` ([#11966](https://github.com/PyTorchLightning/pytorch-lightning/pull/11966)) ### Fixed From 9ac34a15f6408bed7857869f435b6e7d7e01cd46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 22 Feb 2022 11:49:36 +0100 Subject: [PATCH 5/5] remove start_method from constructor --- pytorch_lightning/strategies/launchers/spawn.py | 4 +++- pytorch_lightning/strategies/launchers/xla_spawn.py | 6 ++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/strategies/launchers/spawn.py b/pytorch_lightning/strategies/launchers/spawn.py index 594ef3146fcbb..3b393f4a0c1b1 100644 --- a/pytorch_lightning/strategies/launchers/spawn.py +++ b/pytorch_lightning/strategies/launchers/spawn.py @@ -47,6 +47,7 @@ class _SpawnLauncher(_Launcher): def __init__(self, strategy: Strategy) -> None: self._strategy = strategy + self._start_method = "spawn" def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel. @@ -65,12 +66,13 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] # This needs to be done in the main process here before spawning to ensure each rank will connect # through the same port os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port) - context = mp.get_context("spawn") + context = mp.get_context(self._start_method) return_queue = context.SimpleQueue() mp.spawn( self._wrapping_function, args=(trainer, function, args, kwargs, return_queue), nprocs=self._strategy.num_processes, + start_method=self._start_method, ) spawn_output = return_queue.get() if trainer is None: diff --git a/pytorch_lightning/strategies/launchers/xla_spawn.py b/pytorch_lightning/strategies/launchers/xla_spawn.py index b17cb5a05cdd0..71acfc1011582 100644 --- a/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -48,13 +48,11 @@ class _XLASpawnLauncher(_SpawnLauncher): Args: strategy: A reference to the strategy that is used together with this launcher - start_method: Start method for `~torch_xla.distributed.xla_multiprocessing.spawn`. Accepted options are - ``'spawn'`` or ``'fork'``. """ - def __init__(self, strategy: "Strategy", start_method: str = "fork") -> None: + def __init__(self, strategy: "Strategy") -> None: super().__init__(strategy) - self._start_method = start_method + self._start_method = "fork" def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any: """Spawns processes that run the given function in parallel.