diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 46501a1b459cc..c65ef27bb93e1 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) +- Enabled saving the full model state dict when using the `FSDPStrategy` ([#16558](https://github.com/Lightning-AI/lightning/pull/16558)) + ### Changed diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 319cecca69ca2..0de1f5852d3af 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -55,7 +55,13 @@ _distributed_available = torch.distributed.is_available() _fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available if _fsdp_available: - from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp import ( + CPUOffload, + FullStateDictConfig, + FullyShardedDataParallel, + MixedPrecision, + StateDictType, + ) from torch.distributed.fsdp.wrap import enable_wrap else: FullyShardedDataParallel = None # type: ignore[misc,assignment] @@ -139,6 +145,22 @@ def __init__( # `self.trainer.model.parameters()` and enables support for multiple parameter groups. self.kwargs.setdefault("use_orig_params", True) + def lightning_module_state_dict(self) -> Dict[str, Any]: + """Gathers the full state dict by unsharding all the parameters. + + To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty + dict. + """ + assert self.model is not None + + with FullyShardedDataParallel.state_dict_type( + module=self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(offload_to_cpu=(self.world_size > 1), rank0_only=True), + ): + state_dict = self.model.state_dict() + return _strip_prefix_from_state_dict(state_dict, prefix="_forward_module.") + @property def root_device(self) -> torch.device: assert self.parallel_devices is not None @@ -390,3 +412,8 @@ def register_strategies(cls, strategy_registry: Dict) -> None: cpu_offload=True, ) cls._registered_strategies.append("fsdp_cpu_offload") + + +def _strip_prefix_from_state_dict(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]: + prefix_len = len(prefix) + return {k[prefix_len:]: v for k, v in state_dict.items()} diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 65990ac9cb849..58e4c1a4123f7 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -86,9 +86,11 @@ def _assert_layer_fsdp_instance(self) -> None: class TestFSDPModelAutoWrapped(BoringModel): - def __init__(self): + def __init__(self, wrap_min_params: int = 2): super().__init__() + self.save_hyperparameters() self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + self.should_be_wrapped = [(32 * 32 + 32) > wrap_min_params, None, (32 * 2 + 2) > wrap_min_params] def configure_optimizers(self): parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters() @@ -112,6 +114,10 @@ def _assert_layer_fsdp_instance(self) -> None: precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16 for layer_num in [0, 2]: + if not self.should_be_wrapped[layer_num]: + # this layer is not wrapped + assert not isinstance(self.layer[layer_num], FullyShardedDataParallel) + continue assert isinstance(self.layer[layer_num], FullyShardedDataParallel) assert self.layer[layer_num].mixed_precision.param_dtype == precision assert self.layer[layer_num].mixed_precision.reduce_dtype == precision @@ -224,7 +230,6 @@ def policy(self): custom_fsdp_policy = CustomWrapPolicy(min_num_params=2) - if _TORCH_GREATER_EQUAL_2_0: def custom_auto_wrap_policy( @@ -244,6 +249,34 @@ def custom_auto_wrap_policy( return unwrapped_params >= 2 +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("wrap_min_params", (2, 1024, 100000000)) +def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params): + """Test to ensure that the full state dict is extracted when using FSDP strategy. + + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. + """ + model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params) + correct_state_dict = model.state_dict() # State dict before wrapping + + strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) + trainer = Trainer( + default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision="16-mixed", max_epochs=1 + ) + trainer.fit(model) + + full_state_dict = trainer.strategy.lightning_module_state_dict() + + if trainer.global_rank != 0: + assert len(full_state_dict) == 0 + return + + # State dict should contain same number of keys + assert len(correct_state_dict) == len(full_state_dict) + # OrderedDict should return the same keys in the same order + assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys())) + + @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize( "model, strategy, strategy_cfg",