Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ module = [
"pytorch_lightning.utilities.cloud_io",
"pytorch_lightning.utilities.device_dtype_mixin",
"pytorch_lightning.utilities.device_parser",
"pytorch_lightning.utilities.distributed",
"pytorch_lightning.utilities.parsing",
]
ignore_errors = "False"
46 changes: 27 additions & 19 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from functools import wraps
from platform import python_version
from typing import Any, Optional, Union
from typing import Any, Callable, List, Optional, Tuple, Type, Union

import torch
from torch.nn.parallel.distributed import DistributedDataParallel
Expand All @@ -31,21 +31,22 @@

else:

class ReduceOp:
class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153)
SUM = None

class group:
class group: # type: ignore
WORLD = None


log = logging.getLogger(__name__)


def rank_zero_only(fn):
def rank_zero_only(fn: Callable) -> Callable:
@wraps(fn)
def wrapped_fn(*args, **kwargs):
def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]:
if rank_zero_only.rank == 0:
return fn(*args, **kwargs)
return None

return wrapped_fn

Expand All @@ -64,7 +65,7 @@ def _get_rank() -> int:
rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank())


def rank_zero_warn(*args, stacklevel: int = 5, **kwargs):
def rank_zero_warn(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None:
from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn

rank_zero_deprecation(
Expand All @@ -74,7 +75,7 @@ def rank_zero_warn(*args, stacklevel: int = 5, **kwargs):
return rank_zero_warn(*args, stacklevel=stacklevel, **kwargs)


def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs):
def rank_zero_deprecation(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None:
from pytorch_lightning.utilities.warnings import rank_zero_deprecation

rank_zero_deprecation(
Expand All @@ -84,29 +85,29 @@ def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs):
return rank_zero_deprecation(*args, stacklevel=stacklevel, **kwargs)


def _info(*args, stacklevel: int = 2, **kwargs):
def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
if python_version() >= "3.8.0":
kwargs["stacklevel"] = stacklevel
log.info(*args, **kwargs)


def _debug(*args, stacklevel: int = 2, **kwargs):
def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None:
if python_version() >= "3.8.0":
kwargs["stacklevel"] = stacklevel
log.debug(*args, **kwargs)


@rank_zero_only
def rank_zero_debug(*args, stacklevel: int = 4, **kwargs):
def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
_debug(*args, stacklevel=stacklevel, **kwargs)


@rank_zero_only
def rank_zero_info(*args, stacklevel: int = 4, **kwargs):
def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None:
_info(*args, stacklevel=stacklevel, **kwargs)


def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None):
def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]:
"""
Function to gather all tensors from several ddp processes onto a list that
is broadcasted to all processes
Expand Down Expand Up @@ -141,7 +142,7 @@ def distributed_available() -> bool:


def sync_ddp_if_available(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce a tensor across worker processes during distributed training
Expand All @@ -160,7 +161,7 @@ def sync_ddp_if_available(


def sync_ddp(
result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None
) -> torch.Tensor:
"""
Function to reduce the tensors from several ddp processes to one master process
Expand Down Expand Up @@ -196,7 +197,11 @@ def sync_ddp(

class AllGatherGrad(torch.autograd.Function):
@staticmethod
def forward(ctx, tensor, group=group.WORLD):
def forward(
ctx: Any,
tensor: torch.Tensor,
group: Optional["torch.distributed.ProcessGroup"] = group.WORLD,
) -> torch.Tensor:
ctx.group = group

gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]
Expand All @@ -207,7 +212,7 @@ def forward(ctx, tensor, group=group.WORLD):
return gathered_tensor

@staticmethod
def backward(ctx, *grad_output):
def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]:
grad_output = torch.cat(grad_output)

torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group)
Expand All @@ -216,7 +221,7 @@ def backward(ctx, *grad_output):


def all_gather_ddp_if_available(
tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False
tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False
) -> torch.Tensor:
"""
Function to gather a tensor from several distributed processes
Expand All @@ -241,8 +246,8 @@ def all_gather_ddp_if_available(
def register_ddp_comm_hook(
model: DistributedDataParallel,
ddp_comm_state: Optional[object] = None,
ddp_comm_hook: Optional[callable] = None,
ddp_comm_wrapper: Optional[callable] = None,
ddp_comm_hook: Optional[Callable] = None,
ddp_comm_wrapper: Optional[Callable] = None,
) -> None:
"""
Function to register communication hook for DDP model
Expand Down Expand Up @@ -322,6 +327,9 @@ def register_ddp_comm_hook(
return
if ddp_comm_hook is None:
return
# inform mypy that ddp_comm_hook is callable
ddp_comm_hook: Callable = ddp_comm_hook

if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
Expand Down