|
13 | 13 | # limitations under the License. |
14 | 14 | import contextlib |
15 | 15 | from abc import ABC, abstractmethod |
16 | | -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union |
| 16 | +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, TypeVar, Union |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | from torch import Tensor |
|
25 | 25 | from pytorch_lightning.overrides.base import unwrap_lightning_module |
26 | 26 | from pytorch_lightning.plugins import TorchCheckpointIO |
27 | 27 | from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO |
| 28 | +from pytorch_lightning.utilities.distributed import ReduceOp |
28 | 29 | from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT |
29 | 30 |
|
30 | 31 | TBroadcast = TypeVar("T") |
@@ -91,26 +92,53 @@ def is_global_zero(self) -> bool: |
91 | 92 | """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" |
92 | 93 |
|
93 | 94 | @abstractmethod |
94 | | - def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]: |
| 95 | + def reduce( |
| 96 | + self, |
| 97 | + tensor: Union[torch.Tensor, Any], |
| 98 | + group: Optional[Any] = None, |
| 99 | + reduce_op: Optional[Union[ReduceOp, str]] = "mean", |
| 100 | + ) -> Union[torch.Tensor, Any]: |
95 | 101 | """Reduces the given tensor (e.g. across GPUs/processes). |
96 | 102 |
|
97 | 103 | Args: |
98 | 104 | tensor: the tensor to sync and reduce |
| 105 | + group: the process group to reduce |
| 106 | + reduce_op: the reduction operation. Defaults to 'mean'. |
| 107 | + Can also be a string 'sum' or ReduceOp. |
99 | 108 | *args: plugin-specific positional arguments |
100 | 109 | **kwargs: plugin-specific keyword arguments |
101 | 110 | """ |
102 | 111 |
|
103 | 112 | @abstractmethod |
104 | 113 | def barrier(self, name: Optional[str] = None) -> None: |
105 | | - """Forces all possibly joined processes to wait for each other.""" |
| 114 | + """Synchronizes all processes which blocks processes until the whole group enters this function. |
| 115 | +
|
| 116 | + Args: |
| 117 | + name: a str pass into barrier. Only torch xla respect this param |
| 118 | + """ |
106 | 119 |
|
107 | 120 | @abstractmethod |
108 | | - def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: |
109 | | - """Broadcasts an object to all processes.""" |
| 121 | + def broadcast(self, obj: object, src: int = 0) -> object: |
| 122 | + """Broadcasts an object to all processes. |
| 123 | +
|
| 124 | + Args: |
| 125 | + obj: the object to broadcast |
| 126 | + src: source rank. |
| 127 | + """ |
110 | 128 |
|
111 | 129 | @abstractmethod |
112 | | - def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor: |
113 | | - """Perform a all_gather on all processes.""" |
| 130 | + def all_gather( |
| 131 | + self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False |
| 132 | + ) -> Union[List[torch.Tensor], torch.Tensor]: |
| 133 | + """Perform a all_gather on all processes. |
| 134 | +
|
| 135 | + Args: |
| 136 | + tensor: the tensor to all_gather |
| 137 | + group: the process group to gather results from |
| 138 | + sync_grads: flag that allows users to synchronize gradients for all_gather op |
| 139 | +
|
| 140 | + Returns: a tensor (torch distributed) or a list of tensor (horovod) |
| 141 | + """ |
114 | 142 |
|
115 | 143 | def reduce_boolean_decision(self, decision: bool) -> bool: |
116 | 144 | """Reduce the early stopping decision across all processes.""" |
|
0 commit comments