From 2752e9da7a8c8d9f0de0c42bed068ae1f48c18e0 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 30 Jan 2023 16:02:39 +0300 Subject: [PATCH 01/30] Add additional wrapping to handle sharded model --- src/pytorch_lightning/strategies/fsdp.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/fsdp.py b/src/pytorch_lightning/strategies/fsdp.py index 58af2568350a5..7fc5d845e20ab 100644 --- a/src/pytorch_lightning/strategies/fsdp.py +++ b/src/pytorch_lightning/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, List, Optional, Type, Union +from typing import Any, Dict, Generator, List, Optional, Type, Union, Iterator import torch from torch import Tensor @@ -71,6 +71,20 @@ log = logging.getLogger(__name__) +class _FSDPStrategyModuleWrapper(_LightningModuleWrapperBase): + def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: ignore[override] + # this is required because with FSDP lightning_module is empty because weights are sharded. + # So we need to call self.trainer.model.state_dict (wrapped version) and use this wraper to + # avoid extra keys `_forward_module.layer.weight.` since we want `layer.weight.` in state_dict. + return self._forward_module.state_dict(*args, **kwargs) + + def named_modules(self, *args: Any, **kwargs: Any) -> Iterator[str, Module]: + # This is required because FSDP explicitly checks that each flatted parameter in state_dict. + # Since we are wrapping the model, all flatted parameters will have `_forward_module.` prefix. + # This redirect avoids adding this prefix. + return self._forward_module.named_modules() + + class FSDPStrategy(ParallelStrategy): r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. @@ -151,6 +165,11 @@ def __init__( ) self.kwargs = kwargs + def lightning_module_state_dict(self) -> Dict[str, Any]: + """Returns model state.""" + assert self.model is not None + return self.model.state_dict() + @property def root_device(self) -> torch.device: assert self.parallel_devices is not None @@ -254,7 +273,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self.lightning_module._device = self.root_device assert isinstance(self.model, pl.LightningModule) - self.model = _LightningModuleWrapperBase(self.model) + self.model = _FSDPStrategyModuleWrapper(self.model) if is_overridden("configure_sharded_model", self.lightning_module): rank_zero_info( "You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers" From f070509f8004faeb57480ba5ed5f0424238f282f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Jan 2023 13:12:53 +0000 Subject: [PATCH 02/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/fsdp.py b/src/pytorch_lightning/strategies/fsdp.py index 7fc5d845e20ab..a88cbe52a501b 100644 --- a/src/pytorch_lightning/strategies/fsdp.py +++ b/src/pytorch_lightning/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, List, Optional, Type, Union, Iterator +from typing import Any, Dict, Generator, Iterator, List, Optional, Type, Union import torch from torch import Tensor From 191f8762a6336aa31df12311d365950290751477 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 30 Jan 2023 16:25:54 +0300 Subject: [PATCH 03/30] Use tuple to define yield type in Iterator --- src/pytorch_lightning/strategies/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/fsdp.py b/src/pytorch_lightning/strategies/fsdp.py index 7fc5d845e20ab..3dd0af71e0720 100644 --- a/src/pytorch_lightning/strategies/fsdp.py +++ b/src/pytorch_lightning/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, List, Optional, Type, Union, Iterator +from typing import Any, Dict, Generator, List, Optional, Type, Union, Iterator, Tuple import torch from torch import Tensor @@ -78,7 +78,7 @@ def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: igno # avoid extra keys `_forward_module.layer.weight.` since we want `layer.weight.` in state_dict. return self._forward_module.state_dict(*args, **kwargs) - def named_modules(self, *args: Any, **kwargs: Any) -> Iterator[str, Module]: + def named_modules(self, *args: Any, **kwargs: Any) -> Iterator[Tuple[str, Module]]: # This is required because FSDP explicitly checks that each flatted parameter in state_dict. # Since we are wrapping the model, all flatted parameters will have `_forward_module.` prefix. # This redirect avoids adding this prefix. From f5f2158d573e18257f0cecf534ed658bef62d0ee Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 30 Jan 2023 16:26:59 +0300 Subject: [PATCH 04/30] Use tuple to define yield type in Iterator --- src/pytorch_lightning/strategies/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/fsdp.py b/src/pytorch_lightning/strategies/fsdp.py index a88cbe52a501b..26bc891de3607 100644 --- a/src/pytorch_lightning/strategies/fsdp.py +++ b/src/pytorch_lightning/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, Iterator, List, Optional, Type, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Type, Union, Tuple import torch from torch import Tensor @@ -78,7 +78,7 @@ def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: igno # avoid extra keys `_forward_module.layer.weight.` since we want `layer.weight.` in state_dict. return self._forward_module.state_dict(*args, **kwargs) - def named_modules(self, *args: Any, **kwargs: Any) -> Iterator[str, Module]: + def named_modules(self, *args: Any, **kwargs: Any) -> Iterator[Tuple[str, Module]]: # This is required because FSDP explicitly checks that each flatted parameter in state_dict. # Since we are wrapping the model, all flatted parameters will have `_forward_module.` prefix. # This redirect avoids adding this prefix. From 79431b70bd4186438ab3c7953b908aa70497b14f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Jan 2023 13:28:00 +0000 Subject: [PATCH 05/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/fsdp.py b/src/pytorch_lightning/strategies/fsdp.py index 26bc891de3607..62e475d4311e9 100644 --- a/src/pytorch_lightning/strategies/fsdp.py +++ b/src/pytorch_lightning/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, Iterator, List, Optional, Type, Union, Tuple +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union import torch from torch import Tensor From 7dbf02d7410cefb3f68f84424b1101e4d0952b45 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Tue, 31 Jan 2023 14:36:15 +0300 Subject: [PATCH 06/30] Add tests for checking state_dict extraction --- tests/tests_pytorch/strategies/test_fsdp.py | 30 ++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 8a46bddac4c76..9fa6db9a3730b 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -1,4 +1,5 @@ import os +from functools import partial from typing import Any, Dict, Optional from unittest import mock from unittest.mock import ANY, Mock @@ -18,7 +19,7 @@ if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision - from torch.distributed.fsdp.wrap import wrap + from torch.distributed.fsdp.wrap import wrap, size_based_auto_wrap_policy class TestFSDPModel(BoringModel): @@ -102,6 +103,10 @@ def _assert_layer_fsdp_instance(self) -> None: precision = torch.float16 if self.trainer.precision == "16" else torch.bfloat16 for layer_num in [0, 2]: + num_params = sum(p.numel() for p in self.layer[layer_num].parameters()) + if not custom_auto_wrap_policy(self.layer[layer_num], False, num_params): + # This layer is not wrapped + continue assert isinstance(self.layer[layer_num], FullyShardedDataParallel) assert self.layer[layer_num].mixed_precision.param_dtype == precision assert self.layer[layer_num].mixed_precision.reduce_dtype == precision @@ -209,6 +214,29 @@ def test_fsdp_strategy_checkpoint(tmpdir, precision): _run_multiple_stages(trainer, model, os.path.join(tmpdir, "last.ckpt")) +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("wrap_min_params", (2, 1024, 1048576)) +def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): + """Test to ensure that state dict is extracted correctly when using FSDP strategy. + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. + """ + model = TestFSDPModelAutoWrapped() + correct_state_dict = model.state_dict() # State dict before wrapping + + strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) + trainer = Trainer( + default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy=strategy, precision=16, max_epochs=1 + ) + trainer.fit(model) + # CheckpointConnector use this to extract state dict + extracted_state_dict = trainer.strategy.lightning_module_state_dict() + + # State dict should contain same number of keys + assert len(correct_state_dict) == len(extracted_state_dict) + # OrderedDict should return the same keys in the same order + assert all(_ex == _co for _ex, _co in zip(list(extracted_state_dict.keys()), list(correct_state_dict.keys()))) + + @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize( "model, strategy", From a6d80e28c0be0f46596e0db1bc055f5fcc0e9343 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Jan 2023 11:37:37 +0000 Subject: [PATCH 07/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/strategies/test_fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 9fa6db9a3730b..9ae3783d66d2d 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -19,7 +19,7 @@ if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision - from torch.distributed.fsdp.wrap import wrap, size_based_auto_wrap_policy + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, wrap class TestFSDPModel(BoringModel): @@ -218,6 +218,7 @@ def test_fsdp_strategy_checkpoint(tmpdir, precision): @pytest.mark.parametrize("wrap_min_params", (2, 1024, 1048576)) def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): """Test to ensure that state dict is extracted correctly when using FSDP strategy. + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. """ model = TestFSDPModelAutoWrapped() From d775d46d1f5c4c8f894ce486a1c4c94ed398751d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Feb 2023 07:52:01 +0000 Subject: [PATCH 08/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .github/CONTRIBUTING.md | 1 + docs/source-pytorch/accelerators/ipu_advanced.rst | 1 - docs/source-pytorch/advanced/post_training_quantization.rst | 1 + docs/source-pytorch/advanced/training_tricks.rst | 1 + docs/source-pytorch/advanced/transfer_learning.rst | 1 - docs/source-pytorch/common/precision_intermediate.rst | 1 + docs/source-pytorch/data/custom_data_iterables.rst | 2 -- docs/source-pytorch/data/datamodule.rst | 2 -- docs/source-pytorch/debug/debugging_advanced.rst | 1 - docs/source-pytorch/ecosystem/asr_nlp_tts.rst | 1 + docs/source-pytorch/fabric/advanced/gradient_accumulation.rst | 1 - docs/source-pytorch/fabric/fundamentals/notebooks.rst | 1 - docs/source-pytorch/fabric/guide/logging.rst | 1 - docs/source-pytorch/fabric/guide/multi_node/cloud.rst | 2 +- docs/source-pytorch/guides/data.rst | 4 ---- docs/source-pytorch/starter/introduction.rst | 1 + src/lightning/app/components/serve/auto_scaler.py | 1 + 17 files changed, 8 insertions(+), 15 deletions(-) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 04ec953ad5e4e..811014fcd817c 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -243,6 +243,7 @@ Here is the process to create a new test # TEST SHOULD BE IN YOUR FILE: tests/.../test_file.py # TEST CODE TEMPLATE + # [OPTIONAL] pytest decorator # @RunIf(min_cuda_gpus=1) def test_explain_what_is_being_tested(tmpdir): diff --git a/docs/source-pytorch/accelerators/ipu_advanced.rst b/docs/source-pytorch/accelerators/ipu_advanced.rst index 1dc4e71ee4d7a..fb10ef1d60a1f 100644 --- a/docs/source-pytorch/accelerators/ipu_advanced.rst +++ b/docs/source-pytorch/accelerators/ipu_advanced.rst @@ -120,7 +120,6 @@ You can also use the block context manager within the forward function, or any o self.softmax = torch.nn.Softmax(dim=1) def forward(self, x): - with poptorch.Block(ipu_id=0): x = self.act(self.layer1(x)) diff --git a/docs/source-pytorch/advanced/post_training_quantization.rst b/docs/source-pytorch/advanced/post_training_quantization.rst index c474b3a307044..0e0373f92d215 100644 --- a/docs/source-pytorch/advanced/post_training_quantization.rst +++ b/docs/source-pytorch/advanced/post_training_quantization.rst @@ -71,6 +71,7 @@ Load the pretrained model with PyTorch Lightning: from pytorch_lightning import LightningModule from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer + # BERT Model definition class GLUETransformer(LightningModule): def __init__(self): diff --git a/docs/source-pytorch/advanced/training_tricks.rst b/docs/source-pytorch/advanced/training_tricks.rst index 6d39c1550ee20..89b2135752c83 100644 --- a/docs/source-pytorch/advanced/training_tricks.rst +++ b/docs/source-pytorch/advanced/training_tricks.rst @@ -133,6 +133,7 @@ search for batch sizes larger than the size of the training dataset. tuner = Tuner(trainer) tuner.scale_batch_size(model) + # using LightningDataModule class LitDataModule(LightningDataModule): def __init__(self, batch_size): diff --git a/docs/source-pytorch/advanced/transfer_learning.rst b/docs/source-pytorch/advanced/transfer_learning.rst index 2d221cf6f7f3e..102555b94a455 100644 --- a/docs/source-pytorch/advanced/transfer_learning.rst +++ b/docs/source-pytorch/advanced/transfer_learning.rst @@ -120,7 +120,6 @@ Here's a model that uses `Huggingface transformers int: From 321821f704a0e7b7f2db43dbff9d7fa38d35d8ba Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Thu, 2 Feb 2023 12:21:11 +0300 Subject: [PATCH 09/30] More accurate handle of auto wrapped model --- tests/tests_pytorch/strategies/test_fsdp.py | 24 ++++++++------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 3ddee74246546..1353f4ce52fd2 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -78,9 +78,10 @@ def _assert_layer_fsdp_instance(self) -> None: class TestFSDPModelAutoWrapped(BoringModel): - def __init__(self): + def __init__(self, wrap_min_params: int): super().__init__() self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + self.wrap_min_params = wrap_min_params def configure_optimizers(self): return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1) @@ -104,8 +105,9 @@ def _assert_layer_fsdp_instance(self) -> None: precision = torch.float16 if self.trainer.precision == "16" else torch.bfloat16 for layer_num in [0, 2]: num_params = sum(p.numel() for p in self.layer[layer_num].parameters()) - if not custom_auto_wrap_policy(self.layer[layer_num], False, num_params): - # This layer is not wrapped + if num_params < self.wrap_min_params: + # this layer is not wrapped + assert not isinstance(self.layer[layer_num], FullyShardedDataParallel) continue assert isinstance(self.layer[layer_num], FullyShardedDataParallel) assert self.layer[layer_num].mixed_precision.param_dtype == precision @@ -147,15 +149,6 @@ def _assert_save_equality(trainer, ckpt_path, cls=TestFSDPModel): assert torch.equal(ddp_param.float().cpu(), shard_param) -def custom_auto_wrap_policy( - module, - recurse, - unwrapped_params: int, - min_num_params: int = int(1e8), -) -> bool: - return unwrapped_params >= 2 - - @RunIf(min_torch="1.12") def test_invalid_on_cpu(tmpdir): """Test to ensure that we raise Misconfiguration for FSDP on CPU.""" @@ -221,7 +214,7 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. """ - model = TestFSDPModelAutoWrapped() + model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params) correct_state_dict = model.state_dict() # State dict before wrapping strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) @@ -243,7 +236,7 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): "model, strategy", [ (TestFSDPModel(), "fsdp"), - (TestFSDPModelAutoWrapped(), FSDPStrategy), + (TestFSDPModelAutoWrapped(wrap_min_params=2), FSDPStrategy), ], ) def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy): @@ -252,7 +245,8 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy): ck = ModelCheckpoint(save_last=True) if not isinstance(strategy, str): - strategy = strategy(auto_wrap_policy=custom_auto_wrap_policy) + # So every layer is wrapped + strategy = strategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=2)) trainer = Trainer( default_root_dir=tmpdir, From c3574e33d3983a0678fad75f0a558e56ec7c912e Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Thu, 2 Feb 2023 12:36:45 +0300 Subject: [PATCH 10/30] Save hyperparameters for correct checkpoint loading --- tests/tests_pytorch/strategies/test_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 1353f4ce52fd2..a72a7c9ba793f 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -80,6 +80,7 @@ def _assert_layer_fsdp_instance(self) -> None: class TestFSDPModelAutoWrapped(BoringModel): def __init__(self, wrap_min_params: int): super().__init__() + self.save_hyperparameters() self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) self.wrap_min_params = wrap_min_params From 450a28af8d3b2943f362661189860ac075c6a827 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 30 Jan 2023 16:02:39 +0300 Subject: [PATCH 11/30] Add additional wrapping to handle sharded model --- src/lightning/pytorch/strategies/fsdp.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 42a1702d7fbd6..85661f059ef81 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, List, Optional, Type, Union +from typing import Any, Dict, Generator, List, Optional, Type, Union, Iterator import torch from torch import Tensor @@ -68,6 +68,20 @@ log = logging.getLogger(__name__) +class _FSDPStrategyModuleWrapper(_LightningModuleWrapperBase): + def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: ignore[override] + # this is required because with FSDP lightning_module is empty because weights are sharded. + # So we need to call self.trainer.model.state_dict (wrapped version) and use this wraper to + # avoid extra keys `_forward_module.layer.weight.` since we want `layer.weight.` in state_dict. + return self._forward_module.state_dict(*args, **kwargs) + + def named_modules(self, *args: Any, **kwargs: Any) -> Iterator[str, Module]: + # This is required because FSDP explicitly checks that each flatted parameter in state_dict. + # Since we are wrapping the model, all flatted parameters will have `_forward_module.` prefix. + # This redirect avoids adding this prefix. + return self._forward_module.named_modules() + + class FSDPStrategy(ParallelStrategy): r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. @@ -139,6 +153,11 @@ def __init__( # `self.trainer.model.parameters()` and enables support for multiple parameter groups. self.kwargs.setdefault("use_orig_params", True) + def lightning_module_state_dict(self) -> Dict[str, Any]: + """Returns model state.""" + assert self.model is not None + return self.model.state_dict() + @property def root_device(self) -> torch.device: assert self.parallel_devices is not None @@ -241,7 +260,7 @@ def setup(self, trainer: "pl.Trainer") -> None: self.lightning_module._device = self.root_device assert isinstance(self.model, pl.LightningModule) - self.model = _LightningModuleWrapperBase(self.model) + self.model = _FSDPStrategyModuleWrapper(self.model) if is_overridden("configure_sharded_model", self.lightning_module): rank_zero_info( "You have overridden `LightningModule.configure_sharded_model` hook. It will assume that all the layers" From 39e898ea5229c45baf7a8f9b7e2ef4f71dc0a84d Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 30 Jan 2023 16:25:54 +0300 Subject: [PATCH 12/30] Use tuple to define yield type in Iterator --- src/lightning/pytorch/strategies/fsdp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 85661f059ef81..26809562485af 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib import logging -from typing import Any, Dict, Generator, List, Optional, Type, Union, Iterator +from typing import Any, Dict, Generator, List, Optional, Type, Union, Iterator, Tuple import torch from torch import Tensor @@ -75,7 +75,7 @@ def state_dict(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: # type: igno # avoid extra keys `_forward_module.layer.weight.` since we want `layer.weight.` in state_dict. return self._forward_module.state_dict(*args, **kwargs) - def named_modules(self, *args: Any, **kwargs: Any) -> Iterator[str, Module]: + def named_modules(self, *args: Any, **kwargs: Any) -> Iterator[Tuple[str, Module]]: # This is required because FSDP explicitly checks that each flatted parameter in state_dict. # Since we are wrapping the model, all flatted parameters will have `_forward_module.` prefix. # This redirect avoids adding this prefix. From 8527827105ce0fe71fbd2bdfb47eec1def0efd28 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 30 Jan 2023 13:12:53 +0000 Subject: [PATCH 13/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 26809562485af..42ace8f8527ab 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -1,4 +1,4 @@ -# Copyright The Lightning AI team. +# Copyright The Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From d2aeefb82d35cb9980ba4cca733d01757750fdcc Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Tue, 31 Jan 2023 14:36:15 +0300 Subject: [PATCH 14/30] Add tests for checking state_dict extraction --- tests/tests_pytorch/strategies/test_fsdp.py | 29 +++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 8c004e93d587b..02c5ce15dc3be 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -2,6 +2,8 @@ from contextlib import nullcontext from functools import partial from typing import Any, Callable, Dict, Optional +from functools import partial +from typing import Any, Dict, Optional from unittest import mock from unittest.mock import ANY, Mock @@ -112,6 +114,10 @@ def _assert_layer_fsdp_instance(self) -> None: precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16 for layer_num in [0, 2]: + num_params = sum(p.numel() for p in self.layer[layer_num].parameters()) + if not custom_auto_wrap_policy(self.layer[layer_num], False, num_params): + # This layer is not wrapped + continue assert isinstance(self.layer[layer_num], FullyShardedDataParallel) assert self.layer[layer_num].mixed_precision.param_dtype == precision assert self.layer[layer_num].mixed_precision.reduce_dtype == precision @@ -244,6 +250,29 @@ def custom_auto_wrap_policy( return unwrapped_params >= 2 +@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") +@pytest.mark.parametrize("wrap_min_params", (2, 1024, 1048576)) +def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): + """Test to ensure that state dict is extracted correctly when using FSDP strategy. + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. + """ + model = TestFSDPModelAutoWrapped() + correct_state_dict = model.state_dict() # State dict before wrapping + + strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) + trainer = Trainer( + default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy=strategy, precision=16, max_epochs=1 + ) + trainer.fit(model) + # CheckpointConnector use this to extract state dict + extracted_state_dict = trainer.strategy.lightning_module_state_dict() + + # State dict should contain same number of keys + assert len(correct_state_dict) == len(extracted_state_dict) + # OrderedDict should return the same keys in the same order + assert all(_ex == _co for _ex, _co in zip(list(extracted_state_dict.keys()), list(correct_state_dict.keys()))) + + @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize( "model, strategy, strategy_cfg", From 158e0deaa261f41861da79fa34d307f4e8a79e45 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 31 Jan 2023 11:37:37 +0000 Subject: [PATCH 15/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/strategies/test_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 02c5ce15dc3be..31dbaf904baa9 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -254,6 +254,7 @@ def custom_auto_wrap_policy( @pytest.mark.parametrize("wrap_min_params", (2, 1024, 1048576)) def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): """Test to ensure that state dict is extracted correctly when using FSDP strategy. + Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. """ model = TestFSDPModelAutoWrapped() From 8368140f96fb8231748e3831670140d329d931c8 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Thu, 2 Feb 2023 12:21:11 +0300 Subject: [PATCH 16/30] More accurate handle of auto wrapped model --- tests/tests_pytorch/strategies/test_fsdp.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 31dbaf904baa9..245fc9c0da3f9 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -88,9 +88,10 @@ def _assert_layer_fsdp_instance(self) -> None: class TestFSDPModelAutoWrapped(BoringModel): - def __init__(self): + def __init__(self, wrap_min_params: int): super().__init__() self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) + self.wrap_min_params = wrap_min_params def configure_optimizers(self): parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters() @@ -115,8 +116,9 @@ def _assert_layer_fsdp_instance(self) -> None: precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16 for layer_num in [0, 2]: num_params = sum(p.numel() for p in self.layer[layer_num].parameters()) - if not custom_auto_wrap_policy(self.layer[layer_num], False, num_params): - # This layer is not wrapped + if num_params < self.wrap_min_params: + # this layer is not wrapped + assert not isinstance(self.layer[layer_num], FullyShardedDataParallel) continue assert isinstance(self.layer[layer_num], FullyShardedDataParallel) assert self.layer[layer_num].mixed_precision.param_dtype == precision @@ -257,7 +259,7 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. """ - model = TestFSDPModelAutoWrapped() + model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params) correct_state_dict = model.state_dict() # State dict before wrapping strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) @@ -300,6 +302,8 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): marks=RunIf(min_torch="2.0.0"), id="autowrap_use_orig_params", ), + (TestFSDPModel(), "fsdp"), + (TestFSDPModelAutoWrapped(wrap_min_params=2), FSDPStrategy), ], ) def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg): @@ -309,6 +313,8 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg): strategy_cfg = strategy_cfg or {} if not isinstance(strategy, str): + # So every layer is wrapped + strategy = strategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=2)) strategy = strategy(**strategy_cfg) trainer = Trainer( From 4ab6ecd91917c997e15212cb001c8c424190eca2 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Thu, 2 Feb 2023 12:36:45 +0300 Subject: [PATCH 17/30] Save hyperparameters for correct checkpoint loading --- tests/tests_pytorch/strategies/test_fsdp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 245fc9c0da3f9..d42cbe704fc5f 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -90,6 +90,7 @@ def _assert_layer_fsdp_instance(self) -> None: class TestFSDPModelAutoWrapped(BoringModel): def __init__(self, wrap_min_params: int): super().__init__() + self.save_hyperparameters() self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) self.wrap_min_params = wrap_min_params From 54d974095a34bef45bdb20c74919a6e9874b09f3 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Thu, 13 Apr 2023 21:18:25 +0300 Subject: [PATCH 18/30] Wrap model state dict retrieve with FSDP context --- src/lightning/pytorch/strategies/fsdp.py | 49 +++++++++++++++--------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index b845d1a4d7b19..1bc86f0e20e02 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -55,7 +55,13 @@ _distributed_available = torch.distributed.is_available() _fsdp_available = _TORCH_GREATER_EQUAL_1_12 and _distributed_available if _fsdp_available: - from torch.distributed.fsdp import CPUOffload, FullyShardedDataParallel, MixedPrecision + from torch.distributed.fsdp import ( + CPUOffload, + FullStateDictConfig, + FullyShardedDataParallel, + MixedPrecision, + StateDictType, + ) from torch.distributed.fsdp.wrap import enable_wrap else: FullyShardedDataParallel = None # type: ignore[misc,assignment] @@ -68,42 +74,37 @@ log = logging.getLogger(__name__) +def _clean_up_state_dict(state_dict: Dict[str, Any], prefix: str = "_forward_module.") -> Dict[str, Any]: + prefix_len = len(prefix) + clean_state_dict = {k[prefix_len:]: v for k, v in state_dict.items()} + return clean_state_dict + + class FSDPStrategy(ParallelStrategy): r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. - .. warning:: ``FSDPStrategy`` is in BETA and subject to change. The interface can - bring breaking changes and new features with the next release of PyTorch. + .. warning:: This is an :ref:`experimental ` feature. Fully Sharded Training shards the entire model across all available GPUs, allowing you to scale model size, whilst using efficient communication to reduce overhead. In practice, this means we can remain at parity with PyTorch DDP, whilst scaling our model sizes dramatically. The technique is similar to ZeRO-Stage 3. - For more information `check out `__. + For more information check out + `this blogpost `__. Defaults have been set and options have been exposed, but may require configuration based on your level of memory/speed efficiency. We suggest having a look at `this tutorial `__ for more information. Arguments: - cpu_offload: Enable offloading parameters and gradients to CPU to save GPU memory at the cost of speed. - You can also pass a config: ``cpu_offload=CPUOffload(offload_params=True)``. Note that this currently - implicitly enables gradient offloading to CPU in order for parameters and gradients to be on same device - to work with the optimizer. This API is subject to change. Default: no offloading - backward_prefetch: - This is an experimental feature that is subject to change in the - the near future. It allows users to enable two different backward_prefetch - algorithms to help backward communication and computation overlapping. - The pros and cons of each algorithm is explained in the class ``BackwardPrefetch``. - mixed_precision: - Mixed Precision config. By default, Lightning will enable FP16 if ``precision="16-mixed"`` - or BF16 if ``precision="bf16-mixed"`` unless a config is passed in. - This is only available in PyTorch 1.12 and later. + cpu_offload: See ``cpu_offload`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. + mixed_precision: See ``mixed_precision`` parameter in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. activation_checkpointing: A single layer or a list of layer classes for which you want to enable activation checkpointing. This is typically your transformer block (including attention + feed-forward). Enabling this can free up a significant amount of memory at the cost of speed since activations in these layers need to be recomputed during backpropagation. - \**kwargs: Passed to the FSDP context manager which will configure the FSDP class when wrapping modules. + \**kwargs: See available parameters in :class:`torch.distributed.fsdp.FullyShardedDataParallel`. """ @@ -150,6 +151,18 @@ def __init__( # `self.trainer.model.parameters()` and enables support for multiple parameter groups. self.kwargs.setdefault("use_orig_params", True) + def lightning_module_state_dict(self) -> Dict[str, Any]: + """Returns model state.""" + assert self.model is not None + + with FullyShardedDataParallel.state_dict_type( + module=self.model, + state_dict_type=StateDictType.FULL_STATE_DICT, + state_dict_config=FullStateDictConfig(offload_to_cpu=self.cpu_offload.offload_params, rank0_only=True), + ): + state_dict = self.model.state_dict() + return _clean_up_state_dict(state_dict) + @property def root_device(self) -> torch.device: assert self.parallel_devices is not None From 8393d2de82badb550c2185198102832fb556fa67 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Thu, 13 Apr 2023 21:19:40 +0300 Subject: [PATCH 19/30] Use miltiple GPUs in tests, correct precision name --- tests/tests_pytorch/strategies/test_fsdp.py | 42 ++++----------------- 1 file changed, 7 insertions(+), 35 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 4bb650bf0690f..14bd05c7219f5 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -1,9 +1,6 @@ import os -from contextlib import nullcontext from functools import partial from typing import Any, Callable, Dict, Optional -from functools import partial -from typing import Any, Dict, Optional from unittest import mock from unittest.mock import ANY, Mock @@ -233,31 +230,6 @@ def policy(self): custom_fsdp_policy = CustomWrapPolicy(min_num_params=2) - -@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") -@pytest.mark.parametrize("wrap_min_params", (2, 1024, 1048576)) -def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): - """Test to ensure that state dict is extracted correctly when using FSDP strategy. - - Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. - """ - model = TestFSDPModelAutoWrapped(wrap_min_params=wrap_min_params) - correct_state_dict = model.state_dict() # State dict before wrapping - - strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) - trainer = Trainer( - default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy=strategy, precision=16, max_epochs=1 - ) - trainer.fit(model) - # CheckpointConnector use this to extract state dict - extracted_state_dict = trainer.strategy.lightning_module_state_dict() - - # State dict should contain same number of keys - assert len(correct_state_dict) == len(extracted_state_dict) - # OrderedDict should return the same keys in the same order - assert all(_ex == _co for _ex, _co in zip(list(extracted_state_dict.keys()), list(correct_state_dict.keys()))) - - if _TORCH_GREATER_EQUAL_2_0: def custom_auto_wrap_policy( @@ -277,7 +249,7 @@ def custom_auto_wrap_policy( return unwrapped_params >= 2 -@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") +@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") @pytest.mark.parametrize("wrap_min_params", (2, 1024, 1048576)) def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): """Test to ensure that state dict is extracted correctly when using FSDP strategy. @@ -289,7 +261,7 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): strategy = FSDPStrategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=wrap_min_params)) trainer = Trainer( - default_root_dir=tmpdir, accelerator="gpu", devices=1, strategy=strategy, precision=16, max_epochs=1 + default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision="16-mixed", max_epochs=1 ) trainer.fit(model) # CheckpointConnector use this to extract state dict @@ -307,28 +279,28 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): [ pytest.param(TestFSDPModel(), "fsdp", None, id="manually_wrapped"), pytest.param( - TestFSDPModelAutoWrapped(), + TestFSDPModelAutoWrapped(wrap_min_params=2), FSDPStrategy, {"auto_wrap_policy": custom_auto_wrap_policy}, marks=RunIf(max_torch="2.0.0"), id="autowrap_1x", ), pytest.param( - TestFSDPModelAutoWrapped(), + TestFSDPModelAutoWrapped(wrap_min_params=2), FSDPStrategy, {"auto_wrap_policy": custom_auto_wrap_policy}, marks=RunIf(min_torch="2.0.0"), id="autowrap_2x", ), pytest.param( - TestFSDPModelAutoWrapped(), + TestFSDPModelAutoWrapped(wrap_min_params=2), FSDPStrategy, {"auto_wrap_policy": custom_fsdp_policy, "use_orig_params": True}, marks=RunIf(min_torch="2.0.0"), id="autowrap_use_orig_params", ), - (TestFSDPModel(), "fsdp"), - (TestFSDPModelAutoWrapped(wrap_min_params=2), FSDPStrategy), + (TestFSDPModel(), "fsdp", None), + (TestFSDPModelAutoWrapped(wrap_min_params=2), FSDPStrategy, None), ], ) def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg): From f763c09fd30792dfd37d961b700690d29b0fc745 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Thu, 13 Apr 2023 21:27:19 +0300 Subject: [PATCH 20/30] Revert unnecessary changes in tests --- tests/tests_pytorch/strategies/test_fsdp.py | 22 +++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 14bd05c7219f5..d7d6695afae67 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -1,4 +1,5 @@ import os +from contextlib import nullcontext from functools import partial from typing import Any, Callable, Dict, Optional from unittest import mock @@ -299,8 +300,6 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): marks=RunIf(min_torch="2.0.0"), id="autowrap_use_orig_params", ), - (TestFSDPModel(), "fsdp", None), - (TestFSDPModelAutoWrapped(wrap_min_params=2), FSDPStrategy, None), ], ) def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg): @@ -310,8 +309,6 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg): strategy_cfg = strategy_cfg or {} if not isinstance(strategy, str): - # So every layer is wrapped - strategy = strategy(auto_wrap_policy=partial(size_based_auto_wrap_policy, min_num_params=2)) strategy = strategy(**strategy_cfg) trainer = Trainer( @@ -326,21 +323,30 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg): limit_test_batches=2, limit_predict_batches=2, callbacks=[ck], - inference_mode=not _TORCH_GREATER_EQUAL_2_0, # TODO(carmocca): inference_mode raises RuntimeError ) _run_multiple_stages(trainer, model) @RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12") def test_invalid_parameters_in_optimizer(): - trainer = Trainer(strategy="fsdp", accelerator="cuda", devices=1) + trainer = Trainer( + strategy="fsdp", + accelerator="cuda", + devices=1, + fast_dev_run=1, + ) + error_context = ( + nullcontext() + if _TORCH_GREATER_EQUAL_2_0 + else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters") + ) class EmptyParametersModel(BoringModel): def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=1e-2) model = EmptyParametersModel() - with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): + with error_context: trainer.fit(model) class NoFlatParametersModel(BoringModel): @@ -349,7 +355,7 @@ def configure_optimizers(self): return torch.optim.Adam(layer.parameters(), lr=1e-2) model = NoFlatParametersModel() - with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"): + with error_context: trainer.fit(model) From ed82d0020dcde978f45e9e31b51b9ca815e37c29 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Fri, 14 Apr 2023 14:13:54 +0300 Subject: [PATCH 21/30] Always offload checkpoint to CPU --- src/lightning/pytorch/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index f0729aaca935b..75e7044839ec2 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -158,7 +158,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]: with FullyShardedDataParallel.state_dict_type( module=self.model, state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig(offload_to_cpu=self.cpu_offload.offload_params, rank0_only=True), + state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): state_dict = self.model.state_dict() return _clean_up_state_dict(state_dict) From bbd684b5fcb6d9927f521639531752aa3ad565b5 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Fri, 14 Apr 2023 14:14:15 +0300 Subject: [PATCH 22/30] Validate state_dict only on zero rank --- tests/tests_pytorch/strategies/test_fsdp.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index d7d6695afae67..aa1afed7e7a77 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -90,7 +90,7 @@ def __init__(self, wrap_min_params: int): super().__init__() self.save_hyperparameters() self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) - self.wrap_min_params = wrap_min_params + self.should_be_wrapped = [(32 * 32 + 32) > wrap_min_params, None, (32 * 2 + 2) > wrap_min_params] def configure_optimizers(self): parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters() @@ -114,8 +114,7 @@ def _assert_layer_fsdp_instance(self) -> None: precision = torch.float16 if self.trainer.precision == "16-mixed" else torch.bfloat16 for layer_num in [0, 2]: - num_params = sum(p.numel() for p in self.layer[layer_num].parameters()) - if num_params < self.wrap_min_params: + if not self.should_be_wrapped[layer_num]: # this layer is not wrapped assert not isinstance(self.layer[layer_num], FullyShardedDataParallel) continue @@ -268,6 +267,10 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): # CheckpointConnector use this to extract state dict extracted_state_dict = trainer.strategy.lightning_module_state_dict() + if trainer.global_rank != 0: + assert len(extracted_state_dict) == 0 + return + # State dict should contain same number of keys assert len(correct_state_dict) == len(extracted_state_dict) # OrderedDict should return the same keys in the same order From 8e4804ef5c50cf486e96dca1373366b0dcf32f45 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Fri, 14 Apr 2023 17:18:15 +0300 Subject: [PATCH 23/30] Offload to CPU only if trainer uses offload --- src/lightning/pytorch/strategies/fsdp.py | 2 +- tests/tests_pytorch/strategies/test_fsdp.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 75e7044839ec2..f0729aaca935b 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -158,7 +158,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]: with FullyShardedDataParallel.state_dict_type( module=self.model, state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + state_dict_config=FullStateDictConfig(offload_to_cpu=self.cpu_offload.offload_params, rank0_only=True), ): state_dict = self.model.state_dict() return _clean_up_state_dict(state_dict) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index aa1afed7e7a77..b62f449985d62 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -86,7 +86,7 @@ def _assert_layer_fsdp_instance(self) -> None: class TestFSDPModelAutoWrapped(BoringModel): - def __init__(self, wrap_min_params: int): + def __init__(self, wrap_min_params: int = 2): super().__init__() self.save_hyperparameters() self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2)) @@ -283,21 +283,21 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): [ pytest.param(TestFSDPModel(), "fsdp", None, id="manually_wrapped"), pytest.param( - TestFSDPModelAutoWrapped(wrap_min_params=2), + TestFSDPModelAutoWrapped(), FSDPStrategy, {"auto_wrap_policy": custom_auto_wrap_policy}, marks=RunIf(max_torch="2.0.0"), id="autowrap_1x", ), pytest.param( - TestFSDPModelAutoWrapped(wrap_min_params=2), + TestFSDPModelAutoWrapped(), FSDPStrategy, {"auto_wrap_policy": custom_auto_wrap_policy}, marks=RunIf(min_torch="2.0.0"), id="autowrap_2x", ), pytest.param( - TestFSDPModelAutoWrapped(wrap_min_params=2), + TestFSDPModelAutoWrapped(), FSDPStrategy, {"auto_wrap_policy": custom_fsdp_policy, "use_orig_params": True}, marks=RunIf(min_torch="2.0.0"), From 0c1568a3ef579c909b38083125f5226b94666659 Mon Sep 17 00:00:00 2001 From: Egor Spirin Date: Mon, 17 Apr 2023 17:03:43 +0300 Subject: [PATCH 24/30] Always offload checkpoint to CPU --- src/lightning/pytorch/strategies/fsdp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index f0729aaca935b..75e7044839ec2 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -158,7 +158,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]: with FullyShardedDataParallel.state_dict_type( module=self.model, state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig(offload_to_cpu=self.cpu_offload.offload_params, rank0_only=True), + state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), ): state_dict = self.model.state_dict() return _clean_up_state_dict(state_dict) From 7a1b40f5772f3716915dcd2e3d4be9b8b4e21f93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 Apr 2023 17:08:23 -0400 Subject: [PATCH 25/30] update tests --- src/lightning/pytorch/strategies/fsdp.py | 2 +- tests/tests_pytorch/strategies/test_fsdp.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 75e7044839ec2..b3f3fe2f574ed 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -158,7 +158,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]: with FullyShardedDataParallel.state_dict_type( module=self.model, state_dict_type=StateDictType.FULL_STATE_DICT, - state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + state_dict_config=FullStateDictConfig(offload_to_cpu=(self.world_size > 1), rank0_only=True), ): state_dict = self.model.state_dict() return _clean_up_state_dict(state_dict) diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 99213dbef0a91..58e4c1a4123f7 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -250,9 +250,9 @@ def custom_auto_wrap_policy( @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") -@pytest.mark.parametrize("wrap_min_params", (2, 1024, 1048576)) -def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): - """Test to ensure that state dict is extracted correctly when using FSDP strategy. +@pytest.mark.parametrize("wrap_min_params", (2, 1024, 100000000)) +def test_fsdp_strategy_full_state_dict(tmpdir, wrap_min_params): + """Test to ensure that the full state dict is extracted when using FSDP strategy. Based on `wrap_min_params`, the model will be fully wrapped, half wrapped, and not wrapped at all. """ @@ -264,17 +264,17 @@ def test_fsdp_strategy_state_dict(tmpdir, wrap_min_params): default_root_dir=tmpdir, accelerator="gpu", devices=2, strategy=strategy, precision="16-mixed", max_epochs=1 ) trainer.fit(model) - # CheckpointConnector use this to extract state dict - extracted_state_dict = trainer.strategy.lightning_module_state_dict() + + full_state_dict = trainer.strategy.lightning_module_state_dict() if trainer.global_rank != 0: - assert len(extracted_state_dict) == 0 + assert len(full_state_dict) == 0 return # State dict should contain same number of keys - assert len(correct_state_dict) == len(extracted_state_dict) + assert len(correct_state_dict) == len(full_state_dict) # OrderedDict should return the same keys in the same order - assert all(_ex == _co for _ex, _co in zip(list(extracted_state_dict.keys()), list(correct_state_dict.keys()))) + assert all(_ex == _co for _ex, _co in zip(full_state_dict.keys(), correct_state_dict.keys())) @RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True, min_torch="1.12") From 6b107bcbb4d5fc41096b6153634eed2c13a92be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 Apr 2023 17:15:12 -0400 Subject: [PATCH 26/30] move function to bottom --- src/lightning/pytorch/strategies/fsdp.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index b3f3fe2f574ed..98e657e956b7f 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -74,12 +74,6 @@ log = logging.getLogger(__name__) -def _clean_up_state_dict(state_dict: Dict[str, Any], prefix: str = "_forward_module.") -> Dict[str, Any]: - prefix_len = len(prefix) - clean_state_dict = {k[prefix_len:]: v for k, v in state_dict.items()} - return clean_state_dict - - class FSDPStrategy(ParallelStrategy): r"""Strategy for Fully Sharded Data Parallel provided by torch.distributed. @@ -161,7 +155,7 @@ def lightning_module_state_dict(self) -> Dict[str, Any]: state_dict_config=FullStateDictConfig(offload_to_cpu=(self.world_size > 1), rank0_only=True), ): state_dict = self.model.state_dict() - return _clean_up_state_dict(state_dict) + return _strip_prefix_from_state_dict(state_dict, prefix="_forward_module.") @property def root_device(self) -> torch.device: @@ -414,3 +408,9 @@ def register_strategies(cls, strategy_registry: Dict) -> None: cpu_offload=True, ) cls._registered_strategies.append("fsdp_cpu_offload") + + + +def _strip_prefix_from_state_dict(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]: + prefix_len = len(prefix) + return {k[prefix_len:]: v for k, v in state_dict.items()} From 638a374b2ac506431b4314300801d9be48e477d4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Apr 2023 21:16:34 +0000 Subject: [PATCH 27/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/strategies/fsdp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 98e657e956b7f..02f05be36b467 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -410,7 +410,6 @@ def register_strategies(cls, strategy_registry: Dict) -> None: cls._registered_strategies.append("fsdp_cpu_offload") - def _strip_prefix_from_state_dict(state_dict: Dict[str, Any], prefix: str) -> Dict[str, Any]: prefix_len = len(prefix) return {k[prefix_len:]: v for k, v in state_dict.items()} From 62441e9e8dc654f25a78fb316abbee5f5cf28e39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 Apr 2023 17:19:43 -0400 Subject: [PATCH 28/30] update documentation --- src/lightning/pytorch/strategies/fsdp.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 98e657e956b7f..88a55f0253020 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -146,7 +146,8 @@ def __init__( self.kwargs.setdefault("use_orig_params", True) def lightning_module_state_dict(self) -> Dict[str, Any]: - """Returns model state.""" + """Gathers the full state dict by unsharding all the parameters. To avoid OOM, the returned + parameters will only be returned on rank 0 and on CPU. All other ranks get an empty dict.""" assert self.model is not None with FullyShardedDataParallel.state_dict_type( From 99c0f39d3c5e89b21716daf300f80df14dce1cb0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Apr 2023 21:20:55 +0000 Subject: [PATCH 29/30] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/strategies/fsdp.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 560c6c63d12a3..0de1f5852d3af 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -146,8 +146,11 @@ def __init__( self.kwargs.setdefault("use_orig_params", True) def lightning_module_state_dict(self) -> Dict[str, Any]: - """Gathers the full state dict by unsharding all the parameters. To avoid OOM, the returned - parameters will only be returned on rank 0 and on CPU. All other ranks get an empty dict.""" + """Gathers the full state dict by unsharding all the parameters. + + To avoid OOM, the returned parameters will only be returned on rank 0 and on CPU. All other ranks get an empty + dict. + """ assert self.model is not None with FullyShardedDataParallel.state_dict_type( From 9d1675864ed2d3432c36befdeb8145c88b7b80b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Mon, 17 Apr 2023 17:22:14 -0400 Subject: [PATCH 30/30] add changelog --- src/lightning/pytorch/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 46501a1b459cc..c65ef27bb93e1 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309)) +- Enabled saving the full model state dict when using the `FSDPStrategy` ([#16558](https://github.com/Lightning-AI/lightning/pull/16558)) + ### Changed