diff --git a/CHANGELOG.md b/CHANGELOG.md index 888d22a520f75..636ac76305d91 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) - + * Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018)) ### Changed diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index eb1acaec4100b..177a58a691bfc 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -15,7 +15,7 @@ import os import re from multiprocessing.queues import SimpleQueue -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import numpy as np import torch @@ -155,38 +155,45 @@ def set_world_ranks(self, process_idx: int = 0) -> None: self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() - def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict: - return {"args": (trainer, self.mp_queue), "nprocs": self.num_processes} + def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]: + return {"nprocs": self.num_processes} def start_training(self, trainer: "pl.Trainer") -> None: - mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) + self.spawn(self.new_process, trainer, self.mp_queue) # reset optimizers, since main process is never used for training and thus does not have a valid optim state trainer.optimizers = [] def start_evaluating(self, trainer: "pl.Trainer") -> None: - mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) + self.spawn(self.new_process, trainer, self.mp_queue) def start_predicting(self, trainer: "pl.Trainer") -> None: - mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer)) + self.spawn(self.new_process, trainer, self.mp_queue) - def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: - self.mp_queue = mp_queue + def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None: + """Spawn processes that run the given function. - reset_seed() + Args: + function: The function to spawn processes from. + *args: Optional positional arguments that will be passed to the function in addition to the process index. + These arguments must be pickleable. + **kwargs: Optional named arguments that will be passed to the function in addition to the process index. + These arguments must be pickleable. + """ + os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port()) + mp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs()) - self.set_world_ranks(process_idx) + def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any) -> None: + self._worker_setup(process_idx) + function(*args, **kwargs) - # set warning rank + def _worker_setup(self, process_idx: int): + reset_seed() + self.set_world_ranks(process_idx) rank_zero_only.rank = self.global_rank - - # set up server using proc 0's ip address - # try to init for 20 times at max in case ports are taken - # where to store ip_table init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size) - # TODO: we moved it to the trainer.fit after calling pre_dispatch - # ... need to double check that it is the correct place - # self.trainer.call_setup_hook(self.model) + def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: + self.mp_queue = mp_queue # move the model to the correct device self.model_to_device() diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 921f89782045b..78b54d029a5f6 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager +from multiprocessing.queues import SimpleQueue from typing import Dict, Generator, Optional import torch @@ -100,13 +101,13 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None: def post_training_step(self): pass - def new_process(self, process_idx, trainer, mp_queue): + def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None: # Ensure that the scaler points to the correct process group # which is re-initialized in a new process precision_plugin = trainer.accelerator.precision_plugin if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin): precision_plugin.scaler = ShardedGradScaler() - super().new_process(process_idx, trainer, mp_queue) + return super().new_process(trainer, mp_queue) @classmethod def register_plugins(cls, plugin_registry: Dict) -> None: