From ae3d6c38faf8b4deafe967f9a91fa63c918320cd Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 28 Jun 2021 19:32:39 +0200 Subject: [PATCH 01/12] Add typing for utilities.distributed.py --- pytorch_lightning/utilities/distributed.py | 48 +++++++++++----------- setup.cfg | 2 + 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 6ca2de7eb2ca2..cdd33d3310662 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os from functools import wraps from platform import python_version -from typing import Any, Optional, Union +from typing import Any, Callable, List, Optional, Tuple, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -31,22 +31,23 @@ else: - class ReduceOp: + class ReduceOp: # type: ignore # (see https://github.com/python/mypy/issues/1153) SUM = None - class group: + class group: # type: ignore WORLD = None log = logging.getLogger(__name__) -def rank_zero_only(fn): +def rank_zero_only(fn: Callable) -> Callable: @wraps(fn) - def wrapped_fn(*args, **kwargs): + def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: if rank_zero_only.rank == 0: return fn(*args, **kwargs) + return None return wrapped_fn @@ -65,7 +66,7 @@ def _get_rank() -> int: rank_zero_only.rank = getattr(rank_zero_only, 'rank', _get_rank()) -def rank_zero_warn(*args, stacklevel: int = 5, **kwargs): +def rank_zero_warn(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None: from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn rank_zero_deprecation( '`pytorch_lightning.utilities.distributed.rank_zero_warn` has been moved to' @@ -74,7 +75,7 @@ def rank_zero_warn(*args, stacklevel: int = 5, **kwargs): return rank_zero_warn(*args, stacklevel=stacklevel, **kwargs) -def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs): +def rank_zero_deprecation(*args: Any, stacklevel: int = 5, **kwargs: Any) -> None: from pytorch_lightning.utilities.warnings import rank_zero_deprecation rank_zero_deprecation( '`pytorch_lightning.utilities.distributed.rank_zero_deprecation` has been moved to' @@ -83,29 +84,29 @@ def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs): return rank_zero_deprecation(*args, stacklevel=stacklevel, **kwargs) -def _info(*args, stacklevel: int = 2, **kwargs): +def _info(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: if python_version() >= "3.8.0": kwargs['stacklevel'] = stacklevel log.info(*args, **kwargs) -def _debug(*args, stacklevel: int = 2, **kwargs): +def _debug(*args: Any, stacklevel: int = 2, **kwargs: Any) -> None: if python_version() >= "3.8.0": kwargs['stacklevel'] = stacklevel log.debug(*args, **kwargs) @rank_zero_only -def rank_zero_debug(*args, stacklevel: int = 4, **kwargs): +def rank_zero_debug(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: _debug(*args, stacklevel=stacklevel, **kwargs) @rank_zero_only -def rank_zero_info(*args, stacklevel: int = 4, **kwargs): +def rank_zero_info(*args: Any, stacklevel: int = 4, **kwargs: Any) -> None: _info(*args, stacklevel=stacklevel, **kwargs) -def gather_all_tensors(result: Union[torch.Tensor], group: Optional[Any] = None): +def gather_all_tensors(result: torch.Tensor, group: Optional[Any] = None) -> List[torch.Tensor]: """ Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes @@ -140,9 +141,7 @@ def distributed_available() -> bool: def sync_ddp_if_available( - result: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None + result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None, ) -> torch.Tensor: """ Function to reduce a tensor across worker processes during distributed training @@ -161,9 +160,7 @@ def sync_ddp_if_available( def sync_ddp( - result: Union[torch.Tensor], - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None + result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None, ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -200,7 +197,9 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod - def forward(ctx, tensor, group=group.WORLD): + def forward( + ctx: Any, tensor: torch.Tensor, group: Optional[torch.distributed.ProcessGroup] = group.WORLD + ) -> torch.Tensor: ctx.group = group gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] @@ -211,7 +210,7 @@ def forward(ctx, tensor, group=group.WORLD): return gathered_tensor @staticmethod - def backward(ctx, *grad_output): + def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: grad_output = torch.cat(grad_output) torch.distributed.all_reduce(grad_output, op=torch.distributed.ReduceOp.SUM, async_op=False, group=ctx.group) @@ -220,7 +219,7 @@ def backward(ctx, *grad_output): def all_gather_ddp_if_available( - tensor: Union[torch.Tensor], group: Optional[Any] = None, sync_grads: bool = False + tensor: torch.Tensor, group: Optional[torch.distributed.ProcessGroup] = None, sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes @@ -245,8 +244,8 @@ def all_gather_ddp_if_available( def register_ddp_comm_hook( model: DistributedDataParallel, ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[callable] = None, - ddp_comm_wrapper: Optional[callable] = None, + ddp_comm_hook: Optional[Callable] = None, + ddp_comm_wrapper: Optional[Callable] = None, ) -> None: """ Function to register communication hook for DDP model @@ -325,6 +324,9 @@ def register_ddp_comm_hook( return if ddp_comm_hook is None: return + # inform mypy that ddp_comm_hook is callable + ddp_comm_hook: Callable = ddp_comm_hook + if ddp_comm_wrapper is not None: if not _TORCH_GREATER_EQUAL_1_9: rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.") diff --git a/setup.cfg b/setup.cfg index 74e02d932dc3c..d5332ddacf03d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -183,6 +183,8 @@ ignore_errors = True ignore_errors = True [mypy-pytorch_lightning.utilities.cli] ignore_errors = False +[mypy-pytorch_lightning.utilities.distributed] +ignore_errors = False # todo: add proper typing to this module... [mypy-pl_examples.*] From cdfa374e7523b36d9f9c791c876b23b025e9992e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Jun 2021 16:25:33 +0000 Subject: [PATCH 02/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/distributed.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index cdd33d3310662..b73348ce1519c 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -141,7 +141,9 @@ def distributed_available() -> bool: def sync_ddp_if_available( - result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None, + result: torch.Tensor, + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None, ) -> torch.Tensor: """ Function to reduce a tensor across worker processes during distributed training @@ -160,7 +162,9 @@ def sync_ddp_if_available( def sync_ddp( - result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None, + result: torch.Tensor, + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = None, ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -219,7 +223,9 @@ def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: def all_gather_ddp_if_available( - tensor: torch.Tensor, group: Optional[torch.distributed.ProcessGroup] = None, sync_grads: bool = False + tensor: torch.Tensor, + group: Optional[torch.distributed.ProcessGroup] = None, + sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes From cd4b93f761a995f80a8b61d18425328254bb3ff5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Jul 2021 12:38:56 +0000 Subject: [PATCH 03/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/distributed.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 78f06becbc40a..7872a37306def 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -42,7 +42,6 @@ class group: # type: ignore def rank_zero_only(fn: Callable) -> Callable: - @wraps(fn) def wrapped_fn(*args: Any, **kwargs: Any) -> Optional[Any]: if rank_zero_only.rank == 0: @@ -143,9 +142,7 @@ def distributed_available() -> bool: def sync_ddp_if_available( - result: torch.Tensor, - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None, + result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: """ Function to reduce a tensor across worker processes during distributed training @@ -164,9 +161,7 @@ def sync_ddp_if_available( def sync_ddp( - result: torch.Tensor, - group: Optional[Any] = None, - reduce_op: Optional[Union[ReduceOp, str]] = None, + result: torch.Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None ) -> torch.Tensor: """ Function to reduce the tensors from several ddp processes to one master process @@ -224,9 +219,7 @@ def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: def all_gather_ddp_if_available( - tensor: torch.Tensor, - group: Optional[torch.distributed.ProcessGroup] = None, - sync_grads: bool = False + tensor: torch.Tensor, group: Optional[torch.distributed.ProcessGroup] = None, sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes From 0e5c57cd6075a43fb5b52b816b25f0c9da7f8269 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 30 Jul 2021 16:29:50 +0200 Subject: [PATCH 04/12] Fix setup.cfg file --- setup.cfg | 110 ------------------------------------------------------ 1 file changed, 110 deletions(-) diff --git a/setup.cfg b/setup.cfg index 9850d71fcbfed..86890f08e2c68 100644 --- a/setup.cfg +++ b/setup.cfg @@ -87,113 +87,3 @@ convention = pep257 # D202: Ignore a blank line after docstring (collision with Python Black in decorators) add-ignore = D104,D107,D202 max-line-length = 120 - - -[yapf] -based_on_style = pep8 -spaces_before_comment = 2 -split_before_logical_operator = true -split_before_arithmetic_operator = true -COLUMN_LIMIT = 120 -COALESCE_BRACKETS = true -DEDENT_CLOSING_BRACKETS = true -ALLOW_SPLIT_BEFORE_DICT_VALUE = false -BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true -NO_SPACES_AROUND_SELECTED_BINARY_OPERATORS = false - - -[mypy] -files = pytorch_lightning, pl_examples, benchmarks, tests -disallow_untyped_defs = True -ignore_missing_imports = True -show_error_codes = True -warn_redundant_casts = True -warn_unused_configs = True -warn_unused_ignores = True -allow_redefinition = True -# disable this rule as the Trainer attributes are defined in the connectors, not in its __init__ -disable_error_code = attr-defined - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.callbacks.*] -ignore_errors = True -# whitelist -[mypy-pytorch_lightning.callbacks.pruning] -ignore_errors = False - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.core.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.loggers.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.loops.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.metrics.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.overrides.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.plugins.environments.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.plugins.training_type.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.profiler.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.pt_overrides.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.root_module.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.trainer.*] -ignore_errors = True -# whitelist -[mypy-pytorch_lightning.trainer.evaluation_loop] -ignore_errors = False -[mypy-pytorch_lightning.trainer.connectors.logger_connector] -ignore_errors = False - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.distributed.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.tuner.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-pytorch_lightning.utilities.*] -ignore_errors = True -[mypy-pytorch_lightning.utilities.cli] -ignore_errors = False -[mypy-pytorch_lightning.utilities.distributed] -ignore_errors = False - -# todo: add proper typing to this module... -[mypy-pl_examples.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-benchmarks.*] -ignore_errors = True - -# todo: add proper typing to this module... -[mypy-tests.*] -ignore_errors = True From 5f7b3f9b85ef13fb48962d10307dff5d42caf681 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 30 Jul 2021 16:34:18 +0200 Subject: [PATCH 05/12] Enable mypy testing for utilities/distributed.py --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 874781367ddd0..cf076c228972e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ module = [ "pytorch_lightning.trainer.evaluation_loop", "pytorch_lightning.trainer.connectors.logger_connector", "pytorch_lightning.utilities.cli", + "pytorch_lightning.utilities.distributed", "pytorch_lightning.utilities.device_dtype_mixin", "pytorch_lightning.utilities.device_parser", "pytorch_lightning.utilities.parsing", From 947b18b4e6cb3045021c409478aa86f4fffb19db Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 30 Jul 2021 16:56:39 +0200 Subject: [PATCH 06/12] Edit forward method of AllGatherGrad class to pass mypy checks --- pytorch_lightning/utilities/distributed.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 7872a37306def..7daaa7afb5a28 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -198,8 +198,14 @@ def sync_ddp( class AllGatherGrad(torch.autograd.Function): @staticmethod def forward( - ctx: Any, tensor: torch.Tensor, group: Optional[torch.distributed.ProcessGroup] = group.WORLD + ctx: Any, + *arg: Any, + tensor: torch.Tensor = None, + group: Optional[torch.distributed.ProcessGroup] = group.WORLD, + **kwargs: Any, ) -> torch.Tensor: + if tensor is None: + raise ValueError("`tensor` should be provided.") ctx.group = group gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] From bddff85174e2c9b45701227b43e26541d15c2e51 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 2 Aug 2021 07:21:02 +0200 Subject: [PATCH 07/12] Fix typing for register_ddp_comm_hook method --- pytorch_lightning/utilities/distributed.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 7daaa7afb5a28..920020936270a 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -16,7 +16,7 @@ import os from functools import wraps from platform import python_version -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Type, Union import torch from torch.nn.parallel.distributed import DistributedDataParallel @@ -248,7 +248,7 @@ def all_gather_ddp_if_available( def register_ddp_comm_hook( - model: DistributedDataParallel, + model: Type[DistributedDataParallel], ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[Callable] = None, ddp_comm_wrapper: Optional[Callable] = None, From d77c62974725d6e06c9904aa92a68c8165641675 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Wed, 4 Aug 2021 18:13:10 +0200 Subject: [PATCH 08/12] Undo one mypy-related change --- pytorch_lightning/utilities/distributed.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 920020936270a..7dfc31ee1fc6f 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -199,13 +199,9 @@ class AllGatherGrad(torch.autograd.Function): @staticmethod def forward( ctx: Any, - *arg: Any, - tensor: torch.Tensor = None, + tensor: torch.Tensor, group: Optional[torch.distributed.ProcessGroup] = group.WORLD, - **kwargs: Any, ) -> torch.Tensor: - if tensor is None: - raise ValueError("`tensor` should be provided.") ctx.group = group gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())] From 6464c66e34cddfe0b6be4c1d8a4fb982389e923c Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Thu, 5 Aug 2021 09:10:52 +0200 Subject: [PATCH 09/12] Apply carmocca's suggestion --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 7dfc31ee1fc6f..03eaa32b93f95 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -221,7 +221,7 @@ def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: def all_gather_ddp_if_available( - tensor: torch.Tensor, group: Optional[torch.distributed.ProcessGroup] = None, sync_grads: bool = False + tensor: torch.Tensor, group: Optional['torch.distributed.ProcessGroup'] = None, sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes From ba48abd5947cecf0d81e428b502143ff80530b15 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 5 Aug 2021 07:12:02 +0000 Subject: [PATCH 10/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 03eaa32b93f95..a74103af1d0ea 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -221,7 +221,7 @@ def backward(ctx: Any, *grad_output: torch.Tensor) -> Tuple[torch.Tensor, None]: def all_gather_ddp_if_available( - tensor: torch.Tensor, group: Optional['torch.distributed.ProcessGroup'] = None, sync_grads: bool = False + tensor: torch.Tensor, group: Optional["torch.distributed.ProcessGroup"] = None, sync_grads: bool = False ) -> torch.Tensor: """ Function to gather a tensor from several distributed processes From 94a5ab73bd0623bafe0b36a5d364dce49d92f519 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Thu, 5 Aug 2021 09:25:47 +0200 Subject: [PATCH 11/12] Put another torch.distributed.ProcessGroup into "" --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index a74103af1d0ea..4682844f806de 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -200,7 +200,7 @@ class AllGatherGrad(torch.autograd.Function): def forward( ctx: Any, tensor: torch.Tensor, - group: Optional[torch.distributed.ProcessGroup] = group.WORLD, + group: Optional["torch.distributed.ProcessGroup"] = group.WORLD, ) -> torch.Tensor: ctx.group = group From 876703a2db6c8a40e543900c2530a173e70ab79e Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Thu, 5 Aug 2021 11:15:02 +0200 Subject: [PATCH 12/12] Change type hint for DDP model --- pytorch_lightning/utilities/distributed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 4682844f806de..71292cf8a75b2 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -244,7 +244,7 @@ def all_gather_ddp_if_available( def register_ddp_comm_hook( - model: Type[DistributedDataParallel], + model: DistributedDataParallel, ddp_comm_state: Optional[object] = None, ddp_comm_hook: Optional[Callable] = None, ddp_comm_wrapper: Optional[Callable] = None,