Skip to content

Commit 0c02c44

Browse files
Simplified setup of optimizers in FSDP (#17309)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 51697a8 commit 0c02c44

File tree

4 files changed

+48
-8
lines changed

4 files changed

+48
-8
lines changed

docs/source-pytorch/advanced/model_parallel.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,9 @@ simplest way to do it is auto wrapping, which can serve as a drop-in replacement
8888
have to ``wrap`` layers manually as in the case of manual wrapping.
8989

9090
.. note::
91-
While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
91+
For users of PyTorch < 2.0: While initializing the optimizers inside ``configure_optimizers`` hook, make sure to use ``self.trainer.model.parameters()``, else
9292
PyTorch will raise an error. This is required because when you use auto-wrap, the model layers are sharded and your
93-
``lightning_module.parameters()`` will return a generator with no params. This inconvenience will be addressed in the future.
93+
``lightning_module.parameters()`` will return a generator with no params.
9494

9595

9696
.. code-block:: python

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12-
-
12+
- Added support for multiple optimizer parameter groups when using the FSDP strategy ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
1313

1414

1515
### Changed
1616

17+
- Removed the limitation to call `self.trainer.model.parameters()` in `LightningModule.configure_optimizers()` ([#17309](https://github.com/Lightning-AI/lightning/pull/17309))
18+
1719
- Generalized `Optimizer` validation to accommodate both FSDP 1.x and 2.x ([#16733](https://github.com/Lightning-AI/lightning/pull/16733))
1820

1921

src/lightning/pytorch/strategies/fsdp.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
_sync_ddp_if_available,
3333
)
3434
from lightning.fabric.utilities.distributed import group as _group
35-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_1_12, _TORCH_GREATER_EQUAL_1_13
35+
from lightning.fabric.utilities.imports import (
36+
_TORCH_GREATER_EQUAL_1_12,
37+
_TORCH_GREATER_EQUAL_1_13,
38+
_TORCH_GREATER_EQUAL_2_0,
39+
)
3640
from lightning.fabric.utilities.optimizer import _optimizers_to_device
3741
from lightning.fabric.utilities.seed import reset_seed
3842
from lightning.fabric.utilities.types import ProcessGroup, ReduceOp
@@ -130,6 +134,10 @@ def __init__(
130134
[activation_checkpointing] if not isinstance(activation_checkpointing, list) else activation_checkpointing
131135
)
132136
self.kwargs = kwargs
137+
if _TORCH_GREATER_EQUAL_2_0:
138+
# Avoids the need for user to reference params in `configure_optimizers` via
139+
# `self.trainer.model.parameters()` and enables support for multiple parameter groups.
140+
self.kwargs.setdefault("use_orig_params", True)
133141

134142
@property
135143
def root_device(self) -> torch.device:
@@ -249,6 +257,9 @@ def setup(self, trainer: "pl.Trainer") -> None:
249257
self.setup_precision_plugin()
250258

251259
def setup_optimizers(self, trainer: "pl.Trainer") -> None:
260+
if self.kwargs.get("use_orig_params"):
261+
return super().setup_optimizers(trainer)
262+
252263
invalid_params_error = False
253264
try:
254265
super().setup_optimizers(trainer)
@@ -258,6 +269,7 @@ def setup_optimizers(self, trainer: "pl.Trainer") -> None:
258269
invalid_params_error = True
259270

260271
if invalid_params_error or any(not _optimizer_has_flat_params(optimizer) for optimizer in self.optimizers):
272+
# We avoid this limitation in PyTorch >= 2.0 by setting `use_orig_params=True`
261273
raise ValueError(
262274
"The optimizer does not seem to reference any FSDP parameters. HINT: Make sure to create the"
263275
" optimizer after setting up the model by referencing `self.trainer.model.parameters()` in the"

tests/tests_pytorch/strategies/test_fsdp.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from contextlib import nullcontext
23
from functools import partial
34
from typing import Any, Callable, Dict, Optional
45
from unittest import mock
@@ -90,7 +91,8 @@ def __init__(self):
9091
self.layer = torch.nn.Sequential(torch.nn.Linear(32, 32), torch.nn.ReLU(), torch.nn.Linear(32, 2))
9192

9293
def configure_optimizers(self):
93-
return torch.optim.SGD(self.trainer.model.parameters(), lr=0.1)
94+
parameters = self.parameters() if _TORCH_GREATER_EQUAL_2_0 else self.trainer.model.parameters()
95+
return torch.optim.SGD(parameters, lr=0.1)
9496

9597
def on_train_batch_end(self, *_) -> None:
9698
self._assert_layer_fsdp_instance()
@@ -297,14 +299,24 @@ def test_fsdp_checkpoint_multi_gpus(tmpdir, model, strategy, strategy_cfg):
297299

298300
@RunIf(min_cuda_gpus=1, skip_windows=True, standalone=True, min_torch="1.12")
299301
def test_invalid_parameters_in_optimizer():
300-
trainer = Trainer(strategy="fsdp", accelerator="cuda", devices=1)
302+
trainer = Trainer(
303+
strategy="fsdp",
304+
accelerator="cuda",
305+
devices=1,
306+
fast_dev_run=1,
307+
)
308+
error_context = (
309+
nullcontext()
310+
if _TORCH_GREATER_EQUAL_2_0
311+
else pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters")
312+
)
301313

302314
class EmptyParametersModel(BoringModel):
303315
def configure_optimizers(self):
304316
return torch.optim.Adam(self.parameters(), lr=1e-2)
305317

306318
model = EmptyParametersModel()
307-
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
319+
with error_context:
308320
trainer.fit(model)
309321

310322
class NoFlatParametersModel(BoringModel):
@@ -313,7 +325,7 @@ def configure_optimizers(self):
313325
return torch.optim.Adam(layer.parameters(), lr=1e-2)
314326

315327
model = NoFlatParametersModel()
316-
with pytest.raises(ValueError, match="The optimizer does not seem to reference any FSDP parameters"):
328+
with error_context:
317329
trainer.fit(model)
318330

319331

@@ -370,3 +382,17 @@ def test_fsdp_strategy_cpu_offload():
370382
config = CPUOffload()
371383
strategy = FSDPStrategy(cpu_offload=config)
372384
assert strategy.cpu_offload == config
385+
386+
387+
@RunIf(min_torch="1.12")
388+
def test_fsdp_use_orig_params():
389+
"""Test that Lightning enables `use_orig_params` in PyTorch >= 2.0."""
390+
with mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", False):
391+
strategy = FSDPStrategy()
392+
assert "use_orig_params" not in strategy.kwargs
393+
394+
with mock.patch("lightning.pytorch.strategies.fsdp._TORCH_GREATER_EQUAL_2_0", True):
395+
strategy = FSDPStrategy()
396+
assert strategy.kwargs["use_orig_params"]
397+
strategy = FSDPStrategy(use_orig_params=False)
398+
assert not strategy.kwargs["use_orig_params"]

0 commit comments

Comments
 (0)