1616import os
1717from functools import wraps
1818from platform import python_version
19- from typing import Any , Optional , Union
19+ from typing import Any , Callable , List , Optional , Tuple , Type , Union
2020
2121import torch
2222from torch .nn .parallel .distributed import DistributedDataParallel
3131
3232else :
3333
34- class ReduceOp :
34+ class ReduceOp : # type: ignore # (see https://github.com/python/mypy/issues/1153)
3535 SUM = None
3636
37- class group :
37+ class group : # type: ignore
3838 WORLD = None
3939
4040
4141log = logging .getLogger (__name__ )
4242
4343
44- def rank_zero_only (fn ) :
44+ def rank_zero_only (fn : Callable ) -> Callable :
4545 @wraps (fn )
46- def wrapped_fn (* args , ** kwargs ) :
46+ def wrapped_fn (* args : Any , ** kwargs : Any ) -> Optional [ Any ] :
4747 if rank_zero_only .rank == 0 :
4848 return fn (* args , ** kwargs )
49+ return None
4950
5051 return wrapped_fn
5152
@@ -64,7 +65,7 @@ def _get_rank() -> int:
6465rank_zero_only .rank = getattr (rank_zero_only , "rank" , _get_rank ())
6566
6667
67- def rank_zero_warn (* args , stacklevel : int = 5 , ** kwargs ) :
68+ def rank_zero_warn (* args : Any , stacklevel : int = 5 , ** kwargs : Any ) -> None :
6869 from pytorch_lightning .utilities .warnings import rank_zero_deprecation , rank_zero_warn
6970
7071 rank_zero_deprecation (
@@ -74,7 +75,7 @@ def rank_zero_warn(*args, stacklevel: int = 5, **kwargs):
7475 return rank_zero_warn (* args , stacklevel = stacklevel , ** kwargs )
7576
7677
77- def rank_zero_deprecation (* args , stacklevel : int = 5 , ** kwargs ) :
78+ def rank_zero_deprecation (* args : Any , stacklevel : int = 5 , ** kwargs : Any ) -> None :
7879 from pytorch_lightning .utilities .warnings import rank_zero_deprecation
7980
8081 rank_zero_deprecation (
@@ -84,29 +85,29 @@ def rank_zero_deprecation(*args, stacklevel: int = 5, **kwargs):
8485 return rank_zero_deprecation (* args , stacklevel = stacklevel , ** kwargs )
8586
8687
87- def _info (* args , stacklevel : int = 2 , ** kwargs ) :
88+ def _info (* args : Any , stacklevel : int = 2 , ** kwargs : Any ) -> None :
8889 if python_version () >= "3.8.0" :
8990 kwargs ["stacklevel" ] = stacklevel
9091 log .info (* args , ** kwargs )
9192
9293
93- def _debug (* args , stacklevel : int = 2 , ** kwargs ) :
94+ def _debug (* args : Any , stacklevel : int = 2 , ** kwargs : Any ) -> None :
9495 if python_version () >= "3.8.0" :
9596 kwargs ["stacklevel" ] = stacklevel
9697 log .debug (* args , ** kwargs )
9798
9899
99100@rank_zero_only
100- def rank_zero_debug (* args , stacklevel : int = 4 , ** kwargs ) :
101+ def rank_zero_debug (* args : Any , stacklevel : int = 4 , ** kwargs : Any ) -> None :
101102 _debug (* args , stacklevel = stacklevel , ** kwargs )
102103
103104
104105@rank_zero_only
105- def rank_zero_info (* args , stacklevel : int = 4 , ** kwargs ) :
106+ def rank_zero_info (* args : Any , stacklevel : int = 4 , ** kwargs : Any ) -> None :
106107 _info (* args , stacklevel = stacklevel , ** kwargs )
107108
108109
109- def gather_all_tensors (result : Union [ torch .Tensor ] , group : Optional [Any ] = None ):
110+ def gather_all_tensors (result : torch .Tensor , group : Optional [Any ] = None ) -> List [ torch . Tensor ] :
110111 """
111112 Function to gather all tensors from several ddp processes onto a list that
112113 is broadcasted to all processes
@@ -141,7 +142,7 @@ def distributed_available() -> bool:
141142
142143
143144def sync_ddp_if_available (
144- result : Union [ torch .Tensor ] , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
145+ result : torch .Tensor , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
145146) -> torch .Tensor :
146147 """
147148 Function to reduce a tensor across worker processes during distributed training
@@ -160,7 +161,7 @@ def sync_ddp_if_available(
160161
161162
162163def sync_ddp (
163- result : Union [ torch .Tensor ] , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
164+ result : torch .Tensor , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = None
164165) -> torch .Tensor :
165166 """
166167 Function to reduce the tensors from several ddp processes to one master process
@@ -196,7 +197,11 @@ def sync_ddp(
196197
197198class AllGatherGrad (torch .autograd .Function ):
198199 @staticmethod
199- def forward (ctx , tensor , group = group .WORLD ):
200+ def forward (
201+ ctx : Any ,
202+ tensor : torch .Tensor ,
203+ group : Optional ["torch.distributed.ProcessGroup" ] = group .WORLD ,
204+ ) -> torch .Tensor :
200205 ctx .group = group
201206
202207 gathered_tensor = [torch .zeros_like (tensor ) for _ in range (torch .distributed .get_world_size ())]
@@ -207,7 +212,7 @@ def forward(ctx, tensor, group=group.WORLD):
207212 return gathered_tensor
208213
209214 @staticmethod
210- def backward (ctx , * grad_output ) :
215+ def backward (ctx : Any , * grad_output : torch . Tensor ) -> Tuple [ torch . Tensor , None ] :
211216 grad_output = torch .cat (grad_output )
212217
213218 torch .distributed .all_reduce (grad_output , op = torch .distributed .ReduceOp .SUM , async_op = False , group = ctx .group )
@@ -216,7 +221,7 @@ def backward(ctx, *grad_output):
216221
217222
218223def all_gather_ddp_if_available (
219- tensor : Union [ torch .Tensor ] , group : Optional [Any ] = None , sync_grads : bool = False
224+ tensor : torch .Tensor , group : Optional ["torch.distributed.ProcessGroup" ] = None , sync_grads : bool = False
220225) -> torch .Tensor :
221226 """
222227 Function to gather a tensor from several distributed processes
@@ -241,8 +246,8 @@ def all_gather_ddp_if_available(
241246def register_ddp_comm_hook (
242247 model : DistributedDataParallel ,
243248 ddp_comm_state : Optional [object ] = None ,
244- ddp_comm_hook : Optional [callable ] = None ,
245- ddp_comm_wrapper : Optional [callable ] = None ,
249+ ddp_comm_hook : Optional [Callable ] = None ,
250+ ddp_comm_wrapper : Optional [Callable ] = None ,
246251) -> None :
247252 """
248253 Function to register communication hook for DDP model
@@ -322,6 +327,9 @@ def register_ddp_comm_hook(
322327 return
323328 if ddp_comm_hook is None :
324329 return
330+ # inform mypy that ddp_comm_hook is callable
331+ ddp_comm_hook : Callable = ddp_comm_hook
332+
325333 if ddp_comm_wrapper is not None :
326334 if not _TORCH_GREATER_EQUAL_1_9 :
327335 rank_zero_warn ("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0." )
0 commit comments