diff --git a/pyproject.toml b/pyproject.toml index bb3f093e1be28..76ce8a23a1017 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,6 +67,7 @@ module = [ "pytorch_lightning.utilities.cloud_io", "pytorch_lightning.utilities.device_dtype_mixin", "pytorch_lightning.utilities.device_parser", + "pytorch_lightning.utilities.distributed", "pytorch_lightning.utilities.parsing", ] ignore_errors = "False" diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index ff11df650296f..71292cf8a75b2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os from functools import wraps from platform import python_version -from typing import Any, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -31,21 +31,22 @@ else: - class ReduceOp: + class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153) SUM = None - class group: + class group: # type: ignore WORLD = None log = logging.getLogger(__name__) -def rank_zero_only(fn): +def rank_zero_only(fn: Callable) -> Callable: @wraps(fn) - def wrapped_fn(*args, **kwargs): + def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: if rank_zero_only.rank == 0: return fn(*args, **kwargs) + return None return wrapped_fn @@ -64,7 +65,7 @@ def _get_rank() -> int: rank_zero_only.rank = getattr(rank_zero_only, "rank", _get_rank()) -def rank_zero_warn(*args, stacklevel: int = 5, **kwargs): +def rank_zero_warn(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None: from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn rank_zero_deprecation( @@ -74,7 +75,7 @@ def rank_zero_warn(*args, stacklevel: int = 5, **kwargs): return rank_zero_warn(*args, stacklevel=stacklevel, **kwargs) -def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs): +def rank_zero_deprecation(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None: from pytorch_lightning.utilities.warnings import rank_zero_deprecation rank_zero_deprecation( @@ -84,29 +85,29 @@ def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs): return rank_zero_deprecation(*args, stacklevel=stacklevel, **kwargs) -def _info(*args, stacklevel: int = 2, **kwargs): +def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: if python_version() >= "3.8.0": kwargs["stacklevel"] = stacklevel log.info(*args, **kwargs) -def _debug(*args, stacklevel: int = 2, **kwargs): +def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: if python_version() >= "3.8.0": kwargs["stacklevel"] = stacklevel log.debug(*args, **kwargs) @rank_zero_only -def rank_zero_debug(*args, stacklevel: int = 4, **kwargs): +def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: _debug(*args, stacklevel=stacklevel, **kwargs) @rank_zero_only -def rank_zero_info(*args, stacklevel: int = 4, **kwargs): +def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: _info(*args, stacklevel=stacklevel, **kwargs) -def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): +def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]: """ Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes @@ -141,7 +142,7 @@ def distributed_available() -> bool: def sync_ddp_if_available( - result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: """ Function to reduce a tensor across worker processes during distributed training @@ -160,7 +161,7 @@ def sync_ddp_if_available( def sync_ddp( - result: Union[torch.Tensor], group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None + result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -196,7 +197,11 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor, group=group.WORLD): + def forward( + ctx: Any, + tensor: torch.Tensor, + group: Optional["torch.distributed.ProcessGroup"] = group.WORLD, + ) -> torch.Tensor: ctx.group = group gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] @@ -207,7 +212,7 @@ def forward(ctx, tensor, group=group.WORLD): return gathered_tensor @staticmethod - def backward(ctx, *grad_output): + def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: grad_output = torch.cat(grad_output) torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) @@ -216,7 +221,7 @@ def backward(ctx, *grad_output): def all_gather_ddp_if_available( - tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False + tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes @@ -241,8 +246,8 @@ def all_gather_ddp_if_available( def register_ddp_comm_hook( model: DistributedDataParallel, ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[callable] = None, - ddp_comm_wrapper: Optional[callable] = None, + ddp_comm_hook: Optional[Callable] = None, + ddp_comm_wrapper: Optional[Callable] = None, ) -> None: """ Function to register communication hook for DDP model @@ -322,6 +327,9 @@ def register_ddp_comm_hook( return if ddp_comm_hook is None: return + # inform mypy that ddp_comm_hook is callable + ddp_comm_hook: Callable = ddp_comm_hook + if ddp_comm_wrapper is not None: if not _TORCH_GREATER_EQUAL_1_9: rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")