Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ module = [
"pytorch_lightning.strategies.deepspeed",
"pytorch_lightning.strategies.dp",
"pytorch_lightning.strategies.fully_sharded",
"pytorch_lightning.strategies.horovod",
"pytorch_lightning.strategies.ipu",
"pytorch_lightning.strategies.parallel",
"pytorch_lightning.strategies.sharded",
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/strategies/hivemind.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ class HiveMindScheduler:
This code ensures that we only step when the HiveMind optimizer reaches the global step.
"""

base_lrs: List[float]

def __init__(self, optimizer: "hivemind.Optimizer", scheduler: _LRScheduler) -> None:
# copy most of the `Scheduler` methods into this instance. `__del__` is skipped in case the scheduler has
# implemented custom logic which we would not want to call on destruction of the `HiveMindScheduler`
Expand Down
27 changes: 19 additions & 8 deletions src/pytorch_lightning/strategies/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.strategies.parallel import ParallelStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.distributed import group as dist_group
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _HOROVOD_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from pytorch_lightning.utilities.types import _LRScheduler

if _HOROVOD_AVAILABLE:
import horovod.torch as hvd
Expand Down Expand Up @@ -70,11 +72,11 @@ def world_size(self) -> int:
return hvd.size()

@property
def root_device(self):
def root_device(self) -> torch.device:
return self.parallel_devices[self.local_rank]

@property
def distributed_sampler_kwargs(self):
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
return distributed_sampler_kwargs

Expand All @@ -95,7 +97,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
# no need to setup optimizers
return

def _unpack_lightning_optimizer(opt):
def _unpack_lightning_optimizer(opt: Optimizer) -> Optimizer:
return opt._optimizer if isinstance(opt, LightningOptimizer) else opt

optimizers = self.optimizers
Expand All @@ -111,8 +113,10 @@ def _unpack_lightning_optimizer(opt):
lr_scheduler_configs = self.lr_scheduler_configs
for config in lr_scheduler_configs:
scheduler = config.scheduler
assert isinstance(scheduler, _LRScheduler)
scheduler.base_lrs = [lr * self.world_size for lr in scheduler.base_lrs]

assert self.lightning_module is not None
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
hvd.broadcast_parameters(self.lightning_module.state_dict(), root_rank=0)
for optimizer in optimizers:
Expand All @@ -129,27 +133,33 @@ def _unpack_lightning_optimizer(opt):
# Synchronization will be performed explicitly following backward()
self._exit_stack.enter_context(optimizer.skip_synchronize())

def barrier(self, *args, **kwargs):
def barrier(self, *args: Any, **kwargs: Any) -> None:
if distributed_available():
self.join()

def broadcast(self, obj: object, src: int = 0) -> object:
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
obj = hvd.broadcast_object(obj, src)
return obj

def model_to_device(self):
def model_to_device(self) -> None:
if self.root_device.type == "cuda":
# this can potentially be removed after #8312. Not done due to lack of horovod testing
torch.cuda.set_device(self.root_device)
assert self.model is not None
self.model.to(self.root_device)

def join(self):
def join(self) -> None:
if self.root_device.type == "cuda":
hvd.join(self.local_rank)
else:
hvd.join()

def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
def reduce(
self,
tensor: Union[Any, Tensor],
group: Optional[Any] = None,
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
) -> Union[Any, Tensor]:
"""Reduces a tensor from several distributed processes to one aggregated tensor.

Args:
Expand Down Expand Up @@ -196,6 +206,7 @@ def _wrap_optimizers(
self, optimizers: List[Optimizer], accumulate_grad_batches: int
) -> List["hvd.DistributedOptimizer"]:
"""Wraps optimizers to perform gradient aggregation via allreduce."""
assert self.lightning_module is not None
return [
hvd.DistributedOptimizer(
opt,
Expand Down
1 change: 1 addition & 0 deletions src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
@runtime_checkable
class _LRScheduler(_Stateful, Protocol):
optimizer: Optimizer
base_lrs: List[float]

def __init__(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> None:
...
Expand Down