Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ module = [
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.profilers.simple",
"pytorch_lightning.strategies.ddp",
"pytorch_lightning.strategies.ddp_spawn",
"pytorch_lightning.strategies.fully_sharded",
"pytorch_lightning.strategies.ipu",
"pytorch_lightning.strategies.sharded",
Expand Down
73 changes: 49 additions & 24 deletions src/pytorch_lightning/strategies/ddp_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import os
from datetime import timedelta
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
import torch.distributed
Expand All @@ -26,12 +26,14 @@

import pytorch_lightning as pl
from pytorch_lightning.overrides import LightningDistributedModule
from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase
from pytorch_lightning.overrides.distributed import prepare_for_backward
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.distributed import (
_get_process_group_backend_from_env,
Expand All @@ -49,7 +51,7 @@
from pytorch_lightning.utilities.optimizer import optimizers_to_device
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.seed import reset_seed
from pytorch_lightning.utilities.types import STEP_OUTPUT
from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep

log = logging.getLogger(__name__)

Expand All @@ -75,8 +77,8 @@ def __init__(
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[PrecisionPlugin] = None,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
ddp_comm_hook: Optional[Callable] = None,
ddp_comm_wrapper: Optional[Callable] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
start_method: Literal["spawn", "fork", "forkserver"] = "spawn",
Expand Down Expand Up @@ -113,32 +115,36 @@ def local_rank(self) -> int:
return self._local_rank

@property
def root_device(self):
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]

@property
def num_processes(self):
def num_processes(self) -> int:
return len(self.parallel_devices) if self.parallel_devices is not None else 0

@property
def distributed_sampler_kwargs(self):
def distributed_sampler_kwargs(self) -> Dict[str, int]:
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
return distributed_sampler_kwargs

@property
def _is_single_process_single_device(self):
def _is_single_process_single_device(self) -> bool:
return True

@property
def process_group_backend(self) -> Optional[str]:
return self._process_group_backend

def _configure_launcher(self):
def _configure_launcher(self) -> None:
self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method)

def setup(self, trainer: "pl.Trainer") -> None:

assert self.cluster_environment is not None
os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port)

assert self.accelerator is not None
self.accelerator.setup(trainer)

# move the model to the correct device
Expand All @@ -148,6 +154,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING:
if self._layer_sync:
assert self.model is not None
self.model = self._layer_sync.apply(self.model)

self.setup_precision_plugin()
Expand All @@ -167,11 +174,12 @@ def set_world_ranks(self, process_idx: int = 0) -> None:
self.cluster_environment.set_world_size(self.num_nodes * self.num_processes)
rank_zero_only.rank = self.cluster_environment.global_rank()

def _worker_setup(self, process_idx: int):
def _worker_setup(self, process_idx: int) -> None:
reset_seed()
self.set_world_ranks(process_idx)
rank_zero_only.rank = self.global_rank
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
init_dist_connection(
self.cluster_environment,
self._process_group_backend,
Expand All @@ -187,7 +195,7 @@ def _get_process_group_backend(self) -> str:
or get_default_process_group_backend_for_device(self.root_device)
)

def pre_configure_ddp(self):
def pre_configure_ddp(self) -> None:
# if unset, default `find_unused_parameters` `True`
# Many models require setting this parameter to True, as there are corner cases
# when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True.
Expand All @@ -198,6 +206,7 @@ def _register_ddp_hooks(self) -> None:
# currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode
# https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084
if self.root_device.type == "cuda" and self._is_single_process_single_device:
assert isinstance(self.model, DistributedDataParallel)
register_ddp_comm_hook(
model=self.model,
ddp_comm_state=self._ddp_comm_state,
Expand All @@ -207,47 +216,54 @@ def _register_ddp_hooks(self) -> None:

def configure_ddp(self) -> None:
self.pre_configure_ddp()
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
self.model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()

# set up optimizers after the wrapped module has been moved to the device
assert self.lightning_module is not None
self.setup_optimizers(self.lightning_module.trainer)
optimizers_to_device(self.optimizers, self.root_device)

def determine_ddp_device_ids(self):
def determine_ddp_device_ids(self) -> Optional[List[int]]:
if self.root_device.type == "cpu":
return None
return [self.root_device.index]

def barrier(self, *args, **kwargs) -> None:
def barrier(self, *args: Any, **kwargs: Any) -> None:
if not distributed_available():
return
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
else:
torch.distributed.barrier()

def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
if not distributed_available():
return obj
obj = [obj]
if self.global_rank != src:
obj = [None]
obj = [None] # type: ignore[list-item]
torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
return obj[0]

def model_to_device(self):
def model_to_device(self) -> None:
if self.root_device.type == "cuda":
# set the device on the spawned subprocesses
torch.cuda.set_device(self.root_device)
assert self.model is not None
self.model.to(self.root_device)

def pre_backward(self, closure_loss: Tensor) -> None:
"""Run before precision plugin executes backward."""
assert self.lightning_module is not None
if not self.lightning_module.automatic_optimization:
assert isinstance(self.model, DistributedDataParallel)
prepare_for_backward(self.model, closure_loss)

def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> Tensor:
def reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"
) -> Tensor:
"""Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
Expand All @@ -263,30 +279,38 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp,
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
return tensor

def training_step(self, *args, **kwargs) -> STEP_OUTPUT:
def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
with self.precision_plugin.train_step_context():
return self.model(*args, **kwargs)

def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.val_step_context():
assert self.lightning_module is not None
assert self.model is not None
if self.lightning_module.trainer.state.fn == TrainerFn.FITTING:
# used when calling `trainer.fit`
return self.model(*args, **kwargs)
else:
# used when calling `trainer.validate`
assert isinstance(self.model, ValidationStep)
return self.model.validation_step(*args, **kwargs)

def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]:
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
with self.precision_plugin.test_step_context():
assert isinstance(self.model, TestStep)
return self.model.test_step(*args, **kwargs)

def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
assert isinstance(self.model, PredictStep)
return self.model.predict_step(*args, **kwargs)

def post_training_step(self):
def post_training_step(self) -> None:
assert self.lightning_module is not None
if not self.lightning_module.automatic_optimization:
self.model.require_backward_grad_sync = True
assert self.model is not None
self.model.require_backward_grad_sync = True # type: ignore[assignment]

@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
Expand Down Expand Up @@ -315,7 +339,7 @@ def teardown(self) -> None:
if (
_TORCH_GREATER_EQUAL_1_11
and not self.model.static_graph
and self.model._get_ddp_logging_data().get("can_set_static_graph")
and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator]
):
rank_zero_info(
"Your model can run with static graph optimizations. For future training runs, we suggest you"
Expand All @@ -332,5 +356,6 @@ def teardown(self) -> None:
and pl_module._trainer.state.fn == TrainerFn.FITTING
and self._layer_sync
):
assert self.model is not None
self.model = self._layer_sync.revert(self.model)
super().teardown()