2424from pytorch_lightning .plugins .io .checkpoint_plugin import CheckpointIO
2525from pytorch_lightning .plugins .precision import PrecisionPlugin
2626from pytorch_lightning .strategies .parallel import ParallelStrategy
27+ from pytorch_lightning .strategies .strategy import TBroadcast
2728from pytorch_lightning .utilities .distributed import distributed_available
2829from pytorch_lightning .utilities .distributed import group as dist_group
2930from pytorch_lightning .utilities .distributed import ReduceOp
3031from pytorch_lightning .utilities .exceptions import MisconfigurationException
3132from pytorch_lightning .utilities .imports import _HOROVOD_AVAILABLE
3233from pytorch_lightning .utilities .rank_zero import rank_zero_only
34+ from pytorch_lightning .utilities .types import _LRScheduler
3335
3436if _HOROVOD_AVAILABLE :
3537 import horovod .torch as hvd
@@ -70,11 +72,11 @@ def world_size(self) -> int:
7072 return hvd .size ()
7173
7274 @property
73- def root_device (self ):
75+ def root_device (self ) -> torch . device :
7476 return self .parallel_devices [self .local_rank ]
7577
7678 @property
77- def distributed_sampler_kwargs (self ):
79+ def distributed_sampler_kwargs (self ) -> Dict [ str , Any ] :
7880 distributed_sampler_kwargs = dict (num_replicas = self .world_size , rank = self .global_rank )
7981 return distributed_sampler_kwargs
8082
@@ -95,7 +97,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
9597 # no need to setup optimizers
9698 return
9799
98- def _unpack_lightning_optimizer (opt ) :
100+ def _unpack_lightning_optimizer (opt : Optimizer ) -> Optimizer :
99101 return opt ._optimizer if isinstance (opt , LightningOptimizer ) else opt
100102
101103 optimizers = self .optimizers
@@ -111,8 +113,10 @@ def _unpack_lightning_optimizer(opt):
111113 lr_scheduler_configs = self .lr_scheduler_configs
112114 for config in lr_scheduler_configs :
113115 scheduler = config .scheduler
116+ assert isinstance (scheduler , _LRScheduler )
114117 scheduler .base_lrs = [lr * self .world_size for lr in scheduler .base_lrs ]
115118
119+ assert self .lightning_module is not None
116120 # Horovod: broadcast parameters & optimizer state to ensure consistent initialization
117121 hvd .broadcast_parameters (self .lightning_module .state_dict (), root_rank = 0 )
118122 for optimizer in optimizers :
@@ -129,27 +133,33 @@ def _unpack_lightning_optimizer(opt):
129133 # Synchronization will be performed explicitly following backward()
130134 self ._exit_stack .enter_context (optimizer .skip_synchronize ())
131135
132- def barrier (self , * args , ** kwargs ) :
136+ def barrier (self , * args : Any , ** kwargs : Any ) -> None :
133137 if distributed_available ():
134138 self .join ()
135139
136- def broadcast (self , obj : object , src : int = 0 ) -> object :
140+ def broadcast (self , obj : TBroadcast , src : int = 0 ) -> TBroadcast :
137141 obj = hvd .broadcast_object (obj , src )
138142 return obj
139143
140- def model_to_device (self ):
144+ def model_to_device (self ) -> None :
141145 if self .root_device .type == "cuda" :
142146 # this can potentially be removed after #8312. Not done due to lack of horovod testing
143147 torch .cuda .set_device (self .root_device )
148+ assert self .model is not None
144149 self .model .to (self .root_device )
145150
146- def join (self ):
151+ def join (self ) -> None :
147152 if self .root_device .type == "cuda" :
148153 hvd .join (self .local_rank )
149154 else :
150155 hvd .join ()
151156
152- def reduce (self , tensor , group : Optional [Any ] = None , reduce_op : Optional [Union [ReduceOp , str ]] = "mean" ):
157+ def reduce (
158+ self ,
159+ tensor : Union [Any , Tensor ],
160+ group : Optional [Any ] = None ,
161+ reduce_op : Optional [Union [ReduceOp , str ]] = "mean" ,
162+ ) -> Union [Any , Tensor ]:
153163 """Reduces a tensor from several distributed processes to one aggregated tensor.
154164
155165 Args:
@@ -196,6 +206,7 @@ def _wrap_optimizers(
196206 self , optimizers : List [Optimizer ], accumulate_grad_batches : int
197207 ) -> List ["hvd.DistributedOptimizer" ]:
198208 """Wraps optimizers to perform gradient aggregation via allreduce."""
209+ assert self .lightning_module is not None
199210 return [
200211 hvd .DistributedOptimizer (
201212 opt ,
0 commit comments