diff --git a/pyproject.toml b/pyproject.toml index c08d2c99bf3f5..013c0a6acc1ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,6 @@ module = [ "pytorch_lightning.strategies.fully_sharded", "pytorch_lightning.strategies.horovod", "pytorch_lightning.strategies.ipu", - "pytorch_lightning.strategies.parallel", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.strategies.tpu_spawn", diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 9a11f61c42e52..d8cd66be2e02f 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -116,6 +116,7 @@ def __init__( # type: ignore[no-untyped-def] @property def root_device(self) -> torch.device: + assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] @property diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index 4fc846870ad59..2517848274e3d 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -13,11 +13,10 @@ # limitations under the License. from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, List, Optional +from typing import Any, Dict, Generator, List, Optional import torch from torch import Tensor -from torch.nn.parallel import DistributedDataParallel import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module @@ -81,16 +80,18 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self): + def parallel_devices(self) -> Optional[List[torch.device]]: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices): + def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank) + def distributed_sampler_kwargs(self) -> Dict[str, Any]: + distributed_sampler_kwargs = dict( + num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, rank=self.global_rank + ) return distributed_sampler_kwargs @property @@ -104,7 +105,7 @@ def torch_distributed_backend(self) -> str: return pg_backend return get_default_process_group_backend_for_device(self.root_device) - def reconciliate_processes(self, trace: str): + def reconciliate_processes(self, trace: str) -> None: """Function to re-conciliate processes on failure.""" def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -118,18 +119,19 @@ def reduce_boolean_decision(self, decision: bool) -> bool: return decision @contextmanager - def block_backward_sync(self): + def block_backward_sync(self) -> Generator: """Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ - if isinstance(self.model, DistributedDataParallel): + if isinstance(self.model, pl.utilities.types.DistributedDataParallel): with self.model.no_sync(): yield None else: yield None def teardown(self) -> None: + assert self.cluster_environment is not None self.cluster_environment.teardown() super().teardown() diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 0b10b5eebc7b1..e8de1091ec6c1 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -16,12 +16,14 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union import torch from torch import Tensor +from torch._C._distributed_c10d import ProcessGroup from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric @@ -99,6 +101,31 @@ def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) ... +# Inferred from `torch.nn.parallel.distributed.pyi` +# Missing attributes were added to improve typing +@runtime_checkable +class DistributedDataParallel(Protocol): + def __init__( + self, + module: torch.nn.Module, + device_ids: Optional[List[Union[int, torch.device]]] = None, + output_device: Optional[Union[int, torch.device]] = None, + dim: int = 0, + broadcast_buffers: bool = True, + process_group: Optional[ProcessGroup] = None, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + ) -> None: + ... + + @contextmanager + def no_sync(self) -> Generator: + ... + + # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]