Skip to content
2 changes: 2 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed a bug where checkpoints were not properly saved with `FSDP` strategy models ([#13500](https://github.com/Lightning-AI/lightning/issues/13500))


- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))

Expand Down
23 changes: 23 additions & 0 deletions src/pytorch_lightning/strategies/fully_sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand All @@ -33,6 +34,16 @@
from fairscale.nn import default_auto_wrap_policy, enable_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel

def unwrap_lightning_module_fully_sharded(wrapped_model: torch.nn.Module) -> "pl.LightningModule":
model = wrapped_model
if isinstance(model, FullyShardedDataParallel):
model = model.module

return unwrap_lightning_module(model)

else:
unwrap_lightning_module_fully_sharded = ... # type: ignore[assignment]

log = logging.getLogger(__name__)


Expand Down Expand Up @@ -209,6 +220,18 @@ def predict_step(self, *args, **kwargs) -> STEP_OUTPUT:
with self.precision_plugin.predict_step_context():
return self.model.predict_step(*args, **kwargs)

@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
# TODO unwrapping is eventually needed for checkpointing, but does this
# slow down training? Maybe this should go somewhere else.
# https://github.com/Lightning-AI/lightning/issues/13500
if not _FAIRSCALE_FULLY_SHARDED_AVAILABLE: # pragma: no cover
raise MisconfigurationException(
"`DDPFullyShardedStrategy` requires `fairscale>=0.3.4` to be installed."
" Install it by running `pip install fairscale`."
)
return unwrap_lightning_module_fully_sharded(self.model) if self.model is not None else None

def post_training_step(self):
pass

Expand Down
19 changes: 19 additions & 0 deletions src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.distributed.distributed_c10d import _get_default_group, ProcessGroup

import pytorch_lightning as pl
from pytorch_lightning.overrides.base import unwrap_lightning_module
from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
Expand Down Expand Up @@ -47,6 +48,15 @@
)
from torch.distributed.fsdp.wrap import enable_wrap

def unwrap_lightning_module_fully_sharded_native(wrapped_model: torch.nn.Module) -> "pl.LightningModule":
model = wrapped_model
if isinstance(model, FullyShardedDataParallel):
model = model.module

return unwrap_lightning_module(model)

else:
unwrap_lightning_module_fully_sharded_native = ... # type: ignore[assignment]

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -230,6 +240,15 @@ def teardown(self) -> None:

super().teardown()

@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
# TODO unwrapping is eventually needed for checkpointing, but does this
# slow down training? Maybe this should go somewhere else.
# https://github.com/Lightning-AI/lightning/issues/13500
if not _TORCH_GREATER_EQUAL_1_11: # pragma: no cover
raise MisconfigurationException("`DDPFullyShardedNativeStrategy` requires `torch>=1.11.0` to be installed.")
return unwrap_lightning_module_fully_sharded_native(self.model) if self.model is not None else None

@classmethod
def get_registered_strategies(cls) -> List[str]:
return cls._registered_strategies
Expand Down