1313# limitations under the License.
1414from contextlib import contextmanager
1515from datetime import timedelta
16- from typing import Any , Dict , Generator , List , Optional , Union , TYPE_CHECKING
16+ from typing import Any , Dict , Generator , List , Optional , TYPE_CHECKING , Union
1717
1818import torch
1919from torch import Tensor
2020from torch .distributed import default_pg_timeout
2121from torch .nn import Module
2222
2323from lightning_lite .accelerators import Accelerator
24- from lightning_lite .plugins import CheckpointIO , ClusterEnvironment
24+ from lightning_lite .plugins import CheckpointIO , ClusterEnvironment , Precision
2525from lightning_lite .plugins .precision .fsdp import FSDPPrecision
26- from lightning_lite .utilities .distributed import get_default_process_group_backend_for_device , distributed_available
27- from lightning_lite .utilities .distributed import group as _group
28- from lightning_lite .utilities .distributed import init_dist_connection , ReduceOp , sync_ddp_if_available
29- from lightning_lite .utilities .seed import reset_seed
30- from lightning_lite .plugins import Precision
3126from lightning_lite .strategies .launchers .subprocess_script import _SubprocessScriptLauncher
3227from lightning_lite .strategies .parallel import ParallelStrategy
3328from lightning_lite .strategies .strategy import TBroadcast
29+ from lightning_lite .utilities .distributed import distributed_available , get_default_process_group_backend_for_device
30+ from lightning_lite .utilities .distributed import group as _group
31+ from lightning_lite .utilities .distributed import init_dist_connection , ReduceOp , sync_ddp_if_available
3432from lightning_lite .utilities .imports import _TORCH_GREATER_EQUAL_1_12
3533from lightning_lite .utilities .rank_zero import rank_zero_only
34+ from lightning_lite .utilities .seed import reset_seed
3635
3736if TYPE_CHECKING :
3837 from torch .distributed .fsdp .fully_sharded_data_parallel import (
4342 )
4443 from torch .distributed .fsdp .wrap import enable_wrap
4544
45+ _FSDP_ALIASES = ("fsdp" , "fsdp_full_shard_offload" )
46+
4647
4748class FSDPStrategy (ParallelStrategy ):
4849 r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed.
4950
50- .. warning:: ``DDPFullyShardedNativeStrategy `` is in BETA and subject to change. The interface can
51+ .. warning:: ``FSDPStrategy `` is in BETA and subject to change. The interface can
5152 bring breaking changes and new features with the next release of PyTorch.
5253
5354 Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model
@@ -62,30 +63,20 @@ class FSDPStrategy(ParallelStrategy):
6263 `this tutorial <https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html>`__ for more information.
6364
6465 Arguments:
65- cpu_offload:
66- CPU offloading config. Currently, only parameter and gradient CPU
67- offload is supported. It can be enabled via passing in
68- ``cpu_offload=CPUOffload(offload_params=True)``. Note that this
69- currently implicitly enables gradient offloading to CPU in order for
70- params and grads to be on same device to work with optimizer. This
71- API is subject to change. Default is ``None`` in which case there
66+ cpu_offload: CPU offloading config. Currently, only parameter and gradient CPU offload is supported. It
67+ can be enabled via passing in ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently
68+ implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device
69+ to work with the optimizer. This API is subject to change. Default is ``None`` in which case there
7270 will be no offloading.
73- backward_prefetch:
74- This is an experimental feature that is subject to change in the
75- the near future. It allows users to enable two different backward_prefetch
76- algorithms to help backward communication and computation overlapping.
77- The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
78- mixed_precision:
79- Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16``
80- or BF16 if ``precision=bf16`` unless a config is passed in.
81- This is only available in PyTorch 1.12 and later.
82- \**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules.
83-
71+ backward_prefetch: This is an experimental feature that is subject to change in the near future. It allows
72+ users to enable two different backward prefetching algorithms to help backward communication and
73+ computation overlapping. The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``.
74+ mixed_precision: Mixed Precision config. By default, Lightning will enable FP16 if ``precision=16`` or BF16
75+ if ``precision=bf16`` unless a config is passed in. This is only available in PyTorch 1.12 and later.
76+ \**kwargs: Optional keywoard arguments passed to the FSDP context manager which will configure the FSDP class
77+ when wrapping modules.
8478 """
8579
86- strategy_name = "fsdp_native"
87- _registered_strategies : List [str ] = []
88-
8980 def __init__ (
9081 self ,
9182 accelerator : Optional [Accelerator ] = None ,
@@ -169,6 +160,7 @@ def setup_module(self, module: Module) -> FullyShardedDataParallel:
169160 """Wraps the model into a
170161 :class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
171162 from torch .distributed .fsdp .fully_sharded_data_parallel import FullyShardedDataParallel
163+
172164 if (
173165 any (isinstance (mod , FullyShardedDataParallel ) for mod in module .modules ())
174166 and "auto_wrap_policy" in self ._ddp_kwargs
0 commit comments