Skip to content

Commit 58c9fa7

Browse files
authored
Allow training type plugin to delay optimizer creation (FSDP 2/n) (#6331)
* Allow training_type_plugin to delay optimizer configure * Add missing references to trainer, add a CPU accelerator based test
1 parent 853523e commit 58c9fa7

File tree

4 files changed

+52
-6
lines changed

4 files changed

+52
-6
lines changed

pytorch_lightning/accelerators/accelerator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def setup(self, trainer: 'Trainer', model: LightningModule) -> None:
8585
model: the LightningModule
8686
"""
8787
self.setup_training_type_plugin(self.training_type_plugin, model)
88-
self.setup_optimizers(trainer)
88+
if not self.training_type_plugin.setup_optimizers_in_pre_dispatch:
89+
self.setup_optimizers(trainer)
8990
self.setup_precision_plugin(self.precision_plugin)
9091

9192
def start_training(self, trainer: 'Trainer') -> None:
@@ -97,12 +98,14 @@ def start_evaluating(self, trainer: 'Trainer') -> None:
9798
def start_predicting(self, trainer: 'Trainer') -> None:
9899
self.training_type_plugin.start_predicting(trainer)
99100

100-
def pre_dispatch(self) -> None:
101+
def pre_dispatch(self, trainer: 'Trainer') -> None:
101102
"""Hook to do something before the training/evaluation/prediction starts."""
102103
self.training_type_plugin.pre_dispatch()
104+
if self.training_type_plugin.setup_optimizers_in_pre_dispatch:
105+
self.setup_optimizers(trainer)
103106
self.precision_plugin.pre_dispatch()
104107

105-
def post_dispatch(self) -> None:
108+
def post_dispatch(self, trainer: 'Trainer') -> None:
106109
"""Hook to do something before the training/evaluation/prediction starts."""
107110
self.training_type_plugin.post_dispatch()
108111
self.precision_plugin.post_dispatch()

pytorch_lightning/plugins/training_type/training_type_plugin.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,3 +182,13 @@ def init_optimizers(self, trainer: "Trainer", model: LightningModule):
182182

183183
def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs):
184184
optimizer.step(closure=lambda_closure, **kwargs)
185+
186+
@property
187+
def setup_optimizers_in_pre_dispatch(self) -> bool:
188+
"""
189+
Override to delay setting optimizers and schedulers till after dispatch.
190+
This is useful when the `TrainingTypePlugin` requires operating on the wrapped accelerator model.
191+
However this may break certain precision plugins such as APEX which require optimizers to be set.
192+
Returns: If True, delay setup optimizers till pre_dispatch, else call within setup.
193+
"""
194+
return False

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,7 @@ def fit(
495495
return self.accelerator.results or 1
496496

497497
def pre_dispatch(self):
498-
self.accelerator.pre_dispatch()
498+
self.accelerator.pre_dispatch(self)
499499

500500
# log hyper-parameters
501501
if self.logger is not None:
@@ -505,7 +505,7 @@ def pre_dispatch(self):
505505
self.logger.save()
506506

507507
def post_dispatch(self):
508-
self.accelerator.post_dispatch()
508+
self.accelerator.post_dispatch(self)
509509
self.accelerator.teardown()
510510

511511
def dispatch(self):

tests/accelerators/test_cpu.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22

33
import pytest
44
import torch
5-
5+
from pytorch_lightning import Trainer
66
from pytorch_lightning.accelerators import CPUAccelerator
77
from pytorch_lightning.plugins import SingleDevicePlugin
88
from pytorch_lightning.plugins.precision import MixedPrecisionPlugin
99
from pytorch_lightning.utilities.exceptions import MisconfigurationException
10+
from tests.helpers.boring_model import BoringModel
1011

1112

1213
def test_unsupported_precision_plugins():
@@ -18,3 +19,35 @@ def test_unsupported_precision_plugins():
1819
)
1920
with pytest.raises(MisconfigurationException, match=r"amp \+ cpu is not supported."):
2021
accelerator.setup(trainer=trainer, model=model)
22+
23+
24+
@pytest.mark.parametrize("delay_dispatch", [True, False])
25+
def test_plugin_setup_optimizers_in_pre_dispatch(tmpdir, delay_dispatch):
26+
"""
27+
Test when using a custom training type plugin that delays setup optimizers,
28+
we do not call setup optimizers till ``pre_dispatch``.
29+
"""
30+
31+
class TestModel(BoringModel):
32+
def on_fit_start(self):
33+
if delay_dispatch:
34+
# Ensure we haven't setup optimizers if we've delayed dispatch
35+
assert len(self.trainer.optimizers) == 0
36+
else:
37+
assert len(self.trainer.optimizers) > 0
38+
39+
def on_fit_end(self):
40+
assert len(self.trainer.optimizers) > 0
41+
42+
class CustomPlugin(SingleDevicePlugin):
43+
@property
44+
def setup_optimizers_in_pre_dispatch(self) -> bool:
45+
return delay_dispatch
46+
47+
model = TestModel()
48+
trainer = Trainer(
49+
default_root_dir=tmpdir,
50+
fast_dev_run=True,
51+
plugins=CustomPlugin(device=torch.device("cpu"))
52+
)
53+
trainer.fit(model)

0 commit comments

Comments
 (0)