Skip to content

Commit 8656852

Browse files
awaelchliBorda
andauthored
FSDP (native) support for LightningLite (#14967)
Co-authored-by: Jirka Borovec <[email protected]>
1 parent 4654833 commit 8656852

File tree

21 files changed

+655
-21
lines changed

21 files changed

+655
-21
lines changed

src/lightning_lite/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2020
- Added `LightningLite.setup_module()` and `LightningLite.setup_optimizers()` to support strategies that need to set up the model before an optimizer can be created ([#15185](https://github.com/Lightning-AI/lightning/pull/15185))
2121

2222

23+
- Added support for Fully Sharded Data Parallel (FSDP) training in Lightning Lite ([#14967](https://github.com/Lightning-AI/lightning/issues/14967))
24+
25+
2326
### Changed
2427

2528
- The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))

src/lightning_lite/connector.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TorchElasticEnvironment,
4141
)
4242
from lightning_lite.plugins.precision.double import DoublePrecision
43+
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
4344
from lightning_lite.strategies import (
4445
DDPShardedStrategy,
4546
DDPSpawnShardedStrategy,
@@ -53,6 +54,7 @@
5354
XLAStrategy,
5455
)
5556
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
57+
from lightning_lite.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
5658
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
5759
from lightning_lite.utilities.device_parser import _determine_root_gpu_device
5860
from lightning_lite.utilities.imports import _IS_INTERACTIVE
@@ -417,6 +419,13 @@ def _check_strategy_and_fallback(self) -> None:
417419
f"You selected `Lite(strategy='{strategy_flag}')` but process forking is not supported on this"
418420
f" platform. We recommed `Lite(strategy='ddp_spawn')` instead."
419421
)
422+
if (
423+
strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy)
424+
) and self._accelerator_flag not in ("cuda", "gpu"):
425+
raise ValueError(
426+
"You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`"
427+
" to continue or select a different strategy."
428+
)
420429
if strategy_flag:
421430
self._strategy_flag = strategy_flag
422431

@@ -465,9 +474,11 @@ def _check_and_init_precision(self) -> Precision:
465474
if self._precision_input == 16
466475
else "Using bfloat16 Automatic Mixed Precision (AMP)"
467476
)
468-
469477
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
470-
return NativeMixedPrecision(self._precision_input, device)
478+
479+
if isinstance(self.strategy, FSDPStrategy):
480+
return FSDPPrecision(precision=self._precision_input, device=device)
481+
return NativeMixedPrecision(precision=self._precision_input, device=device)
471482

472483
raise RuntimeError("No precision set")
473484

src/lightning_lite/lite.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
DDPShardedStrategy,
3636
DDPSpawnShardedStrategy,
3737
DeepSpeedStrategy,
38+
FSDPStrategy,
3839
SingleDeviceStrategy,
3940
Strategy,
4041
XLAStrategy,
@@ -593,14 +594,20 @@ def _prepare_run_method(self) -> None:
593594
# wrap the run method, so we can inject setup logic or spawn processes for the user
594595
setattr(self, "run", partial(self._run_impl, self.run))
595596

596-
@staticmethod
597-
def _validate_setup(module: nn.Module, optimizers: Sequence[Optimizer]) -> None:
597+
def _validate_setup(self, module: nn.Module, optimizers: Sequence[Optimizer]) -> None:
598598
if isinstance(module, _LiteModule):
599599
raise ValueError("A model should be passed only once to the `setup` method.")
600600

601601
if any(isinstance(opt, _LiteOptimizer) for opt in optimizers):
602602
raise ValueError("An optimizer should be passed only once to the `setup` method.")
603603

604+
if isinstance(self._strategy, FSDPStrategy):
605+
raise RuntimeError(
606+
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
607+
" Create and set up the model first through `model = self.setup_model(model)`. Then create the"
608+
" optimizer and set it up: `optimizer = self.setup_optimizer(optimizer)`."
609+
)
610+
604611
def _validate_setup_module(self, module: nn.Module) -> None:
605612
if isinstance(module, _LiteModule):
606613
raise ValueError("A model should be passed only once to the `setup_module` method.")

