Skip to content

Commit 5cb313e

Browse files
authored
Merge branch 'master' into tpu-iterable-datasets
2 parents d852d22 + 0c02c44 commit 5cb313e

File tree

12 files changed

+175
-60
lines changed

12 files changed

+175
-60
lines changed

.github/workflows/_legacy-checkpoints.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ jobs:
139139
run: echo ${{ needs.create-legacy-ckpts.outputs.pl-version }} >> ${{ env.legacy_dir }}/back-compatible-versions.txt
140140

141141
- name: Create Pull Request
142-
uses: peter-evans/create-pull-request@v4
142+
uses: peter-evans/create-pull-request@v5
143143
with:
144144
title: Adding test for legacy checkpoint created with ${{ needs.create-legacy-ckpts.outputs.pl-version }}
145145
delete-branch: true

docs/source-pytorch/advanced/model_parallel.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ simplest way to do it is auto wrapping, which can serve as a drop-in replacement
8888
have to ``wrap`` layers manually as in the case of manual wrapping.
8989

9090
.. note::
91-
While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
91+
For users of PyTorch < 2.0: While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
9292
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
93-
``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future.
93+
``lightning_module.parameters()`` will return a generator with no params.
9494

9595

9696
.. code-block:: python

docs/source-pytorch/tutorials.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
:orphan:
2+
3+
PyTorch Lightning Tutorials
4+
===========================
5+
6+
.. tutoriallist::

