Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
2752e9d
Add additional wrapping to handle sharded model
SpirinEgor Jan 30, 2023
d3ad5e7
Merge branch 'master' into master
SpirinEgor Jan 30, 2023
f070509
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2023
191f876
Use tuple to define yield type in Iterator
SpirinEgor Jan 30, 2023
0b8a680
Merge remote-tracking branch 'origin/master'
SpirinEgor Jan 30, 2023
f5f2158
Use tuple to define yield type in Iterator
SpirinEgor Jan 30, 2023
79431b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2023
7dbf02d
Add tests for checking state_dict extraction
SpirinEgor Jan 31, 2023
a6d80e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2023
53bef88
Merge branch 'master' into master
SpirinEgor Feb 1, 2023
b39b5e0
Merge remote-tracking branch 'upstream/master'
SpirinEgor Feb 2, 2023
d775d46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 2, 2023
321821f
More accurate handle of auto wrapped model
SpirinEgor Feb 2, 2023
c3574e3
Save hyperparameters for correct checkpoint loading
SpirinEgor Feb 2, 2023
8f78724
Merge branch 'master' into master
SpirinEgor Feb 2, 2023
5cec7ef
Merge branch 'master' into master
SpirinEgor Feb 6, 2023
b0c4a14
Merge branch 'master' into master
SpirinEgor Feb 7, 2023
6df3903
Merge branch 'master' into master
SpirinEgor Feb 7, 2023
c57f7d8
Merge branch 'master' into master
SpirinEgor Feb 13, 2023
ec56da1
Merge branch 'master' into master
Borda Mar 3, 2023
450a28a
Add additional wrapping to handle sharded model
SpirinEgor Jan 30, 2023
39e898e
Use tuple to define yield type in Iterator
SpirinEgor Jan 30, 2023
8527827
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 30, 2023
d2aeefb
Add tests for checking state_dict extraction
SpirinEgor Jan 31, 2023
158e0de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2023
8368140
More accurate handle of auto wrapped model
SpirinEgor Feb 2, 2023
4ab6ecd
Save hyperparameters for correct checkpoint loading
SpirinEgor Feb 2, 2023
c2175f6
Merge remote-tracking branch 'origin/master'
SpirinEgor Apr 13, 2023
54d9740
Wrap model state dict retrieve with FSDP context
SpirinEgor Apr 13, 2023
8393d2d
Use miltiple GPUs in tests, correct precision name
SpirinEgor Apr 13, 2023
a09ee14
Merge branch 'master' into master
SpirinEgor Apr 13, 2023
f763c09
Revert unnecessary changes in tests
SpirinEgor Apr 13, 2023
ed82d00
Always offload checkpoint to CPU
SpirinEgor Apr 14, 2023
bbd684b
Validate state_dict only on zero rank
SpirinEgor Apr 14, 2023
8e4804e
Offload to CPU only if trainer uses offload
SpirinEgor Apr 14, 2023
14d3d8b
Merge branch 'master' into master
SpirinEgor Apr 14, 2023
26e9ea7
Merge branch 'master' into master
SpirinEgor Apr 14, 2023
d998776
Merge branch 'master' into master
SpirinEgor Apr 17, 2023
0c1568a
Always offload checkpoint to CPU
SpirinEgor Apr 17, 2023
7a1b40f
update tests
awaelchli Apr 17, 2023
6b107bc
move function to bottom
awaelchli Apr 17, 2023
638a374
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
62441e9
update documentation
awaelchli Apr 17, 2023
2fc3834
Merge branch 'master' of github.com:SpirinEgor/lightning into SpirinE…
awaelchli Apr 17, 2023
99c0f39
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2023
9d16758
add changelog
awaelchli Apr 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))

- Enabled saving the full model state dict when using the `FSDPStrategy` ([#16558](https://github.com/Lightning-AI/lightning/pull/16558))


### Changed

Expand Down
29 changes: 28 additions & 1 deletion src/lightning/pytorch/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,13 @@
_distributed_available = torch.distributed.is_available()
_fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available
if _fsdp_available:
from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel, MixedPrecision
from torch.distributed.fsdp import (
CPUOffload,
FullStateDictConfig,
FullyShardedDataParallel,
MixedPrecision,
StateDictType,
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
FullyShardedDataParallel = None # type: ignore[misc,assignment]
Expand Down Expand Up @@ -139,6 +145,22 @@ def __init__(
# `self.trainer.model.parameters()` and enables support for multiple parameter groups.
self.kwargs.setdefault("use_orig_params", True)

def lightning_module_state_dict(self) -> Dict[str, Any]:
"""Gathers the full state dict by unsharding all the parameters.

To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty
dict.
"""
assert self.model is not None

with FullyShardedDataParallel.state_dict_type(
module=self.model,
state_dict_type=StateDictType.FULL_STATE_DICT,
state_dict_config=FullStateDictConfig(offload_to_cpu=(self.world_size > 1), rank0_only=True),
):
state_dict = self.model.state_dict()
return _strip_prefix_from_state_dict(state_dict, prefix="_forward_module.")

@property
def root_device(self) -> torch.device:
assert self.parallel_devices is not None
Expand Down Expand Up @@ -390,3 +412,8 @@ def register_strategies(cls, strategy_registry: Dict) -> None:
cpu_offload=True,
)
cls._registered_strategies.append("fsdp_cpu_offload")


def _strip_prefix_from_state_dict(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]:
prefix_len = len(prefix)
return {k[prefix_len:]: v for k, v in state_dict.items()}
37 changes: 35 additions & 2 deletions tests/tests_pytorch/strategies/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def _assert_layer_fsdp_instance(self) -> None:


class TestFSDPModelAutoWrapped(BoringModel):
def __init__(self):
def __init__(self, wrap_min_params: int = 2):
super().__init__()
self.save_hyperparameters()
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
self.should_be_wrapped = [(32 * 32 + 32) > wrap_min_params, None, (32 * 2 + 2) > wrap_min_params]

def configure_optimizers(self):
parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters()
Expand All @@ -112,6 +114,10 @@ def _assert_layer_fsdp_instance(self) -> None:

precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16
for layer_num in [0, 2]:
if not self.should_be_wrapped[layer_num]:
# this layer is not wrapped
assert not isinstance(self.layer[layer_num], FullyShardedDataParallel)
continue
assert isinstance(self.layer[layer_num], FullyShardedDataParallel)
assert self.layer[layer_num].mixed_precision.param_dtype == precision
assert self.layer[layer_num].mixed_precision.reduce_dtype == precision
Expand Down Expand Up @@ -224,7 +230,6 @@ def policy(self):

custom_fsdp_policy = CustomWrapPolicy(min_num_params=2)


if _TORCH_GREATER_EQUAL_2_0:

def custom_auto_wrap_policy(
Expand All @@ -244,6 +249,34 @@ def custom_auto_wrap_policy(
return unwrapped_params >= 2


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize("wrap_min_params", (2, 1024, 100000000))
def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params):
"""Test to ensure that the full state dict is extracted when using FSDP strategy.

Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all.
"""
model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params)
correct_state_dict = model.state_dict() # State dict before wrapping

strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params))
trainer = Trainer(
default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision="16-mixed", max_epochs=1
)
trainer.fit(model)

full_state_dict = trainer.strategy.lightning_module_state_dict()

if trainer.global_rank != 0:
assert len(full_state_dict) == 0
return

# State dict should contain same number of keys
assert len(correct_state_dict) == len(full_state_dict)
# OrderedDict should return the same keys in the same order
assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys()))


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12")
@pytest.mark.parametrize(
"model, strategy, strategy_cfg",
Expand Down