Skip to content

Commit eacdbf2

Browse files
committed
Refactor collective functions, call training_type_plugin directly
1 parent 2b2537d commit eacdbf2

File tree

12 files changed

+732
-50
lines changed

12 files changed

+732
-50
lines changed

pytorch_lightning/callbacks/timer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def on_load_checkpoint(
165165

166166
def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
167167
should_stop = self.time_elapsed() >= self._duration
168-
should_stop = trainer.accelerator.broadcast(should_stop)
168+
should_stop = trainer.accelerator.training_type_plugin.broadcast(should_stop)
169169
trainer.should_stop = trainer.should_stop or should_stop
170170
if should_stop and self._verbose:
171171
elapsed = timedelta(seconds=int(self.time_elapsed(RunningStage.TRAINING)))

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def all_gather(
594594
the output will also be a collection with tensors of this shape.
595595
"""
596596
group = group if group is not None else torch.distributed.group.WORLD
597-
all_gather = self.trainer.accelerator.all_gather
597+
all_gather = self.trainer.accelerator.training_type_plugin.all_gather
598598
data = convert_to_tensors(data, device=self.device)
599599
return apply_to_collection(data, torch.Tensor, all_gather, group=group, sync_grads=sync_grads)
600600

pytorch_lightning/plugins/training_type/ddp.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131

3232
import pytorch_lightning as pl
3333
from pytorch_lightning.core.optimizer import LightningOptimizer
34-
from pytorch_lightning.distributed import LightningDistributed
3534
from pytorch_lightning.overrides import LightningDistributedModule
3635
from pytorch_lightning.overrides.distributed import prepare_for_backward
36+
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
3737
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3838
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3939
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
@@ -48,13 +48,9 @@
4848
rank_zero_deprecation,
4949
rank_zero_warn,
5050
)
51-
from pytorch_lightning.utilities.distributed import (
52-
distributed_available,
53-
init_ddp_connection,
54-
rank_zero_only,
55-
ReduceOp,
56-
sync_ddp_if_available,
57-
)
51+
from pytorch_lightning.utilities.distributed import distributed_available
52+
from pytorch_lightning.utilities.distributed import group as _group
53+
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
5854
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
5955
from pytorch_lightning.utilities.seed import reset_seed
6056
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -116,7 +112,6 @@ def __init__(
116112
" Notice that it will be overriden by the trainer setting."
117113
)
118114
self._sync_batchnorm = sync_batchnorm or False
119-
self.dist = LightningDistributed()
120115
self.num_processes = len(self.parallel_devices) if self.parallel_devices is not None else 0
121116
self._ddp_kwargs = kwargs
122117
self._task_idx = None
@@ -270,8 +265,6 @@ def setup_distributed(self):
270265
init_ddp_connection(self.cluster_environment, self.torch_distributed_backend)
271266

272267
# set the ranks and devices
273-
self.dist.rank = self.global_rank
274-
self.dist.device = self.root_device
275268

276269
def _check_can_spawn_children(self):
277270
if self.local_rank != 0:
@@ -403,7 +396,15 @@ def barrier(self, *args, **kwargs) -> None:
403396
torch.distributed.barrier()
404397

405398
def broadcast(self, obj: object, src: int = 0) -> object:
406-
return self.dist.broadcast(obj)
399+
if not distributed_available():
400+
raise RuntimeError(
401+
"DDPSpawn is not initialized and torch.distributed is not avalible, can not broadcast object"
402+
)
403+
obj = [obj]
404+
if self.global_rank != 0:
405+
obj = [None] * len(obj)
406+
broadcast_object_list(obj, src, group=_group.WORLD)
407+
return obj[0]
407408

408409
def pre_backward(self, closure_loss: torch.Tensor) -> None:
409410
"""Run before precision plugin executes backward."""

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
from torch.nn.parallel.distributed import DistributedDataParallel
2525

2626
import pytorch_lightning as pl
27-
from pytorch_lightning.distributed.dist import LightningDistributed
2827
from pytorch_lightning.overrides import LightningDistributedModule
2928
from pytorch_lightning.overrides.distributed import prepare_for_backward
29+
from pytorch_lightning.overrides.torch_distributed import broadcast_object_list
3030
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
3131
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
3232
from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
@@ -40,13 +40,9 @@
4040
from pytorch_lightning.utilities.apply_func import apply_to_collection
4141
from pytorch_lightning.utilities.cloud_io import atomic_save
4242
from pytorch_lightning.utilities.cloud_io import load as pl_load
43-
from pytorch_lightning.utilities.distributed import (
44-
distributed_available,
45-
init_ddp_connection,
46-
rank_zero_only,
47-
ReduceOp,
48-
sync_ddp_if_available,
49-
)
43+
from pytorch_lightning.utilities.distributed import distributed_available
44+
from pytorch_lightning.utilities.distributed import group as _group
45+
from pytorch_lightning.utilities.distributed import init_ddp_connection, rank_zero_only, ReduceOp, sync_ddp_if_available
5046
from pytorch_lightning.utilities.model_helpers import is_overridden
5147
from pytorch_lightning.utilities.seed import reset_seed
5248
from pytorch_lightning.utilities.types import STEP_OUTPUT
@@ -93,7 +89,6 @@ def __init__(
9389
)
9490
self._sync_batchnorm = sync_batchnorm or False
9591
self._ddp_kwargs = kwargs
96-
self.dist = LightningDistributed()
9792
self.num_processes = len(parallel_devices) if parallel_devices is not None else 0
9893
self.mp_queue = None
9994
self._ddp_comm_state = ddp_comm_state
@@ -193,10 +188,6 @@ def new_process(self, process_idx: int, trainer: "pl.Trainer", mp_queue: SimpleQ
193188
# ... need to double check that it is the correct place
194189
# self.trainer.call_setup_hook(self.model)
195190

196-
# set the ranks and devices
197-
self.dist.rank = self.global_rank
198-
self.dist.device = self.root_device
199-
200191
# move the model to the correct device
201192
self.model_to_device()
202193

@@ -324,7 +315,11 @@ def barrier(self, *args, **kwargs) -> None:
324315
def broadcast(self, obj: object, src: int = 0) -> object:
325316
if not distributed_available():
326317
return obj
327-
return self.dist.broadcast(obj)
318+
obj = [obj]
319+
if self.global_rank != 0:
320+
obj = [None] * len(obj)
321+
broadcast_object_list(obj, src, group=_group.WORLD)
322+
return obj[0]
328323

329324
def model_to_device(self):
330325
if self.root_device.type == "cuda":

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,6 @@ def setup_distributed(self):
342342

343343
self._init_deepspeed_distributed()
344344

345-
# set the ranks and devices
346-
self.dist.rank = self.global_rank
347-
self.dist.device = self.root_device
348345
if not self._config_initialized:
349346
self._format_config()
350347
self._config_initialized = True

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import contextlib
1515
from abc import ABC, abstractmethod
16-
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, TypeVar, Union
16+
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, TypeVar, Union
1717

1818
import torch
1919
from torch import Tensor
@@ -25,6 +25,7 @@
2525
from pytorch_lightning.overrides.base import unwrap_lightning_module
2626
from pytorch_lightning.plugins import TorchCheckpointIO
2727
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
28+
from pytorch_lightning.utilities.distributed import ReduceOp
2829
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT, _PATH, _PREDICT_OUTPUT
2930

3031
TBroadcast = TypeVar("T")
@@ -91,26 +92,53 @@ def is_global_zero(self) -> bool:
9192
"""Whether the current process is the rank zero process not only on the local node, but for all nodes."""
9293

9394
@abstractmethod
94-
def reduce(self, tensor: Union[torch.Tensor, Any], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Any]:
95+
def reduce(
96+
self,
97+
tensor: Union[torch.Tensor, Any],
98+
group: Optional[Any] = None,
99+
reduce_op: Optional[Union[ReduceOp, str]] = "mean",
100+
) -> Union[torch.Tensor, Any]:
95101
"""Reduces the given tensor (e.g. across GPUs/processes).
96102
97103
Args:
98104
tensor: the tensor to sync and reduce
105+
group: the process group to reduce
106+
reduce_op: the reduction operation. Defaults to 'mean'.
107+
Can also be a string 'sum' or ReduceOp.
99108
*args: plugin-specific positional arguments
100109
**kwargs: plugin-specific keyword arguments
101110
"""
102111

103112
@abstractmethod
104113
def barrier(self, name: Optional[str] = None) -> None:
105-
"""Forces all possibly joined processes to wait for each other."""
114+
"""Synchronizes all processes which blocks processes until the whole group enters this function.
115+
116+
Args:
117+
name: a str pass into barrier. Only torch xla respect this param
118+
"""
106119

107120
@abstractmethod
108-
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
109-
"""Broadcasts an object to all processes."""
121+
def broadcast(self, obj: object, src: int = 0) -> object:
122+
"""Broadcasts an object to all processes.
123+
124+
Args:
125+
obj: the object to broadcast
126+
src: source rank.
127+
"""
110128

111129
@abstractmethod
112-
def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> torch.Tensor:
113-
"""Perform a all_gather on all processes."""
130+
def all_gather(
131+
self, tensor: torch.Tensor, group: Optional[Any] = None, sync_grads: bool = False
132+
) -> Union[List[torch.Tensor], torch.Tensor]:
133+
"""Perform a all_gather on all processes.
134+
135+
Args:
136+
tensor: the tensor to all_gather
137+
group: the process group to gather results from
138+
sync_grads: flag that allows users to synchronize gradients for all_gather op
139+
140+
Returns: a tensor (torch distributed) or a list of tensor (horovod)
141+
"""
114142

115143
def reduce_boolean_decision(self, decision: bool) -> bool:
116144
"""Reduce the early stopping decision across all processes."""

pytorch_lightning/trainer/data_loading.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def request_dataloader(
531531
dataloader = self.call_hook(hook, pl_module=model)
532532
if isinstance(dataloader, tuple):
533533
dataloader = list(dataloader)
534-
self.accelerator.barrier("get_dataloaders")
534+
self.accelerator.training_type_plugin.barrier("get_dataloaders")
535535
return dataloader
536536

537537
@staticmethod

0 commit comments

Comments
 (0)