Skip to content

Commit 440db6a

Browse files
committed
bugfix
1 parent 4f391bc commit 440db6a

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

pytorch_lightning/trainer/optimizers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727

2828
class TrainerOptimizersMixin(ABC):
2929

30+
_lightning_optimizers: Optional[List[LightningOptimizer]]
31+
3032
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
33+
self._lightning_optimizers = None
3134
optim_conf = model.configure_optimizers()
3235
if optim_conf is None:
3336
rank_zero_warn(

tests/trainer/test_trainer.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from copy import deepcopy
2020
from distutils.version import LooseVersion
2121
from pathlib import Path
22-
from unittest.mock import ANY, call, patch
22+
from unittest.mock import ANY, call, patch, Mock
2323

2424
import cloudpickle
2525
import pytest
@@ -1785,3 +1785,33 @@ def backward(self, *args, **kwargs):
17851785
"training_step",
17861786
"backward",
17871787
]
1788+
1789+
1790+
def test_init_optimizers_resets_lightning_optimizers(tmpdir):
1791+
""" Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """
1792+
1793+
def compare_optimizers():
1794+
assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0]
1795+
1796+
class OptimizerSpy(Callback):
1797+
def on_fit_start(self, *args, **kwargs):
1798+
compare_optimizers()
1799+
1800+
model = BoringModel()
1801+
model.lr = 0.2
1802+
trainer = Trainer(
1803+
default_root_dir=tmpdir,
1804+
max_epochs=1,
1805+
auto_lr_find=True,
1806+
callbacks=[OptimizerSpy()]
1807+
)
1808+
1809+
trainer.tune(model)
1810+
compare_optimizers()
1811+
1812+
trainer.fit(model)
1813+
compare_optimizers()
1814+
1815+
trainer.max_epochs = 2 # simulate multiple fit calls
1816+
trainer.fit(model)
1817+
compare_optimizers()

0 commit comments

Comments
 (0)