Skip to content

Commit 0caf973

Browse files
committed
fsdp
1 parent 043783e commit 0caf973

File tree

2 files changed

+35
-50
lines changed

2 files changed

+35
-50
lines changed

src/lightning_lite/plugins/precision/fsdp.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,37 +11,38 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Optional, Union
14+
from typing import Optional, TYPE_CHECKING, Literal
1515

1616
import torch
1717

1818
from lightning_lite.utilities.enums import PrecisionType
19-
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
20-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
21-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
19+
from lightning_lite.plugins.precision import NativeMixedPrecision
20+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
2221

23-
if _TORCH_GREATER_EQUAL_1_12 and torch.distributed.is_available():
22+
if TYPE_CHECKING:
2423
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
2524
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
26-
else:
27-
MixedPrecision = None # type: ignore[misc,assignment]
28-
ShardedGradScaler = None # type: ignore[misc,assignment]
2925

3026

31-
class FSDPPrecision(NativeMixedPrecisionPlugin):
27+
class FSDPPrecision(NativeMixedPrecision):
3228
"""AMP for Fully Sharded Data Parallel training."""
3329

34-
def __init__(self, precision: Union[str, int], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
30+
def __init__(self, precision: Literal[16, "bf16"], device: str, scaler: Optional[ShardedGradScaler] = None) -> None:
3531
if not _TORCH_GREATER_EQUAL_1_12:
3632
raise RuntimeError("`FSDPPrecision` is supported from PyTorch v1.12.0 onwards.")
33+
34+
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
35+
3736
super().__init__(
3837
precision=precision,
3938
device=device,
4039
scaler=(ShardedGradScaler() if scaler is None and precision == 16 else None),
4140
)
4241

4342
@property
44-
def mixed_precision_config(self) -> Optional[MixedPrecision]:
43+
def mixed_precision_config(self) -> MixedPrecision:
44+
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
45+
4546
if self.precision == PrecisionType.HALF:
4647
dtype = torch.float16
4748
elif self.precision == PrecisionType.BFLOAT:

src/lightning_lite/strategies/fsdp.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,56 +11,37 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import logging
1514
from contextlib import contextmanager
1615
from datetime import timedelta
17-
from typing import Any, Dict, Generator, List, Optional, Union
16+
from typing import Any, Dict, Generator, List, Optional, Union, TYPE_CHECKING
1817

1918
import torch
2019
from torch import Tensor
2120
from torch.distributed import default_pg_timeout
2221
from torch.nn import Module
2322

24-
import pytorch_lightning as pl
2523
from lightning_lite.accelerators import Accelerator
2624
from lightning_lite.plugins import CheckpointIO, ClusterEnvironment
25+
from lightning_lite.plugins.precision.fsdp import FSDPPrecision
2726
from lightning_lite.utilities.distributed import get_default_process_group_backend_for_device, distributed_available
2827
from lightning_lite.utilities.distributed import group as _group
2928
from lightning_lite.utilities.distributed import init_dist_connection, ReduceOp, sync_ddp_if_available
30-
from lightning_lite.utilities.optimizer import optimizers_to_device
3129
from lightning_lite.utilities.seed import reset_seed
3230
from lightning_lite.plugins import Precision
33-
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
34-
from pytorch_lightning.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
31+
from lightning_lite.strategies.launchers.subprocess_script import _SubprocessScriptLauncher
3532
from lightning_lite.strategies.parallel import ParallelStrategy
36-
from pytorch_lightning.strategies.strategy import TBroadcast
37-
from pytorch_lightning.trainer.states import TrainerFn
38-
from pytorch_lightning.utilities.exceptions import MisconfigurationException
39-
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_12
40-
from pytorch_lightning.utilities.model_helpers import is_overridden
41-
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only
42-
from pytorch_lightning.utilities.types import ProcessGroup, STEP_OUTPUT
43-
44-
_distributed_available = torch.distributed.is_available()
45-
_fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available
46-
if _fsdp_available:
33+
from lightning_lite.strategies.strategy import TBroadcast
34+
from lightning_lite.utilities.imports import _TORCH_GREATER_EQUAL_1_12
35+
from lightning_lite.utilities.rank_zero import rank_zero_only
36+
37+
if TYPE_CHECKING:
4738
from torch.distributed.fsdp.fully_sharded_data_parallel import (
4839
BackwardPrefetch,
4940
CPUOffload,
5041
FullyShardedDataParallel,
5142
MixedPrecision,
5243
)
5344
from torch.distributed.fsdp.wrap import enable_wrap
54-
else:
55-
FullyShardedDataParallel = None # type: ignore[misc,assignment]
56-
MixedPrecision = None # type: ignore[misc,assignment]
57-
BackwardPrefetch = None # type: ignore[misc,assignment]
58-
CPUOffload = None # type: ignore[misc,assignment]
59-
60-
if _distributed_available:
61-
from torch.distributed.distributed_c10d import _get_default_group
62-
63-
log = logging.getLogger(__name__)
6445

6546

6647
class FSDPStrategy(ParallelStrategy):
@@ -120,9 +101,7 @@ def __init__(
120101
**kwargs: Any,
121102
) -> None:
122103
if not _TORCH_GREATER_EQUAL_1_12:
123-
raise MisconfigurationException(
124-
"`FSDPStrategy` is supported from PyTorch v1.12.0 onwards."
125-
)
104+
raise RuntimeError("`FSDPStrategy` is supported from PyTorch v1.12.0 onwards.")
126105

127106
super().__init__(
128107
accelerator=accelerator,
@@ -169,13 +148,13 @@ def distributed_sampler_kwargs(self) -> Dict:
169148
def process_group_backend(self) -> Optional[str]:
170149
return self._process_group_backend
171150

172-
# @property
173-
# def mixed_precision_config(self) -> Optional[MixedPrecision]:
174-
# if self.mixed_precision:
175-
# return self.mixed_precision
176-
# plugin = self.precision_plugin
177-
# if isinstance(plugin, FullyShardedNativeNativeMixedPrecisionPlugin):
178-
# return plugin.mixed_precision_config
151+
@property
152+
def mixed_precision_config(self) -> Optional[MixedPrecision]:
153+
if self.mixed_precision:
154+
return self.mixed_precision
155+
plugin = self.precision_plugin
156+
if isinstance(plugin, FSDPPrecision):
157+
return plugin.mixed_precision_config
179158

180159
def _configure_launcher(self) -> None:
181160
assert self.cluster_environment is not None
@@ -189,6 +168,7 @@ def setup_environment(self) -> None:
189168
def setup_module(self, module: Module) -> FullyShardedDataParallel:
190169
"""Wraps the model into a
191170
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
171+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
192172
if (
193173
any(isinstance(mod, FullyShardedDataParallel) for mod in module.modules())
194174
and "auto_wrap_policy" in self._ddp_kwargs
@@ -209,12 +189,14 @@ def module_to_device(self, module: Module) -> None:
209189

210190
@contextmanager
211191
def module_sharded_context(self) -> Generator:
192+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
193+
from torch.distributed.fsdp.wrap import enable_wrap
194+
212195
with enable_wrap(
213196
wrapper_cls=FullyShardedDataParallel,
214-
# process_group=self.process_group,
215197
cpu_offload=self.cpu_offload,
216198
backward_prefetch=self.backward_prefetch,
217-
mixed_precision=self.precision_plugin.mixed_precision_config,
199+
mixed_precision=self.mixed_precision_config,
218200
device_id=self.root_device.index,
219201
**self._ddp_kwargs,
220202
):
@@ -244,6 +226,8 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
244226

245227
@classmethod
246228
def register_strategies(cls, strategy_registry: Dict) -> None:
229+
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
230+
247231
strategy_registry.register(
248232
"fsdp",
249233
cls,

0 commit comments

Comments
 (0)