Skip to content

Commit e179a58

Browse files
fix mypy typing errors in pytorch_lightning/strategies/horovod.py (#13570)
1 parent c1cc112 commit e179a58

File tree

4 files changed

+22
-9
lines changed

4 files changed

+22
-9
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,6 @@ module = [
7272
"pytorch_lightning.strategies.ddp_spawn",
7373
"pytorch_lightning.strategies.deepspeed",
7474
"pytorch_lightning.strategies.fully_sharded",
75-
"pytorch_lightning.strategies.horovod",
7675
"pytorch_lightning.strategies.ipu",
7776
"pytorch_lightning.strategies.parallel",
7877
"pytorch_lightning.strategies.sharded",

src/pytorch_lightning/strategies/hivemind.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,8 @@ class HiveMindScheduler:
310310
This code ensures that we only step when the HiveMind optimizer reaches the global step.
311311
"""
312312

313+
base_lrs: List[float]
314+
313315
def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
314316
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
315317
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`

src/pytorch_lightning/strategies/horovod.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,14 @@
2424
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2525
from pytorch_lightning.plugins.precision import PrecisionPlugin
2626
from pytorch_lightning.strategies.parallel import ParallelStrategy
27+
from pytorch_lightning.strategies.strategy import TBroadcast
2728
from pytorch_lightning.utilities.distributed import distributed_available
2829
from pytorch_lightning.utilities.distributed import group as dist_group
2930
from pytorch_lightning.utilities.distributed import ReduceOp
3031
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3132
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
3233
from pytorch_lightning.utilities.rank_zero import rank_zero_only
34+
from pytorch_lightning.utilities.types import _LRScheduler
3335

3436
if _HOROVOD_AVAILABLE:
3537
import horovod.torch as hvd
@@ -70,11 +72,11 @@ def world_size(self) -> int:
7072
return hvd.size()
7173

7274
@property
73-
def root_device(self):
75+
def root_device(self) -> torch.device:
7476
return self.parallel_devices[self.local_rank]
7577

7678
@property
77-
def distributed_sampler_kwargs(self):
79+
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
7880
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
7981
return distributed_sampler_kwargs
8082

@@ -95,7 +97,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
9597
# no need to setup optimizers
9698
return
9799

98-
def _unpack_lightning_optimizer(opt):
100+
def _unpack_lightning_optimizer(opt: Optimizer) -> Optimizer:
99101
return opt._optimizer if isinstance(opt, LightningOptimizer) else opt
100102

101103
optimizers = self.optimizers
@@ -111,8 +113,10 @@ def _unpack_lightning_optimizer(opt):
111113
lr_scheduler_configs = self.lr_scheduler_configs
112114
for config in lr_scheduler_configs:
113115
scheduler = config.scheduler
116+
assert isinstance(scheduler, _LRScheduler)
114117
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]
115118

119+
assert self.lightning_module is not None
116120
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
117121
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
118122
for optimizer in optimizers:
@@ -129,27 +133,33 @@ def _unpack_lightning_optimizer(opt):
129133
# Synchronization will be performed explicitly following backward()
130134
self._exit_stack.enter_context(optimizer.skip_synchronize())
131135

132-
def barrier(self, *args, **kwargs):
136+
def barrier(self, *args: Any, **kwargs: Any) -> None:
133137
if distributed_available():
134138
self.join()
135139

136-
def broadcast(self, obj: object, src: int = 0) -> object:
140+
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
137141
obj = hvd.broadcast_object(obj, src)
138142
return obj
139143

140-
def model_to_device(self):
144+
def model_to_device(self) -> None:
141145
if self.root_device.type == "cuda":
142146
# this can potentially be removed after #8312. Not done due to lack of horovod testing
143147
torch.cuda.set_device(self.root_device)
148+
assert self.model is not None
144149
self.model.to(self.root_device)
145150

146-
def join(self):
151+
def join(self) -> None:
147152
if self.root_device.type == "cuda":
148153
hvd.join(self.local_rank)
149154
else:
150155
hvd.join()
151156

152-
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
157+
def reduce(
158+
self,
159+
tensor: Union[Any, Tensor],
160+
group: Optional[Any] = None,
161+
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
162+
) -> Union[Any, Tensor]:
153163
"""Reduces a tensor from several distributed processes to one aggregated tensor.
154164
155165
Args:
@@ -196,6 +206,7 @@ def _wrap_optimizers(
196206
self, optimizers: List[Optimizer], accumulate_grad_batches: int
197207
) -> List["hvd.DistributedOptimizer"]:
198208
"""Wraps optimizers to perform gradient aggregation via allreduce."""
209+
assert self.lightning_module is not None
199210
return [
200211
hvd.DistributedOptimizer(
201212
opt,

src/pytorch_lightning/utilities/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
6565
@runtime_checkable
6666
class _LRScheduler(_Stateful, Protocol):
6767
optimizer: Optimizer
68+
base_lrs: List[float]
6869

6970
def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
7071
...

0 commit comments

Comments
 (0)