From 62417ae557c2fc0a5488f35fdae8c98c48fa10f3 Mon Sep 17 00:00:00 2001 From: Cyprien-Ricque Date: Wed, 6 Jul 2022 17:35:08 +0200 Subject: [PATCH 1/4] fix mypy typing in parallel.py --- pyproject.toml | 1 - .../strategies/fully_sharded_native.py | 1 + src/pytorch_lightning/strategies/parallel.py | 22 ++++++++------ src/pytorch_lightning/utilities/types.py | 29 ++++++++++++++++++- 4 files changed, 42 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c08d2c99bf3f5..013c0a6acc1ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,6 @@ module = [ "pytorch_lightning.strategies.fully_sharded", "pytorch_lightning.strategies.horovod", "pytorch_lightning.strategies.ipu", - "pytorch_lightning.strategies.parallel", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", "pytorch_lightning.strategies.tpu_spawn", diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 9a11f61c42e52..d8cd66be2e02f 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -116,6 +116,7 @@ def __init__( # type: ignore[no-untyped-def] @property def root_device(self) -> torch.device: + assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] @property diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index 4fc846870ad59..bc85d80820453 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -13,11 +13,10 @@ # limitations under the License. from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, List, Optional +from typing import Any, List, Optional, Dict, Generator import torch from torch import Tensor -from torch.nn.parallel import DistributedDataParallel import pytorch_lightning as pl from pytorch_lightning.overrides.base import unwrap_lightning_module @@ -81,16 +80,19 @@ def is_global_zero(self) -> bool: return self.global_rank == 0 @property - def parallel_devices(self): + def parallel_devices(self) -> Optional[List[torch.device]]: return self._parallel_devices @parallel_devices.setter - def parallel_devices(self, parallel_devices): + def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None: self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self): - distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank) + def distributed_sampler_kwargs(self) -> Dict[str, int]: + distributed_sampler_kwargs = dict( + num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, + rank=self.global_rank + ) return distributed_sampler_kwargs @property @@ -104,7 +106,7 @@ def torch_distributed_backend(self) -> str: return pg_backend return get_default_process_group_backend_for_device(self.root_device) - def reconciliate_processes(self, trace: str): + def reconciliate_processes(self, trace: str) -> None: """Function to re-conciliate processes on failure.""" def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: @@ -118,18 +120,20 @@ def reduce_boolean_decision(self, decision: bool) -> bool: return decision @contextmanager - def block_backward_sync(self): + def block_backward_sync(self) -> Generator: """Blocks ddp sync gradients behaviour on backwards pass. This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ - if isinstance(self.model, DistributedDataParallel): + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) and \ + isinstance(self.model, pl.utilities.types.DistributedDataParallel): with self.model.no_sync(): yield None else: yield None def teardown(self) -> None: + assert self.cluster_environment is not None self.cluster_environment.teardown() super().teardown() diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 0b10b5eebc7b1..c3f548e6d6baf 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -16,12 +16,14 @@ - Do not include any `_TYPE` suffix - Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`) """ +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union, Generator import torch from torch import Tensor +from torch._C._distributed_c10d import ProcessGroup from torch.optim import Optimizer from torch.utils.data import DataLoader from torchmetrics import Metric @@ -99,6 +101,31 @@ def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) ... +# Inferred from `torch.nn.parallel.distributed.pyi` +# Missing attributes were added to improve typing +@runtime_checkable +class DistributedDataParallel(Protocol): + def __init__( + self, + module: torch.nn.Module, + device_ids: Optional[List[Union[int, torch.device]]] = None, + output_device: Optional[Union[int, torch.device]] = None, + dim: int = 0, + broadcast_buffers: bool = True, + process_group: Optional[ProcessGroup] = None, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, + ) -> None: + ... + + @contextmanager + def no_sync(self) -> Generator: + ... + + # todo: improve LRSchedulerType naming/typing LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau) LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau] From 84b14d4fb7bdc2e0038d1a673a60ca43018ef29e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 6 Jul 2022 15:42:15 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/parallel.py | 10 ++++---- src/pytorch_lightning/utilities/types.py | 26 ++++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index bc85d80820453..52ae0e99b0ed2 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod from contextlib import contextmanager -from typing import Any, List, Optional, Dict, Generator +from typing import Any, Dict, Generator, List, Optional import torch from torch import Tensor @@ -90,8 +90,7 @@ def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> No @property def distributed_sampler_kwargs(self) -> Dict[str, int]: distributed_sampler_kwargs = dict( - num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, - rank=self.global_rank + num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, rank=self.global_rank ) return distributed_sampler_kwargs @@ -126,8 +125,9 @@ def block_backward_sync(self) -> Generator: This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ - if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) and \ - isinstance(self.model, pl.utilities.types.DistributedDataParallel): + if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) and isinstance( + self.model, pl.utilities.types.DistributedDataParallel + ): with self.model.no_sync(): yield None else: diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index c3f548e6d6baf..e8de1091ec6c1 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union, Generator +from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union import torch from torch import Tensor @@ -106,18 +106,18 @@ def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None) @runtime_checkable class DistributedDataParallel(Protocol): def __init__( - self, - module: torch.nn.Module, - device_ids: Optional[List[Union[int, torch.device]]] = None, - output_device: Optional[Union[int, torch.device]] = None, - dim: int = 0, - broadcast_buffers: bool = True, - process_group: Optional[ProcessGroup] = None, - bucket_cap_mb: int = 25, - find_unused_parameters: bool = False, - check_reduction: bool = False, - gradient_as_bucket_view: bool = False, - static_graph: bool = False, + self, + module: torch.nn.Module, + device_ids: Optional[List[Union[int, torch.device]]] = None, + output_device: Optional[Union[int, torch.device]] = None, + dim: int = 0, + broadcast_buffers: bool = True, + process_group: Optional[ProcessGroup] = None, + bucket_cap_mb: int = 25, + find_unused_parameters: bool = False, + check_reduction: bool = False, + gradient_as_bucket_view: bool = False, + static_graph: bool = False, ) -> None: ... From e99ab099a9f5e7fa5701cf554252573a6c114713 Mon Sep 17 00:00:00 2001 From: Cyprien Ricque <48893621+Cyprien-Ricque@users.noreply.github.com> Date: Thu, 7 Jul 2022 16:52:31 +0200 Subject: [PATCH 3/4] distributed_sampler_kwargs property can return "Any" type in dict values MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/strategies/parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index 52ae0e99b0ed2..251b8fdb8a6bc 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -88,7 +88,7 @@ def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> No self._parallel_devices = parallel_devices @property - def distributed_sampler_kwargs(self) -> Dict[str, int]: + def distributed_sampler_kwargs(self) -> Dict[str, Any]: distributed_sampler_kwargs = dict( num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, rank=self.global_rank ) From 19bd0d5b4221d94c4d0cb7cb8ce4caced482f42b Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Tue, 12 Jul 2022 15:22:51 +0530 Subject: [PATCH 4/4] Update src/pytorch_lightning/strategies/parallel.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/strategies/parallel.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/pytorch_lightning/strategies/parallel.py b/src/pytorch_lightning/strategies/parallel.py index 251b8fdb8a6bc..2517848274e3d 100644 --- a/src/pytorch_lightning/strategies/parallel.py +++ b/src/pytorch_lightning/strategies/parallel.py @@ -125,9 +125,7 @@ def block_backward_sync(self) -> Generator: This is useful for skipping sync when accumulating gradients, reducing communication overhead Returns: context manager with sync behaviour off """ - if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) and isinstance( - self.model, pl.utilities.types.DistributedDataParallel - ): + if isinstance(self.model, pl.utilities.types.DistributedDataParallel): with self.model.no_sync(): yield None else: