|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | from contextlib import ExitStack |
15 | | -from typing import Optional |
| 15 | +from typing import Any, Optional, Union |
16 | 16 |
|
17 | 17 | import torch |
18 | 18 | from torch.optim.lr_scheduler import _LRScheduler |
19 | 19 |
|
20 | | -from pytorch_lightning.accelerators.accelerator import Accelerator |
| 20 | +from pytorch_lightning.accelerators.accelerator import Accelerator, ReduceOp |
21 | 21 | from pytorch_lightning.utilities import AMPType |
22 | 22 | from pytorch_lightning.utilities.distributed import rank_zero_only |
23 | 23 |
|
@@ -161,3 +161,41 @@ def barrier(self, name: Optional[str] = None): |
161 | 161 | def broadcast(self, obj, src=0): |
162 | 162 | obj = hvd.broadcast_object(obj, src) |
163 | 163 | return obj |
| 164 | + |
| 165 | + def gather_all_tensors(self, result: Union[torch.Tensor], group: Optional[Any] = None): |
| 166 | + if group is not None: |
| 167 | + raise ValueError( |
| 168 | + "Horovod does not support allgather using a subcommunicator at this time. " |
| 169 | + "Unset `group`." |
| 170 | + ) |
| 171 | + |
| 172 | + if len(result.shape) == 0: |
| 173 | + # Convert scalars to single dimension tensors |
| 174 | + result = result.reshape(1) |
| 175 | + |
| 176 | + # sync and gather all |
| 177 | + hvd.join() |
| 178 | + gathered = hvd.allgather(result) |
| 179 | + gathered_result = list(gathered.split(1, dim=0)) |
| 180 | + return gathered_result |
| 181 | + |
| 182 | + def sync_tensor(self, |
| 183 | + tensor: Union[torch.Tensor], |
| 184 | + group: Optional[Any] = None, |
| 185 | + reduce_op: Optional[Union[ReduceOp, str]] = None) -> torch.Tensor: |
| 186 | + if group is not None: |
| 187 | + raise ValueError( |
| 188 | + "Horovod does not support allreduce using a subcommunicator at this time. " |
| 189 | + "Unset `group`." |
| 190 | + ) |
| 191 | + |
| 192 | + if reduce_op is None or reduce_op == "sum": |
| 193 | + reduce_op = hvd.Sum |
| 194 | + elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"): |
| 195 | + reduce_op = hvd.Average |
| 196 | + else: |
| 197 | + raise ValueError(f"unrecognized `reduce_op`: {reduce_op}") |
| 198 | + |
| 199 | + # sync all processes before reduction |
| 200 | + hvd.join() |
| 201 | + return hvd.allreduce(tensor, op=reduce_op) |
0 commit comments