-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Proposed refactoring or deprecation
Current Horovod training type plugin collective function all_gather() is calling horovod all_gather() and convert result to a list
-
Horovod support allgather_object which support return a list of tensor
https://horovod.readthedocs.io/en/stable/_modules/horovod/torch/functions.html#allgather_object -
Revisit the use cases to see do we need to return a list of tensor here, as the training_type_plugin all_gather() api is defined to return a tensor now.
Motivation
Have correct and consistent collective behavior
Pitch
def all_gather(
self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False
) -> torch.Tensor:
if group is not None and group != dist_group.WORLD:
raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")
if len(result.shape) == 0:
# Convert scalars to single dimension tensors
result = result.reshape(1)
# sync and gather all
self.join()
gathered = hvd.allgather(result)
gathered_result = list(gathered.split(1, dim=0))
return gathered_result
[RFC]
Option 1
def all_gather(
self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False
) -> torch.Tensor:
if group is not None and group != dist_group.WORLD:
raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")
if len(result.shape) == 0:
# Convert scalars to single dimension tensors
result = result.reshape(1)
# sync and gather all
self.join()
return hvd. allgather_object(result)
Option 2
def all_gather(
self, result: Union[torch.Tensor], group: Optional[Any] = dist_group.WORLD, sync_grads: bool = False
) -> torch.Tensor:
if group is not None and group != dist_group.WORLD:
raise ValueError("Horovod does not support allgather using a subcommunicator at this time. Unset `group`.")
if len(result.shape) == 0:
# Convert scalars to single dimension tensors
result = result.reshape(1)
# sync and gather all
self.join()
return hvd.allgather(result)t)
Additional context
If you enjoy Lightning, check out our other projects! ⚡
-
Metrics: Machine learning metrics for distributed, scalable PyTorch applications.
-
Flash: The fastest way to get a Lightning baseline! A collection of tasks for fast prototyping, baselining, finetuning and solving problems with deep learning
-
Bolts: Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch
-
Lightning Transformers: Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.