-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Standalone Lite: DDP Spawn Strategy Family #14675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
fe59302
7271f94
b6de11f
2ef04e6
9bbaf4f
48bc1e8
0cf9651
dc09055
6a14975
e6d619c
9055717
f016626
084bc6f
9c19b48
3d09dac
7a5a740
de78087
48ef646
6d60b96
4e018c4
983a6d7
2220350
cfce27e
4ba5809
231d8c3
6e1f03a
1505eb4
334e3cf
e832e67
a90ef22
8bf889b
1195cec
f4dd9a5
c1f029e
4cc08fe
9b8572d
06bf069
6dbe465
be60f9a
f325117
6a8812d
90466e0
b829e90
81d1b1a
b8de59f
2dcabd6
60a7479
9479bad
ab62924
a2b00bd
f917cb4
e12f142
20808b8
ae55eff
690a87f
dbf0730
cbd26d5
fb2522b
07a2956
3685ba1
d4f3a54
7d4f5b6
2fd9d73
5236eef
b5dd25d
0b473a6
78c96b1
d7e5db9
04f3f78
185dd64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,214 @@ | ||
| # Copyright The PyTorch Lightning team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from datetime import timedelta | ||
| from typing import Any, Dict, List, Optional, Union | ||
|
|
||
| import torch | ||
| import torch.distributed | ||
| from torch import Tensor | ||
| from torch.distributed.constants import default_pg_timeout | ||
| from torch.nn import Module | ||
| from torch.nn.parallel.distributed import DistributedDataParallel | ||
| from typing_extensions import Literal | ||
|
|
||
| from lightning_lite.accelerators.accelerator import Accelerator | ||
| from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment | ||
| from lightning_lite.plugins.io.checkpoint_plugin import CheckpointIO | ||
| from lightning_lite.plugins.precision import Precision | ||
| from lightning_lite.strategies.launchers.multiprocessing import _MultiProcessingLauncher | ||
| from lightning_lite.strategies.parallel import ParallelStrategy | ||
| from lightning_lite.strategies.strategy import TBroadcast | ||
| from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device | ||
| from lightning_lite.utilities.distributed import group as _group | ||
| from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available | ||
| from lightning_lite.utilities.rank_zero import rank_zero_only | ||
|
|
||
| _DDP_FORK_ALIASES = ( | ||
| "ddp_fork", | ||
| "ddp_fork_find_unused_parameters_false", | ||
| "ddp_notebook", | ||
| "ddp_notebook_find_unused_parameters_false", | ||
| ) | ||
|
|
||
|
|
||
| class DDPSpawnStrategy(ParallelStrategy): | ||
| """Spawns processes using the :func:`torch.multiprocessing.spawn` method and joins processes after training | ||
| finishes.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| accelerator: Optional[Accelerator] = None, | ||
| parallel_devices: Optional[List[torch.device]] = None, | ||
| cluster_environment: Optional[ClusterEnvironment] = None, | ||
| checkpoint_io: Optional[CheckpointIO] = None, | ||
| precision_plugin: Optional[Precision] = None, | ||
| process_group_backend: Optional[str] = None, | ||
| timeout: Optional[timedelta] = default_pg_timeout, | ||
| start_method: Literal["spawn", "fork", "forkserver"] = "spawn", | ||
| **kwargs: Any, | ||
| ): | ||
| super().__init__( | ||
| accelerator=accelerator, | ||
| parallel_devices=parallel_devices, | ||
| cluster_environment=cluster_environment, | ||
| checkpoint_io=checkpoint_io, | ||
| precision_plugin=precision_plugin, | ||
| ) | ||
| self._num_nodes = 1 | ||
| self._process_group_backend: Optional[str] = process_group_backend | ||
| self._timeout: Optional[timedelta] = timeout | ||
| self._start_method = start_method | ||
| self._ddp_kwargs = kwargs | ||
| self._local_rank = 0 | ||
|
|
||
| @property | ||
| def root_device(self) -> torch.device: | ||
| assert self.parallel_devices is not None | ||
| return self.parallel_devices[self.local_rank] | ||
|
|
||
| @property | ||
| def num_nodes(self) -> int: | ||
| return self._num_nodes | ||
|
|
||
| @num_nodes.setter | ||
| def num_nodes(self, num_nodes: int) -> None: | ||
| # note that world ranks is related to num_nodes, when resetting it, need to reset world ranks | ||
| self._num_nodes = num_nodes | ||
|
|
||
| @property | ||
| def num_processes(self) -> int: | ||
| return len(self.parallel_devices) if self.parallel_devices is not None else 0 | ||
|
|
||
| @property | ||
| def distributed_sampler_kwargs(self) -> Dict[str, int]: | ||
| distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) | ||
| return distributed_sampler_kwargs | ||
|
|
||
| @property | ||
| def process_group_backend(self) -> Optional[str]: | ||
| return self._process_group_backend | ||
|
|
||
| @property | ||
| def local_rank(self) -> int: | ||
| return self._local_rank | ||
|
|
||
| def _configure_launcher(self) -> None: | ||
| self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method) | ||
|
|
||
| def setup_environment(self) -> None: | ||
| self._setup_distributed() | ||
| super().setup_environment() | ||
|
|
||
| def setup_module(self, module: Module) -> Module: | ||
| return DistributedDataParallel(module=module, device_ids=self._determine_ddp_device_ids(), **self._ddp_kwargs) | ||
|
|
||
| def module_to_device(self, module: Module) -> None: | ||
| if self.root_device.type == "cuda": | ||
| # TODO(lite): This should be handled outside module_to_device, by a call to accelerator.setup_device() | ||
| # set the device on the spawned subprocesses | ||
| torch.cuda.set_device(self.root_device) | ||
| module.to(self.root_device) | ||
|
|
||
| def reduce( | ||
| self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" | ||
| ) -> Tensor: | ||
| """Reduces a tensor from several distributed processes to one aggregated tensor. | ||
|
|
||
| Args: | ||
| tensor: the tensor to sync and reduce | ||
| group: the process group to gather results from. Defaults to all processes (world) | ||
| reduce_op: the reduction operation. Defaults to 'mean'/'avg'. | ||
| Can also be a string 'sum' to calculate the sum during reduction. | ||
|
|
||
| Return: | ||
| reduced value, except when the input was not a tensor the output remains is unchanged | ||
| """ | ||
| if isinstance(tensor, Tensor): | ||
| tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) | ||
| return tensor | ||
|
|
||
| def barrier(self, *args: Any, **kwargs: Any) -> None: | ||
| if not distributed_available(): | ||
| return | ||
| if torch.distributed.get_backend() == "nccl": | ||
| torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) | ||
| else: | ||
| torch.distributed.barrier() | ||
|
|
||
| def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: | ||
| if not distributed_available(): | ||
| return obj | ||
| obj = [obj] | ||
| if self.global_rank != src: | ||
| obj = [None] # type: ignore[list-item] | ||
| torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) | ||
| return obj[0] | ||
|
|
||
| @classmethod | ||
| def register_strategies(cls, strategy_registry: Dict) -> None: | ||
| entries = ( | ||
| ("ddp_spawn", "spawn"), | ||
| ("ddp_fork", "fork"), | ||
| ("ddp_notebook", "fork"), | ||
| ) | ||
| for name, start_method in entries: | ||
| strategy_registry.register( | ||
| name, | ||
| cls, | ||
| description=f"DDP strategy with `start_method` '{start_method}'", | ||
| start_method=start_method, | ||
| ) | ||
|
|
||
| entries = ( | ||
| ("ddp_spawn_find_unused_parameters_false", "spawn"), | ||
| ("ddp_fork_find_unused_parameters_false", "fork"), | ||
| ("ddp_notebook_find_unused_parameters_false", "fork"), | ||
| ) | ||
| for name, start_method in entries: | ||
| strategy_registry.register( | ||
| name, | ||
| cls, | ||
| description=f"DDP strategy with `find_unused_parameters` as False and `start_method` '{start_method}'", | ||
| find_unused_parameters=False, | ||
| start_method=start_method, | ||
| ) | ||
|
|
||
| def _setup_distributed(self) -> None: | ||
| self._set_world_ranks() | ||
| rank_zero_only.rank = self.global_rank | ||
| self._process_group_backend = self._get_process_group_backend() | ||
| assert self.cluster_environment is not None | ||
| init_dist_connection( | ||
| self.cluster_environment, | ||
| self._process_group_backend, | ||
| self.global_rank, | ||
| self.world_size, | ||
| timeout=self._timeout, | ||
| ) | ||
|
|
||
| def _get_process_group_backend(self) -> str: | ||
| return self._process_group_backend or get_default_process_group_backend_for_device(self.root_device) | ||
|
|
||
| def _set_world_ranks(self, process_idx: int = 0) -> None: | ||
| self._local_rank = process_idx | ||
| if self.cluster_environment is None: | ||
| return | ||
| self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) | ||
| self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) | ||
| rank_zero_only.rank = self.cluster_environment.global_rank() | ||
|
|
||
| def _determine_ddp_device_ids(self) -> Optional[List[int]]: | ||
| if self.root_device.type == "cpu": | ||
| return None | ||
| return [self.root_device.index] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -86,8 +86,7 @@ def _wrapping_function( | |
| return_queue: SimpleQueue, | ||
| global_states: Optional[_GlobalStateSnapshot] = None, | ||
| ) -> None: | ||
| # TODO(lite): Update worker setup once TPUSpawn strategy is in Lite | ||
| self._strategy._worker_setup(process_idx) | ||
| self._strategy._local_rank = process_idx | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @awaelchli I believe this is not correct, as this no longer sets File "/home/runner/work/lightning/tests/tests_lite/strategies/test_xla.py", line 17, in broadcast_on_tpu_fn
result = strategy.broadcast(obj)
File "/home/runner/work/lightning/src/lightning_lite/strategies/xla.py", line 146, in broadcast
data_tensor = torch.tensor(data, device=self.root_device, dtype=torch.float)
File "/home/runner/work/lightning/src/lightning_lite/strategies/xla.py", line 72, in root_device
raise RuntimeError("Accessing the XLA device before processes have spawned is not allowed.")I'm a bit confused about what's the best way to do this. Do If yes:
Also, why did Lite remove This blocks #14926
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was done in reaction to your comment: #11073 (comment) I commented on #14926 that maybe all that is missing in the test is a strategy.setup_environment().
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My motivation with #11073 was always to simplify these things so that these questions wouldn't come up in the first place. But nobody wants to merge it lol, already posted 3x times in waiting pr over the last 5 months or so.
For the multiprocessing launcher, the information of local rank can only come from the launcher directly. So the answer here is no.
If #11073 lands both codes would be identical in this regard.
If #11073 lands both codes would be identical in this regard. |
||
| results = function(*args, **kwargs) | ||
|
|
||
| if self._strategy.local_rank == 0: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.