Skip to content

Commit a2130b9

Browse files
committed
fsdp support in lite
1 parent 0caf973 commit a2130b9

File tree

3 files changed

+36
-34
lines changed

3 files changed

+36
-34
lines changed

src/lightning_lite/connector.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
TorchElasticEnvironment,
4141
)
4242
from lightning_lite.plugins.precision.double import DoublePrecision
43+
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
4344
from lightning_lite.strategies import (
4445
DDPShardedStrategy,
4546
DDPSpawnShardedStrategy,
@@ -53,7 +54,7 @@
5354
XLAStrategy,
5455
)
5556
from lightning_lite.strategies.ddp_spawn import _DDP_FORK_ALIASES
56-
from lightning_lite.strategies.fsdp import FSDPStrategy
57+
from lightning_lite.strategies.fsdp import _FSDP_ALIASES, FSDPStrategy
5758
from lightning_lite.utilities import _StrategyType, rank_zero_info, rank_zero_warn
5859
from lightning_lite.utilities.device_parser import determine_root_gpu_device
5960
from lightning_lite.utilities.imports import _HPU_AVAILABLE, _IPU_AVAILABLE, _IS_INTERACTIVE, _TPU_AVAILABLE
@@ -409,6 +410,13 @@ def _check_strategy_and_fallback(self) -> None:
409410
f"You selected `Lite(strategy='{strategy_flag}')` but process forking is not supported on this"
410411
f" platform. We recommed `Lite(strategy='ddp_spawn')` instead."
411412
)
413+
if (
414+
strategy_flag in _FSDP_ALIASES or isinstance(self._strategy_flag, FSDPStrategy)
415+
) and self._accelerator_flag not in ("cuda", "gpu"):
416+
raise ValueError(
417+
f"You selected the FSDP strategy but FSDP is only available on GPU. Set `Lite(accelerator='gpu', ...)`"
418+
" to continue or select a different strategy."
419+
)
412420
if strategy_flag:
413421
self._strategy_flag = strategy_flag
414422

@@ -457,9 +465,11 @@ def _check_and_init_precision(self) -> Precision:
457465
if self._precision_flag == 16
458466
else "Using bfloat16 Automatic Mixed Precision (AMP)"
459467
)
460-
461468
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
462-
return NativeMixedPrecision(self._precision_flag, device)
469+
470+
if isinstance(self.strategy, FSDPStrategy):
471+
return FSDPPrecision(precision=self._precision_flag, device=device)
472+
return NativeMixedPrecision(precision=self._precision_flag, device=device)
463473

464474
raise RuntimeError("No precision set")
465475

src/lightning_lite/plugins/precision/fsdp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Optional, TYPE_CHECKING, Literal
14+
from typing import Literal, Optional, TYPE_CHECKING
1515

1616
import torch
1717

18-
from lightning_lite.utilities.enums import PrecisionType
1918
from lightning_lite.plugins.precision import NativeMixedPrecision
19+
from lightning_lite.utilities.enums import PrecisionType
2020
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
2121

2222
if TYPE_CHECKING:

src/lightning_lite/strategies/fsdp.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,25 @@
1313
# limitations under the License.
1414
from contextlib import contextmanager
1515
from datetime import timedelta
16-
from typing import Any, Dict, Generator, List, Optional, Union, TYPE_CHECKING
16+
from typing import Any, Dict, Generator, List, Optional, TYPE_CHECKING, Union
1717

1818
import torch
1919
from torch import Tensor
2020
from torch.distributed import default_pg_timeout
2121
from torch.nn import Module
2222

2323
from lightning_lite.accelerators import Accelerator
24-
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
24+
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment, Precision
2525
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
26-
from lightning_lite.utilities.distributed import get_default_process_group_backend_for_device, distributed_available
27-
from lightning_lite.utilities.distributed import group as _group
28-
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
29-
from lightning_lite.utilities.seed import reset_seed
30-
from lightning_lite.plugins import Precision
3126
from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
3227
from lightning_lite.strategies.parallel import ParallelStrategy
3328
from lightning_lite.strategies.strategy import TBroadcast
29+
from lightning_lite.utilities.distributed import distributed_available, get_default_process_group_backend_for_device
30+
from lightning_lite.utilities.distributed import group as _group
31+
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
3432
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
3533
from lightning_lite.utilities.rank_zero import rank_zero_only
34+
from lightning_lite.utilities.seed import reset_seed
3635

3736
if TYPE_CHECKING:
3837
from torch.distributed.fsdp.fully_sharded_data_parallel import (
@@ -43,11 +42,13 @@
4342
)
4443
from torch.distributed.fsdp.wrap import enable_wrap
4544

45+
_FSDP_ALIASES = ("fsdp", "fsdp_full_shard_offload")
46+
4647

4748
class FSDPStrategy(ParallelStrategy):
4849
r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
4950
50-
.. warning:: ``DDPFullyShardedNativeStrategy`` is in BETA and subject to change. The interface can
51+
.. warning:: ``FSDPStrategy`` is in BETA and subject to change. The interface can
5152
bring breaking changes and new features with the next release of PyTorch.
5253
5354
Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
@@ -62,30 +63,20 @@ class FSDPStrategy(ParallelStrategy):
6263
`this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
6364
6465
Arguments:
65-
cpu_offload:
66-
CPU offloading config. Currently, only parameter and gradient CPU
67-
offload is supported. It can be enabled via passing in
68-
``cpu_offload=CPUOffload(offload_params=True)``. Note that this
69-
currently implicitly enables gradient offloading to CPU in order for
70-
params and grads to be on same device to work with optimizer. This
71-
API is subject to change. Default is ``None`` in which case there
66+
cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It
67+
can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
68+
implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
69+
to work with the optimizer. This API is subject to change. Default is ``None`` in which case there
7270
will be no offloading.
73-
backward_prefetch:
74-
This is an experimental feature that is subject to change in the
75-
the near future. It allows users to enable two different backward_prefetch
76-
algorithms to help backward communication and computation overlapping.
77-
The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
78-
mixed_precision:
79-
Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
80-
or BF16 if ``precision=bf16`` unless a config is passed in.
81-
This is only available in PyTorch 1.12 and later.
82-
\**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules.
83-
71+
backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows
72+
users to enable two different backward prefetching algorithms to help backward communication and
73+
computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
74+
mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
75+
if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
76+
\**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
77+
when wrapping modules.
8478
"""
8579

86-
strategy_name = "fsdp_native"
87-
_registered_strategies: List[str] = []
88-
8980
def __init__(
9081
self,
9182
accelerator: Optional[Accelerator] = None,
@@ -169,6 +160,7 @@ def setup_module(self, module: Module) -> FullyShardedDataParallel:
169160
"""Wraps the model into a
170161
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
171162
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
163+
172164
if (
173165
any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules())
174166
and "auto_wrap_policy" in self._ddp_kwargs

0 commit comments

Comments
 (0)