Skip to content

Commit 3c9b7cb

Browse files
authored
Lite: Flatten XLAStrategy (#15838)
1 parent 5595166 commit 3c9b7cb

File tree

3 files changed

+27
-8
lines changed

3 files changed

+27
-8
lines changed

src/lightning_lite/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828
- The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))
2929

3030

31-
-
31+
- The `XLAStrategy` now inherits from `ParallelStrategy` instead of `DDPSpawnStrategy` ([#15838](https://github.com/Lightning-AI/lightning/issues/15838))
32+
3233

3334

3435
### Deprecated

src/lightning_lite/strategies/launchers/xla.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
from torch.multiprocessing import get_context
1919

2020
from lightning_lite.accelerators.tpu import _XLA_AVAILABLE
21-
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher
21+
from lightning_lite.strategies.launchers.base import _Launcher
22+
from lightning_lite.strategies.launchers.multiprocessing import _GlobalStateSnapshot
2223
from lightning_lite.utilities.apply_func import move_data_to_device
2324

2425
if TYPE_CHECKING:
2526
from lightning_lite.strategies import XLAStrategy
2627

2728

28-
class _XLALauncher(_MultiProcessingLauncher):
29+
class _XLALauncher(_Launcher):
2930
r"""Launches processes that run a given function in parallel on XLA supported hardware, and joins them all at the
3031
end.
3132
@@ -44,7 +45,8 @@ class _XLALauncher(_MultiProcessingLauncher):
4445
def __init__(self, strategy: "XLAStrategy") -> None:
4546
if not _XLA_AVAILABLE:
4647
raise ModuleNotFoundError(str(_XLA_AVAILABLE))
47-
super().__init__(strategy=strategy, start_method="fork")
48+
self._strategy = strategy
49+
self._start_method = "fork"
4850

4951
@property
5052
def is_interactive_compatible(self) -> bool:

src/lightning_lite/strategies/xla.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from lightning_lite.plugins.io.checkpoint_io import CheckpointIO
2727
from lightning_lite.plugins.io.xla import XLACheckpointIO
2828
from lightning_lite.plugins.precision import Precision
29-
from lightning_lite.strategies.ddp_spawn import DDPSpawnStrategy
29+
from lightning_lite.strategies import ParallelStrategy
3030
from lightning_lite.strategies.launchers.xla import _XLALauncher
3131
from lightning_lite.strategies.strategy import TBroadcast
3232
from lightning_lite.utilities.apply_func import apply_to_collection
@@ -38,7 +38,7 @@
3838
from torch_xla.distributed.parallel_loader import MpDeviceLoader
3939

4040

41-
class XLAStrategy(DDPSpawnStrategy):
41+
class XLAStrategy(ParallelStrategy):
4242
"""Strategy for training multiple TPU devices using the :func:`torch_xla.distributed.xla_multiprocessing.spawn`
4343
method."""
4444

@@ -55,11 +55,11 @@ def __init__(
5555
cluster_environment=XLAEnvironment(),
5656
checkpoint_io=checkpoint_io,
5757
precision=precision,
58-
start_method="fork",
5958
)
6059
self._checkpoint_io: Optional[CheckpointIO]
6160
self._backward_sync_control = None # XLA synchronizes gradients in the optimizer.step() call
6261
self._launched = False
62+
self._local_rank = 0
6363

6464
@property
6565
def root_device(self) -> torch.device:
@@ -69,6 +69,14 @@ def root_device(self) -> torch.device:
6969

7070
return xm.xla_device()
7171

72+
@property
73+
def num_processes(self) -> int:
74+
return len(self.parallel_devices) if self.parallel_devices is not None else 0
75+
76+
@property
77+
def local_rank(self) -> int:
78+
return self._local_rank
79+
7280
@property
7381
def checkpoint_io(self) -> CheckpointIO:
7482
if self._checkpoint_io is None:
@@ -93,10 +101,11 @@ def is_distributed(self) -> bool:
93101
def _configure_launcher(self) -> None:
94102
self._launcher = _XLALauncher(self)
95103

96-
def _setup_distributed(self) -> None:
104+
def setup_environment(self) -> None:
97105
self._launched = True
98106
self._set_world_ranks()
99107
rank_zero_only.rank = self.global_rank
108+
super().setup_environment()
100109

101110
def setup_module(self, module: Module) -> Module:
102111
return module
@@ -201,6 +210,13 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
201210
strategy_registry.register("tpu_spawn", cls, description=cls.__class__.__name__)
202211
strategy_registry.register("xla", cls, description=cls.__class__.__name__)
203212

213+
def _set_world_ranks(self) -> None:
214+
if self.cluster_environment is None:
215+
return
216+
self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank)
217+
self.cluster_environment.set_world_size(self.num_processes)
218+
rank_zero_only.rank = self.cluster_environment.global_rank()
219+
204220
@staticmethod
205221
def _validate_dataloader(dataloaders: DataLoader) -> None:
206222
def check_has_len(dataloader: DataLoader) -> None:

0 commit comments

Comments
 (0)