Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- LightningLite:
* Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988))

* Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018))

### Changed

Expand Down
43 changes: 25 additions & 18 deletions pytorch_lightning/plugins/training_type/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
from multiprocessing.queues import SimpleQueue
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
Expand Down Expand Up @@ -155,38 +155,45 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def get_mp_spawn_kwargs(self, trainer: "pl.Trainer") -> dict:
return {"args": (trainer, self.mp_queue), "nprocs": self.num_processes}
def get_mp_spawn_kwargs(self, trainer: Optional["pl.Trainer"] = None) -> Dict[str, Any]:
return {"nprocs": self.num_processes}

def start_training(self, trainer: "pl.Trainer") -> None:
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
self.spawn(self.new_process, trainer, self.mp_queue)
# reset optimizers, since main process is never used for training and thus does not have a valid optim state
trainer.optimizers = []

def start_evaluating(self, trainer: "pl.Trainer") -> None:
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
self.spawn(self.new_process, trainer, self.mp_queue)

def start_predicting(self, trainer: "pl.Trainer") -> None:
mp.spawn(self.new_process, **self.get_mp_spawn_kwargs(trainer))
self.spawn(self.new_process, trainer, self.mp_queue)

def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.mp_queue = mp_queue
def spawn(self, function: Callable, *args: Any, **kwargs: Any) -> None:
"""Spawn processes that run the given function.

reset_seed()
Args:
function: The function to spawn processes from.
*args: Optional positional arguments that will be passed to the function in addition to the process index.
These arguments must be pickleable.
**kwargs: Optional named arguments that will be passed to the function in addition to the process index.
These arguments must be pickleable.
"""
os.environ["MASTER_PORT"] = str(self.cluster_environment.master_port())
mp.spawn(self._wrapped_function, args=(function, args, kwargs), **self.get_mp_spawn_kwargs())

self.set_world_ranks(process_idx)
def _wrapped_function(self, process_idx: int, function: Callable, args: Any, kwargs: Any) -> None:
self._worker_setup(process_idx)
function(*args, **kwargs)

# set warning rank
def _worker_setup(self, process_idx: int):
reset_seed()
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank

# set up server using proc 0's ip address
# try to init for 20 times at max in case ports are taken
# where to store ip_table
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend, self.global_rank, self.world_size)

# TODO: we moved it to the trainer.fit after calling pre_dispatch
# ... need to double check that it is the correct place
# self.trainer.call_setup_hook(self.model)
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
self.mp_queue = mp_queue

# move the model to the correct device
self.model_to_device()
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/training_type/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from multiprocessing.queues import SimpleQueue
from typing import Dict, Generator, Optional

import torch
Expand Down Expand Up @@ -100,13 +101,13 @@ def pre_backward(self, closure_loss: torch.Tensor) -> None:
def post_training_step(self):
pass

def new_process(self, process_idx, trainer, mp_queue):
def new_process(self, trainer: "pl.Trainer", mp_queue: SimpleQueue) -> None:
# Ensure that the scaler points to the correct process group
# which is re-initialized in a new process
precision_plugin = trainer.accelerator.precision_plugin
if isinstance(precision_plugin, ShardedNativeMixedPrecisionPlugin):
precision_plugin.scaler = ShardedGradScaler()
super().new_process(process_idx, trainer, mp_queue)
return super().new_process(trainer, mp_queue)

@classmethod
def register_plugins(cls, plugin_registry: Dict) -> None:
Expand Down