src/lightning_lite/plugins/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning_lite.plugins.io.xla import XLACheckpointIO
1818
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
1919
from lightning_lite.plugins.precision.double import DoublePrecision
20+
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
2021
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
2122
from lightning_lite.plugins.precision.precision import Precision
2223
from lightning_lite.plugins.precision.tpu import TPUPrecision
@@ -33,4 +34,5 @@
3334
"NativeMixedPrecision",
3435
"TPUPrecision",
3536
"TPUBf16Precision",
37+
"FSDPPrecision",
3638
]

src/lightning_lite/plugins/precision/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from lightning_lite.plugins.precision.deepspeed import DeepSpeedPrecision
1515
from lightning_lite.plugins.precision.double import DoublePrecision
16+
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
1617
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
1718
from lightning_lite.plugins.precision.precision import Precision
1819
from lightning_lite.plugins.precision.tpu import TPUPrecision
@@ -25,4 +26,5 @@
2526
"Precision",
2627
"TPUPrecision",
2728
"TPUBf16Precision",
29+
"FSDPPrecision",
2830
]
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Optional, TYPE_CHECKING
15+
16+
import torch
17+
from typing_extensions import Literal
18+
19+
from lightning_lite.plugins.precision.native_amp import NativeMixedPrecision
20+
from lightning_lite.utilities.enums import PrecisionType
21+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
22+
23+
if TYPE_CHECKING:
24+
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
25+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
26+
27+
28+
class FSDPPrecision(NativeMixedPrecision):
29+
"""AMP for Fully Sharded Data Parallel training."""
30+
31+
def __init__(
32+
self, precision: Literal[16, "bf16"], device: str, scaler: Optional["ShardedGradScaler"] = None
33+
) -> None:
34+
if not _TORCH_GREATER_EQUAL_1_12:
35+
raise NotImplementedError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")
36+
37+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
38+
39+
super().__init__(
40+
precision=precision,
41+
device=device,
42+
scaler=(ShardedGradScaler() if scaler is None and precision == 16 else None),
43+
)
44+
45+
@property
46+
def mixed_precision_config(self) -> "MixedPrecision":
47+
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
48+
49+
if self.precision == PrecisionType.HALF:
50+
dtype = torch.float16
51+
elif self.precision == PrecisionType.BFLOAT:
52+
dtype = torch.bfloat16
53+
else:
54+
raise ValueError(f"Was unable to infer precision type, received {self.precision!r}.")
55+
return MixedPrecision(
56+
param_dtype=dtype,
57+
reduce_dtype=dtype,
58+
buffer_dtype=dtype,
59+
)

src/lightning_lite/strategies/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from lightning_lite.strategies.dp import DataParallelStrategy # noqa: F401
1818
from lightning_lite.strategies.fairscale import DDPShardedStrategy # noqa: F401
1919
from lightning_lite.strategies.fairscale import DDPSpawnShardedStrategy # noqa: F401
20+
from lightning_lite.strategies.fsdp import FSDPStrategy # noqa: F401
2021
from lightning_lite.strategies.parallel import ParallelStrategy # noqa: F401
2122
from lightning_lite.strategies.registry import _call_register_strategies, _StrategyRegistry
2223
from lightning_lite.strategies.single_device import SingleDeviceStrategy # noqa: F401

src/lightning_lite/strategies/ddp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def num_processes(self) -> int:
9292

9393
@property
9494
def distributed_sampler_kwargs(self) -> Dict[str, Any]:
95-
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
96-
return distributed_sampler_kwargs
95+
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
9796

9897
@property
9998
def process_group_backend(self) -> Optional[str]:

src/lightning_lite/strategies/ddp_spawn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,7 @@ def num_processes(self) -> int:
9999

100100
@property
101101
def distributed_sampler_kwargs(self) -> Dict[str, int]:
102-
distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
103-
return distributed_sampler_kwargs
102+
return dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank)
104103

105104
@property
106105
def process_group_backend(self) -> Optional[str]:

src/lightning_lite/strategies/deepspeed.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,7 @@ def zero_stage_3(self) -> bool:
297297

298298
@property
299299
def distributed_sampler_kwargs(self) -> Dict[str, int]:
300-
distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank)
301-
return distributed_sampler_kwargs
300+
return dict(num_replicas=self.world_size, rank=self.global_rank)
302301

303302
@property
304303
def model(self) -> "deepspeed.DeepSpeedEngine":

0 commit comments

Comments
 (0)