diff --git a/pyproject.toml b/pyproject.toml index 32cc6e8452d25..05eba62c50402 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,6 @@ module = [ "pytorch_lightning.profilers.pytorch", "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.ddp", - "pytorch_lightning.strategies.ddp_spawn", "pytorch_lightning.strategies.fully_sharded", "pytorch_lightning.strategies.ipu", "pytorch_lightning.strategies.sharded", diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index fdb0a7d851169..6a3460febbf07 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -14,7 +14,7 @@ import logging import os from datetime import timedelta -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.distributed @@ -26,12 +26,14 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy +from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.distributed import ( _get_process_group_backend_from_env, @@ -49,7 +51,7 @@ from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep log = logging.getLogger(__name__) @@ -75,8 +77,8 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[callable] = None, - ddp_comm_wrapper: Optional[callable] = None, + ddp_comm_hook: Optional[Callable] = None, + ddp_comm_wrapper: Optional[Callable] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, start_method: Literal["spawn", "fork", "forkserver"] = "spawn", @@ -113,32 +115,36 @@ def local_rank(self) -> int: return self._local_rank @property - def root_device(self): + def root_device(self) -> torch.device: + assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] @property - def num_processes(self): + def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property - def distributed_sampler_kwargs(self): + def distributed_sampler_kwargs(self) -> Dict[str, int]: distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs @property - def _is_single_process_single_device(self): + def _is_single_process_single_device(self) -> bool: return True @property def process_group_backend(self) -> Optional[str]: return self._process_group_backend - def _configure_launcher(self): + def _configure_launcher(self) -> None: self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method) def setup(self, trainer: "pl.Trainer") -> None: + + assert self.cluster_environment is not None os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) + assert self.accelerator is not None self.accelerator.setup(trainer) # move the model to the correct device @@ -148,6 +154,7 @@ def setup(self, trainer: "pl.Trainer") -> None: trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: if self._layer_sync: + assert self.model is not None self.model = self._layer_sync.apply(self.model) self.setup_precision_plugin() @@ -167,11 +174,12 @@ def set_world_ranks(self, process_idx: int = 0) -> None: self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() - def _worker_setup(self, process_idx: int): + def _worker_setup(self, process_idx: int) -> None: reset_seed() self.set_world_ranks(process_idx) rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None init_dist_connection( self.cluster_environment, self._process_group_backend, @@ -187,7 +195,7 @@ def _get_process_group_backend(self) -> str: or get_default_process_group_backend_for_device(self.root_device) ) - def pre_configure_ddp(self): + def pre_configure_ddp(self) -> None: # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. @@ -198,6 +206,7 @@ def _register_ddp_hooks(self) -> None: # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 if self.root_device.type == "cuda" and self._is_single_process_single_device: + assert isinstance(self.model, DistributedDataParallel) register_ddp_comm_hook( model=self.model, ddp_comm_state=self._ddp_comm_state, @@ -207,19 +216,21 @@ def _register_ddp_hooks(self) -> None: def configure_ddp(self) -> None: self.pre_configure_ddp() + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() # set up optimizers after the wrapped module has been moved to the device + assert self.lightning_module is not None self.setup_optimizers(self.lightning_module.trainer) optimizers_to_device(self.optimizers, self.root_device) - def determine_ddp_device_ids(self): + def determine_ddp_device_ids(self) -> Optional[List[int]]: if self.root_device.type == "cpu": return None return [self.root_device.index] - def barrier(self, *args, **kwargs) -> None: + def barrier(self, *args: Any, **kwargs: Any) -> None: if not distributed_available(): return if torch.distributed.get_backend() == "nccl": @@ -227,27 +238,32 @@ def barrier(self, *args, **kwargs) -> None: else: torch.distributed.barrier() - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: if not distributed_available(): return obj obj = [obj] if self.global_rank != src: - obj = [None] + obj = [None] # type: ignore[list-item] torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) return obj[0] - def model_to_device(self): + def model_to_device(self) -> None: if self.root_device.type == "cuda": # set the device on the spawned subprocesses torch.cuda.set_device(self.root_device) + assert self.model is not None self.model.to(self.root_device) def pre_backward(self, closure_loss: Tensor) -> None: """Run before precision plugin executes backward.""" + assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: + assert isinstance(self.model, DistributedDataParallel) prepare_for_backward(self.model, closure_loss) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> Tensor: + def reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -263,30 +279,38 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def training_step(self, *args, **kwargs) -> STEP_OUTPUT: + def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.model is not None with self.precision_plugin.train_step_context(): return self.model(*args, **kwargs) - def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): + assert self.lightning_module is not None + assert self.model is not None if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: # used when calling `trainer.fit` return self.model(*args, **kwargs) else: # used when calling `trainer.validate` + assert isinstance(self.model, ValidationStep) return self.model.validation_step(*args, **kwargs) - def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.test_step_context(): + assert isinstance(self.model, TestStep) return self.model.test_step(*args, **kwargs) - def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: + def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: with self.precision_plugin.predict_step_context(): + assert isinstance(self.model, PredictStep) return self.model.predict_step(*args, **kwargs) - def post_training_step(self): + def post_training_step(self) -> None: + assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: - self.model.require_backward_grad_sync = True + assert self.model is not None + self.model.require_backward_grad_sync = True # type: ignore[assignment] @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: @@ -315,7 +339,7 @@ def teardown(self) -> None: if ( _TORCH_GREATER_EQUAL_1_11 and not self.model.static_graph - and self.model._get_ddp_logging_data().get("can_set_static_graph") + and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator] ): rank_zero_info( "Your model can run with static graph optimizations. For future training runs, we suggest you" @@ -332,5 +356,6 @@ def teardown(self) -> None: and pl_module._trainer.state.fn == TrainerFn.FITTING and self._layer_sync ): + assert self.model is not None self.model = self._layer_sync.revert(self.model) super().teardown()