Skip to content

Commit 4cf7d3b

Browse files
committed
2/n Consolidate collective functions - collective base and subclasses
1 parent 089ae9b commit 4cf7d3b

File tree

5 files changed

+376
-0
lines changed

5 files changed

+376
-0
lines changed
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from abc import ABC, abstractmethod
15+
from typing import Any, Optional, Union
16+
17+
import torch
18+
19+
20+
class Collective(ABC):
21+
"""Base class for collective functions for training type plugins."""
22+
23+
@abstractmethod
24+
def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None:
25+
"""Forces all possibly joined processes to wait for each other."""
26+
27+
@abstractmethod
28+
def broadcast(self, obj: object, src: int = 0) -> object:
29+
"""Broadcasts an object to all processes."""
30+
31+
@abstractmethod
32+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
33+
"""Perform a all_gather on all processes."""
34+
35+
@abstractmethod
36+
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
37+
"""Reduces the given tensor (e.g. across GPUs/processes).
38+
39+
Args:
40+
tensor: the tensor to sync and reduce
41+
*args: plugin-specific positional arguments
42+
**kwargs: plugin-specific keyword arguments
43+
"""
44+
45+
def reduce_boolean_decision(self, decision: bool) -> bool:
46+
"""Reduce the early stopping decision across all processes."""
47+
return decision
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import io
15+
from typing import Any, Optional, Union
16+
17+
import torch
18+
19+
from pytorch_lightning.plugins.collective import Collective
20+
from pytorch_lightning.utilities import _HOROVOD_AVAILABLE
21+
from pytorch_lightning.utilities.distributed import ReduceOp
22+
from pytorch_lightning.utilities.types import _TPU_AVAILABLE
23+
24+
if _TPU_AVAILABLE:
25+
import torch_xla.core.xla_model as xm
26+
from torch_xla.core.xla_model import rendezvous
27+
else:
28+
xm, rendezvous = [None] * 4
29+
30+
if _HOROVOD_AVAILABLE:
31+
import horovod.torch as hvd
32+
33+
34+
class HorovodCollective(Collective):
35+
"""Base class for collective functions for training type plugins."""
36+
37+
def __init__(
38+
self,
39+
on_gpu: Optional[bool] = False,
40+
local_rank: Optional[int] = 0,
41+
):
42+
self._on_gpu = on_gpu
43+
self._local_rank = local_rank
44+
45+
def join(self):
46+
"""Horovod function that indicates that the rank finished processing data.
47+
48+
All ranks that did not call join() continue to process allreduce operations. This function blocks Python thread
49+
until all ranks join.
50+
"""
51+
if self.on_gpu:
52+
hvd.join(self.local_rank)
53+
else:
54+
hvd.join()
55+
56+
def barrier(self, name: Optional[str] = None) -> None:
57+
if self.is_distributed:
58+
rendezvous(name)
59+
60+
def broadcast(self, obj: object, src: int = 0) -> object:
61+
if not self.is_distributed:
62+
return obj
63+
buffer = io.BytesIO()
64+
torch.save(obj, buffer)
65+
data = bytearray(buffer.getbuffer())
66+
data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
67+
data = xm.all_gather(data_tensor)
68+
buffer = io.BytesIO(data.cpu().byte().numpy())
69+
obj = torch.load(buffer)
70+
return obj
71+
72+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
73+
"""
74+
Function to gather a tensor from several distributed processes
75+
Args:
76+
tensor: tensor of shape (batch, ...)
77+
group: not available with TPUs
78+
sync_grads: not available with TPUs
79+
Return:
80+
A tensor of shape (world_size, batch, ...)
81+
"""
82+
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
83+
tensor = tensor.unsqueeze(0)
84+
return self._xm.all_gather(tensor)
85+
86+
def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean"):
87+
"""Reduces a tensor from several distributed processes to one aggregated tensor.
88+
89+
Args:
90+
tensor: the tensor to sync and reduce
91+
group: the process group to gather results from. Defaults to all processes (world)
92+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
93+
Can also be a string 'sum' to calculate the sum during reduction.
94+
95+
Return:
96+
reduced value, except when the input was not a tensor the output remains is unchanged
97+
"""
98+
if group is not None:
99+
raise ValueError("Horovod does not support allreduce using a subcommunicator at this time. Unset `group`.")
100+
101+
if reduce_op in (None, "avg", "mean"):
102+
reduce_op = hvd.Average
103+
elif reduce_op in ("sum", ReduceOp.SUM):
104+
reduce_op = hvd.Sum
105+
else:
106+
raise ValueError(f"unrecognized `reduce_op`: {reduce_op}")
107+
108+
# sync all processes before reduction
109+
self.join()
110+
return hvd.allreduce(tensor, op=reduce_op)
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Optional, Union
15+
16+
import torch
17+
18+
from pytorch_lightning.plugins.collective import Collective
19+
20+
21+
class SingleNodeCollective(Collective):
22+
"""Base class for collective functions for training type plugins."""
23+
24+
def barrier(self, name: Optional[str] = None, *args, **kwargs) -> None:
25+
"""Forces all possibly joined processes to wait for each other."""
26+
pass
27+
28+
def broadcast(self, obj: object, src: int = 0) -> object:
29+
"""Broadcasts an object to all processes."""
30+
return obj
31+
32+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
33+
"""Perform a all_gather on all processes."""
34+
return tensor
35+
36+
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
37+
"""Reduces the given tensor (e.g. across GPUs/processes).
38+
39+
Args:
40+
tensor: the tensor to sync and reduce
41+
*args: plugin-specific positional arguments
42+
**kwargs: plugin-specific keyword arguments
43+
"""
44+
return tensor
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Optional, Union
15+
16+
import torch
17+
import torch.distributed
18+
19+
from pytorch_lightning.plugins.collective import Collective
20+
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_8
21+
from pytorch_lightning.utilities.apply_func import apply_to_collection
22+
from pytorch_lightning.utilities.distributed import (
23+
all_gather_ddp_if_available,
24+
distributed_available,
25+
ReduceOp,
26+
sync_ddp_if_available,
27+
)
28+
from pytorch_lightning.utilities.types import _METRIC_COLLECTION
29+
30+
31+
class TorchCollective(Collective):
32+
"""Collective interface for DDP, DDPSpawn, DP and DDP2."""
33+
34+
def __init__(self, local_reduce=False):
35+
""".. note::
36+
37+
DDP and DDPSpawn sync accross multiple nodes/devices, local_reduce = False
38+
DP run reduce in on node, local_reduce = True
39+
DDP2 behaves like DP in one node, local_reduce = True
40+
41+
local_reduce set in Plugins.setup() functions
42+
"""
43+
self.local_reduce = local_reduce
44+
45+
def barrier(self, *args, **kwargs) -> None:
46+
if not distributed_available():
47+
return
48+
if _TORCH_GREATER_EQUAL_1_8 and torch.distributed.get_backend() == "nccl":
49+
torch.distributed.barrier(device_ids=self.determine_ddp_device_ids())
50+
else:
51+
torch.distributed.barrier()
52+
53+
def broadcast(self, obj: object, src: int = 0) -> object:
54+
if not distributed_available():
55+
return obj
56+
return self.dist.broadcast(obj)
57+
58+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
59+
"""Perform a all_gather on all processes."""
60+
return all_gather_ddp_if_available(tensor, group=group, sync_grads=sync_grads)
61+
62+
def reduce(
63+
self, tensor: _METRIC_COLLECTION, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean"
64+
) -> torch.Tensor:
65+
"""Reduces the given tensor (e.g. across GPUs/processes)
66+
67+
If local_reduce = True (dp and ddp2), reduces tensor from all local processes.
68+
69+
If local_reduce = False (ddp, ddpspawning and extentions), reduces a tensor from several distributed processes
70+
Args:
71+
tensor: the tensor to sync and reduce
72+
group: the process group to gather results from. Defaults to all processes (world)
73+
reduce_op: the reduction operation. Defaults to 'mean'/'avg'.
74+
Can also be a string 'sum' to calculate the sum during reduction.
75+
76+
Return:
77+
reduced value, except when the input was not a tensor the output remains is unchanged
78+
"""
79+
if self.local_reduce:
80+
81+
def mean(t: torch.Tensor) -> torch.Tensor:
82+
original_dtype = t.dtype
83+
return t.float().mean().to(original_dtype)
84+
85+
return apply_to_collection(tensor, torch.Tensor, mean)
86+
87+
if isinstance(tensor, torch.Tensor):
88+
tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
89+
return tensor
90+
91+
def reduce_boolean_decision(self, decision: bool) -> bool:
92+
decision = torch.tensor(int(decision), device=self.lightning_module.device)
93+
decision = self.reduce(decision, reduce_op=ReduceOp.SUM)
94+
decision = bool(decision == self.world_size)
95+
return decision
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import io
15+
from typing import Any, Optional, Union
16+
17+
import torch
18+
19+
from pytorch_lightning.plugins.collective import Collective
20+
from pytorch_lightning.utilities.distributed import ReduceOp
21+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
22+
from pytorch_lightning.utilities.types import _TPU_AVAILABLE
23+
24+
if _TPU_AVAILABLE:
25+
import torch_xla.core.xla_model as xm
26+
from torch_xla.core.xla_model import rendezvous
27+
else:
28+
xm, rendezvous = [None] * 4
29+
30+
31+
class TPUCollective(Collective):
32+
"""Base class for collective functions for training type plugins."""
33+
34+
def barrier(self, name: Optional[str] = None) -> None:
35+
if self.is_distributed:
36+
rendezvous(name)
37+
38+
def broadcast(self, obj: object, src: int = 0) -> object:
39+
if not self.is_distributed:
40+
return obj
41+
buffer = io.BytesIO()
42+
torch.save(obj, buffer)
43+
data = bytearray(buffer.getbuffer())
44+
data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
45+
data = xm.all_gather(data_tensor)
46+
buffer = io.BytesIO(data.cpu().byte().numpy())
47+
obj = torch.load(buffer)
48+
return obj
49+
50+
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
51+
"""
52+
Function to gather a tensor from several distributed processes
53+
Args:
54+
tensor: tensor of shape (batch, ...)
55+
group: not available with TPUs
56+
sync_grads: not available with TPUs
57+
Return:
58+
A tensor of shape (world_size, batch, ...)
59+
"""
60+
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
61+
tensor = tensor.unsqueeze(0)
62+
return self._xm.all_gather(tensor)
63+
64+
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
65+
if not isinstance(output, torch.Tensor):
66+
output = torch.tensor(output, device=self.lightning_module.device)
67+
68+
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
69+
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
70+
if _invalid_reduce_op or _invalid_reduce_op_str:
71+
raise MisconfigurationException(
72+
"Currently, TPUSpawn TrainingTypePlugin only support `sum`, `mean`, `avg` reduce operation."
73+
)
74+
75+
output = xm.mesh_reduce("reduce", output, sum)
76+
77+
if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
78+
output = output / self.world_size
79+
80+
return output

0 commit comments

Comments
 (0)