Skip to content

Commit a16a75b

Browse files
authored
Merge cc73c5c into facfda8
2 parents facfda8 + cc73c5c commit a16a75b

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
107107
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))
108108

109109

110-
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
110+
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
111111

112112

113113
## [1.2.2] - 2021-03-02

pytorch_lightning/trainer/optimizers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@
2727

2828
class TrainerOptimizersMixin(ABC):
2929

30+
_lightning_optimizers: Optional[List[LightningOptimizer]]
31+
3032
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
33+
self._lightning_optimizers = None
3134
optim_conf = model.configure_optimizers()
3235
if optim_conf is None:
3336
rank_zero_warn(

tests/trainer/test_trainer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1803,3 +1803,28 @@ def backward(self, *args, **kwargs):
18031803
"training_step",
18041804
"backward",
18051805
]
1806+
1807+
1808+
def test_init_optimizers_resets_lightning_optimizers(tmpdir):
1809+
""" Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """
1810+
1811+
def compare_optimizers():
1812+
assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0]
1813+
1814+
model = BoringModel()
1815+
model.lr = 0.2
1816+
trainer = Trainer(
1817+
default_root_dir=tmpdir,
1818+
max_epochs=1,
1819+
auto_lr_find=True,
1820+
)
1821+
1822+
trainer.tune(model)
1823+
compare_optimizers()
1824+
1825+
trainer.fit(model)
1826+
compare_optimizers()
1827+
1828+
trainer.max_epochs = 2 # simulate multiple fit calls
1829+
trainer.fit(model)
1830+
compare_optimizers()

0 commit comments

Comments
 (0)