Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ module = [
"pytorch_lightning.strategies.fully_sharded",
"pytorch_lightning.strategies.horovod",
"pytorch_lightning.strategies.ipu",
"pytorch_lightning.strategies.parallel",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.strategies.sharded_spawn",
"pytorch_lightning.strategies.tpu_spawn",
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__( # type: ignore[no-untyped-def]

@property
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
return self.parallel_devices[self.local_rank]

@property
Expand Down
20 changes: 11 additions & 9 deletions src/pytorch_lightning/strategies/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, List, Optional
from typing import Any, Dict, Generator, List, Optional

import torch
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import unwrap_lightning_module
Expand Down Expand Up @@ -81,16 +80,18 @@ def is_global_zero(self) -> bool:
return self.global_rank == 0

@property
def parallel_devices(self):
def parallel_devices(self) -> Optional[List[torch.device]]:
return self._parallel_devices

@parallel_devices.setter
def parallel_devices(self, parallel_devices):
def parallel_devices(self, parallel_devices: Optional[List[torch.device]]) -> None:
self._parallel_devices = parallel_devices

@property
def distributed_sampler_kwargs(self):
distributed_sampler_kwargs = dict(num_replicas=len(self.parallel_devices), rank=self.global_rank)
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
distributed_sampler_kwargs = dict(
num_replicas=len(self.parallel_devices) if self.parallel_devices is not None else 0, rank=self.global_rank
)
return distributed_sampler_kwargs

@property
Expand All @@ -104,7 +105,7 @@ def torch_distributed_backend(self) -> str:
return pg_backend
return get_default_process_group_backend_for_device(self.root_device)

def reconciliate_processes(self, trace: str):
def reconciliate_processes(self, trace: str) -> None:
"""Function to re-conciliate processes on failure."""

def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
Expand All @@ -118,18 +119,19 @@ def reduce_boolean_decision(self, decision: bool) -> bool:
return decision

@contextmanager
def block_backward_sync(self):
def block_backward_sync(self) -> Generator:
"""Blocks ddp sync gradients behaviour on backwards pass.

This is useful for skipping sync when accumulating gradients, reducing communication overhead
Returns: context manager with sync behaviour off
"""
if isinstance(self.model, DistributedDataParallel):
if isinstance(self.model, pl.utilities.types.DistributedDataParallel):
with self.model.no_sync():
yield None
else:
yield None

def teardown(self) -> None:
assert self.cluster_environment is not None
self.cluster_environment.teardown()
super().teardown()
29 changes: 28 additions & 1 deletion src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@
- Do not include any `_TYPE` suffix
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no leading `_`)
"""
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union
from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union

import torch
from torch import Tensor
from torch._C._distributed_c10d import ProcessGroup
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torchmetrics import Metric
Expand Down Expand Up @@ -99,6 +101,31 @@ def step(self, metrics: Union[float, int, Tensor], epoch: Optional[int] = None)
...


# Inferred from `torch.nn.parallel.distributed.pyi`
# Missing attributes were added to improve typing
@runtime_checkable
class DistributedDataParallel(Protocol):
def __init__(
self,
module: torch.nn.Module,
device_ids: Optional[List[Union[int, torch.device]]] = None,
output_device: Optional[Union[int, torch.device]] = None,
dim: int = 0,
broadcast_buffers: bool = True,
process_group: Optional[ProcessGroup] = None,
bucket_cap_mb: int = 25,
find_unused_parameters: bool = False,
check_reduction: bool = False,
gradient_as_bucket_view: bool = False,
static_graph: bool = False,
) -> None:
...

@contextmanager
def no_sync(self) -> Generator:
...


# todo: improve LRSchedulerType naming/typing
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
Expand Down