src/lightning/fabric/CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
-
12+
- Added support for joint setup of model and optimizer with FSDP ([#17305](https://github.com/Lightning-AI/lightning/pull/17305))
13+
- Added support for handling multiple parameter groups in optimizers set up with FSDP ([#17305](https://github.com/Lightning-AI/lightning/pull/17305))
1314

1415

1516
### Changed

src/lightning/fabric/fabric.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler
2929

3030
from lightning.fabric.loggers import Logger
31+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_0
3132

3233
from lightning.fabric.plugins import Precision # avoid circular imports: # isort: split
3334
from lightning.fabric.accelerators.accelerator import Accelerator
@@ -798,7 +799,7 @@ def _validate_setup(self, module: nn.Module, optimizers: Sequence[Optimizer]) ->
798799
if any(isinstance(opt, _FabricOptimizer) for opt in optimizers):
799800
raise ValueError("An optimizer should be passed only once to the `setup` method.")
800801

801-
if isinstance(self._strategy, FSDPStrategy):
802+
if isinstance(self._strategy, FSDPStrategy) and not _TORCH_GREATER_EQUAL_2_0:
802803
raise RuntimeError(
803804
f"The `{type(self).__name__}` requires the model and optimizer(s) to be set up separately."
804805
" Create and set up the model first through `model = self.setup_model(model)`. Then create the"

src/lightning/fabric/strategies/fsdp.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
)
3737
from lightning.fabric.utilities.distributed import group as _group
3838
from lightning.fabric.utilities.distributed import ReduceOp
39-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
39+
from lightning.fabric.utilities.imports import (
40+
_TORCH_GREATER_EQUAL_1_12,
41+
_TORCH_GREATER_EQUAL_1_13,
42+
_TORCH_GREATER_EQUAL_2_0,
43+
)
4044
from lightning.fabric.utilities.rank_zero import rank_zero_only, rank_zero_warn
4145
from lightning.fabric.utilities.seed import reset_seed
4246

@@ -101,7 +105,11 @@ def __init__(
101105
self._process_group_backend: Optional[str] = process_group_backend
102106
self._timeout: Optional[timedelta] = timeout
103107
self._backward_sync_control = _FSDPBackwardSyncControl()
104-
self._ddp_kwargs = kwargs
108+
self._fsdp_kwargs = kwargs
109+
110+
if _TORCH_GREATER_EQUAL_2_0:
111+
# Enables joint setup of model and optimizer, multiple optimizer param groups, and `torch.compile()`
112+
self._fsdp_kwargs.setdefault("use_orig_params", True)
105113

106114
if activation_checkpointing and not _TORCH_GREATER_EQUAL_1_13:
107115
raise ValueError("Activation checkpointing requires torch >= 1.13.0. HINT: `pip install -U torch`")
@@ -157,28 +165,44 @@ def setup_environment(self) -> None:
157165
def setup_module_and_optimizers(
158166
self, module: Module, optimizers: List[Optimizer]
159167
) -> Tuple[Module, List[Optimizer]]:
160-
raise NotImplementedError(
161-
f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."
162-
" Please do it in this order: Create the model, call `setup_module`, create the optimizer,"
163-
" call `setup_optimizer`."
164-
)
168+
"""Wraps the model into a
169+
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module
170+
and sets `use_orig_params=True` to keep the reference to the original parameters in the
171+
optimizer.
172+
"""
173+
if not _TORCH_GREATER_EQUAL_2_0:
174+
raise NotImplementedError(
175+
f"The `{type(self).__name__}` does not support the joint setup of module and optimizer(s)."
176+
" Please do it in this order: Create the model, call `setup_module`, create the optimizer,"
177+
" call `setup_optimizer`."
178+
)
179+
use_orig_params = self._fsdp_kwargs.get("use_orig_params")
180+
if use_orig_params is False:
181+
raise ValueError(
182+
f"You set `{type(self).__name__}(use_orig_params=False)` but this is not supported when"
183+
" setting the model and optimizer up jointly. Either set it to `True` or set the objects"
184+
" up in this order: Create the model, call `setup_module`, create the optimizer,"
185+
" call `setup_optimizer`."
186+
)
187+
module = self.setup_module(module)
188+
return module, optimizers
165189

166190
def setup_module(self, module: Module) -> "FullyShardedDataParallel":
167191
"""Wraps the model into a
168192
:class:`~torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel` module."""
169193
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
170194

171-
if "auto_wrap_policy" in self._ddp_kwargs and any(
195+
if "auto_wrap_policy" in self._fsdp_kwargs and any(
172196
isinstance(mod, FullyShardedDataParallel) for mod in module.modules()
173197
):
174198
# If model is already wrapped, we need to avoid sending the `auto_wrap_policy`
175-
del self._ddp_kwargs["auto_wrap_policy"]
199+
del self._fsdp_kwargs["auto_wrap_policy"]
176200
wrapped_module = FullyShardedDataParallel(
177201
module=module,
178202
cpu_offload=self.cpu_offload,
179203
mixed_precision=self.mixed_precision_config,
180204
device_id=self.root_device.index,
181-
**self._ddp_kwargs,
205+
**self._fsdp_kwargs,
182206
)
183207

184208
# activation checkpointing needs to be set up after wrapping the model
@@ -194,6 +218,9 @@ def setup_optimizer(self, optimizer: Optimizer) -> Optimizer:
194218
that the optimizer was created after the model was wrapped with :meth:`setup_module` with a reference to the
195219
flattened parameters.
196220
"""
221+
if _TORCH_GREATER_EQUAL_2_0:
222+
return optimizer
223+
197224
from torch.distributed.fsdp import FlatParameter
198225

199226
num_groups = len(optimizer.param_groups)
@@ -224,7 +251,7 @@ def module_sharded_context(self) -> Generator:
224251
cpu_offload=self.cpu_offload,
225252
mixed_precision=self.mixed_precision_config,
226253
device_id=self.root_device.index,
227-
**self._ddp_kwargs,
254+
**self._fsdp_kwargs,
228255
):
229256
yield
230257

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
-
12+
- Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
1313

1414

1515
### Changed
1616

17+
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
18+
1719
- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))
1820

1921

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
_sync_ddp_if_available,
3333
)
3434
from lightning.fabric.utilities.distributed import group as _group
35-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
35+
from lightning.fabric.utilities.imports import (
36+
_TORCH_GREATER_EQUAL_1_12,
37+
_TORCH_GREATER_EQUAL_1_13,
38+
_TORCH_GREATER_EQUAL_2_0,
39+
)
3640
from lightning.fabric.utilities.optimizer import _optimizers_to_device
3741
from lightning.fabric.utilities.seed import reset_seed
3842
from lightning.fabric.utilities.types import ProcessGroup, ReduceOp
@@ -130,6 +134,10 @@ def __init__(
130134
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
131135
)
132136
self.kwargs = kwargs
137+
if _TORCH_GREATER_EQUAL_2_0:
138+
# Avoids the need for user to reference params in `configure_optimizers` via
139+
# `self.trainer.model.parameters()` and enables support for multiple parameter groups.
140+
self.kwargs.setdefault("use_orig_params", True)
133141

134142
@property
135143
def root_device(self) -> torch.device:
@@ -249,6 +257,9 @@ def setup(self, trainer: "pl.Trainer") -> None:
249257
self.setup_precision_plugin()
250258

251259
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
260+
if self.kwargs.get("use_orig_params"):
261+
return super().setup_optimizers(trainer)
262+
252263
invalid_params_error = False
253264
try:
254265
super().setup_optimizers(trainer)
@@ -258,6 +269,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
258269
invalid_params_error = True
259270

260271
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
272+
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
261273
raise ValueError(
262274
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
263275
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"

tests/tests_fabric/helpers/models.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch.utils.data import DataLoader, Dataset, IterableDataset
99

1010
from lightning.fabric import Fabric
11-
from lightning.fabric.strategies.fsdp import FSDPStrategy
1211

1312

1413
class RandomDataset(Dataset):
@@ -56,13 +55,8 @@ def after_optimizer_step(self, model: Module, optimizer: Optimizer) -> None:
5655

5756
def run(self) -> None:
5857
model = self.get_model()
59-
if isinstance(self.strategy, FSDPStrategy):
60-
model = self.setup_module(model)
61-
optimizer = self.get_optimizer(model)
62-
optimizer = self.setup_optimizers(optimizer)
63-
else:
64-
optimizer = self.get_optimizer(model)
65-
model, optimizer = self.setup(model, optimizer)
58+
optimizer = self.get_optimizer(model)
59+
model, optimizer = self.setup(model, optimizer)
6660

6761
dataloader = self.get_dataloader()
6862
dataloader = self.setup_dataloaders(dataloader)

tests/tests_fabric/strategies/test_fsdp.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,42 @@ def test_fsdp_cpu_offload():
5858

5959

6060
@RunIf(min_torch="1.12")
61-
def test_fsdp_setup_optimizer_validation():
61+
@pytest.mark.parametrize("torch_ge_2_0", [False, True])
62+
def test_fsdp_setup_optimizer_validation(torch_ge_2_0):
6263
"""Test that `setup_optimizer()` validates the param groups and reference to FSDP parameters."""
6364
module = nn.Linear(2, 2)
6465
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
6566

66-
bad_optimizer = Adam([{"params": [module.weight]}, {"params": [module.bias], "lr": 1e-3}])
67-
with pytest.raises(ValueError, match="does not support multiple param groups"):
68-
strategy.setup_optimizer(bad_optimizer)
67+
with mock.patch("lightning.fabric.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", torch_ge_2_0):
68+
bad_optimizer_1 = Adam([{"params": [module.weight]}, {"params": [module.bias], "lr": 1e-3}])
69+
bad_optimizer_2 = Adam(module.parameters())
6970

70-
bad_optimizer = Adam(module.parameters())
71-
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"):
72-
strategy.setup_optimizer(bad_optimizer)
71+
if torch_ge_2_0:
72+
strategy.setup_optimizer(bad_optimizer_1)
73+
strategy.setup_optimizer(bad_optimizer_2)
74+
else:
75+
with pytest.raises(ValueError, match="does not support multiple param groups"):
76+
strategy.setup_optimizer(bad_optimizer_1)
77+
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameter"):
78+
strategy.setup_optimizer(bad_optimizer_2)
79+
80+
81+
@RunIf(min_torch="2.0.0")
82+
@mock.patch("lightning.fabric.strategies.fsdp.FSDPStrategy.setup_module")
83+
def test_fsdp_setup_use_orig_params(_):
84+
module = nn.Linear(2, 2)
85+
optimizer = Adam(module.parameters())
86+
87+
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")], use_orig_params=False)
88+
assert not strategy._fsdp_kwargs["use_orig_params"]
89+
90+
with pytest.raises(ValueError, match=r"`FSDPStrategy\(use_orig_params=False\)` but this is not supported"):
91+
strategy.setup_module_and_optimizers(module, optimizer)
92+
93+
strategy = FSDPStrategy(parallel_devices=[torch.device("cpu")])
94+
assert strategy._fsdp_kwargs["use_orig_params"]
95+
strategy.setup_module_and_optimizers(module, optimizer)
96+
assert strategy._fsdp_kwargs["use_orig_params"]
7397

7498

7599
@RunIf(min_torch="1.12")

0 commit comments

Comments
 (0)