From 31c70733b5592b8b7ba4cba673f5c86a387ce343 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 12:18:48 +0200 Subject: [PATCH 1/6] update spawn logic --- .../plugins/training_type/ddp_spawn.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index eb1acaec4100b..07fa4e455df26 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -14,8 +14,9 @@ import logging import os import re +from functools import partial from multiprocessing.queues import SimpleQueue -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch @@ -159,34 +160,36 @@ def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict: return {"args": (trainer, self.mp_queue), "nprocs": self.num_processes} def start_training(self, trainer: "pl.Trainer") -> None: + # TODO: refactor: call self.spawn() here mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] def start_evaluating(self, trainer: "pl.Trainer") -> None: + # TODO: refactor: call self.spawn() here mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) def start_predicting(self, trainer: "pl.Trainer") -> None: + # TODO: refactor: call self.spawn() here mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) - def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self.mp_queue = mp_queue + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + mp.spawn(self._wrapped_function, args=(function, args, kwargs), nprocs=self.num_processes) - reset_seed() + def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any) -> None: + self._worker_setup(process_idx) + function(*args, **kwargs) + def _worker_setup(self, process_idx: int): + reset_seed() self.set_world_ranks(process_idx) - - # set warning rank rank_zero_only.rank = self.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size) - # TODO: we moved it to the trainer.fit after calling pre_dispatch - # ... need to double check that it is the correct place - # self.trainer.call_setup_hook(self.model) + def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: + self._worker_setup(process_idx) + self.mp_queue = mp_queue # move the model to the correct device self.model_to_device() From 1fbe084cc90eded1c1b173dba7dc834dc26dd990 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 12:36:20 +0200 Subject: [PATCH 2/6] update changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 888d22a520f75..636ac76305d91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) - + * Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018)) ### Changed From 0aaabe1dcdbec829d78d051ce444c25e9d1ffb08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 14:28:37 +0200 Subject: [PATCH 3/6] resolve todo's --- .../plugins/training_type/ddp_spawn.py | 25 ++++++++++++------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 07fa4e455df26..60a3025f6b8da 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -156,26 +156,33 @@ def set_world_ranks(self, process_idx: int = 0) -> None: self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() - def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict: - return {"args": (trainer, self.mp_queue), "nprocs": self.num_processes} + def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: + return {"nprocs": self.num_processes} def start_training(self, trainer: "pl.Trainer") -> None: - # TODO: refactor: call self.spawn() here - mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) + self.spawn(self.new_process, trainer, self.mp_queue) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] def start_evaluating(self, trainer: "pl.Trainer") -> None: - # TODO: refactor: call self.spawn() here - mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) + self.spawn(self.new_process, trainer, self.mp_queue) def start_predicting(self, trainer: "pl.Trainer") -> None: - # TODO: refactor: call self.spawn() here - mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) + self.spawn(self.new_process, trainer, self.mp_queue) def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: + """Spawn processes that run the given function. + + Args: + function: The function to spawn processes from. It must at least accept one positional argument for the + process index. + *args: Optional positional arguments that will be passed to the function in addition to the process index. + These arguments must be pickleable. + **kwargs: Optional named arguments that will be passed to the function in addition to the process index. + These arguments must be pickleable. + """ os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) - mp.spawn(self._wrapped_function, args=(function, args, kwargs), nprocs=self.num_processes) + mp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs()) def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any) -> None: self._worker_setup(process_idx) From fa47fb8a9fa06cea03fd9163755671bf0ded774b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 14:34:37 +0200 Subject: [PATCH 4/6] remove unused import --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 60a3025f6b8da..77ff9cc44b2b2 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -14,7 +14,6 @@ import logging import os import re -from functools import partial from multiprocessing.queues import SimpleQueue from typing import Any, Callable, Dict, List, Optional, Union From c9a701f3c37d65177bbf2310cf895149dbef1aff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 15:07:54 +0200 Subject: [PATCH 5/6] fix worker setup --- pytorch_lightning/plugins/training_type/ddp_spawn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index 77ff9cc44b2b2..177a58a691bfc 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -173,8 +173,7 @@ def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: """Spawn processes that run the given function. Args: - function: The function to spawn processes from. It must at least accept one positional argument for the - process index. + function: The function to spawn processes from. *args: Optional positional arguments that will be passed to the function in addition to the process index. These arguments must be pickleable. **kwargs: Optional named arguments that will be passed to the function in addition to the process index. @@ -193,8 +192,7 @@ def _worker_setup(self, process_idx: int): rank_zero_only.rank = self.global_rank init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size) - def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self._worker_setup(process_idx) + def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: self.mp_queue = mp_queue # move the model to the correct device From a754aac4fab8a00c540f88b0083b5bca55e5841d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 19 Oct 2021 15:31:57 +0200 Subject: [PATCH 6/6] update signature in sharded plugin --- pytorch_lightning/plugins/training_type/sharded_spawn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 921f89782045b..78b54d029a5f6 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager +from multiprocessing.queues import SimpleQueue from typing import Dict, Generator, Optional import torch @@ -100,13 +101,13 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, process_idx, trainer, mp_queue): + def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process precision_plugin = trainer.accelerator.precision_plugin if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): precision_plugin.scaler = ShardedGradScaler() - super().new_process(process_idx, trainer, mp_queue) + return super().new_process(trainer, mp_queue) @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: