Skip to content

Commit 2e8a466

Browse files
awaelchlicarmocca
authored andcommitted
Fix trainer not resetting lightning_optimizers (#6372)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 0783c0b commit 2e8a466

File tree

3 files changed

+32
-1
lines changed

3 files changed

+32
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4545
- Ensure that clip gradients is only called if the value is greater than 0 ([#6330](https://github.com/PyTorchLightning/pytorch-lightning/pull/6330)
4646

4747

48+
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
49+
50+
4851
## [1.2.2] - 2021-03-02
4952

5053
### Added

pytorch_lightning/trainer/optimizers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from abc import ABC
16-
from typing import List, Optional, Tuple, Dict, Any
16+
from typing import Any, Dict, List, Optional, Tuple
1717

1818
import torch
1919
from torch import optim
@@ -27,7 +27,10 @@
2727

2828
class TrainerOptimizersMixin(ABC):
2929

30+
_lightning_optimizers: Optional[List[LightningOptimizer]]
31+
3032
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
33+
self._lightning_optimizers = None
3134
optim_conf = model.configure_optimizers()
3235
if optim_conf is None:
3336
rank_zero_warn(

tests/trainer/test_trainer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1825,3 +1825,28 @@ def backward(self, *args, **kwargs):
18251825
"training_step",
18261826
"backward",
18271827
]
1828+
1829+
1830+
def test_init_optimizers_resets_lightning_optimizers(tmpdir):
1831+
""" Test that the Trainer resets the `lightning_optimizers` list everytime new optimizers get initialized. """
1832+
1833+
def compare_optimizers():
1834+
assert trainer.lightning_optimizers[0].optimizer is trainer.optimizers[0]
1835+
1836+
model = BoringModel()
1837+
model.lr = 0.2
1838+
trainer = Trainer(
1839+
default_root_dir=tmpdir,
1840+
max_epochs=1,
1841+
auto_lr_find=True,
1842+
)
1843+
1844+
trainer.tune(model)
1845+
compare_optimizers()
1846+
1847+
trainer.fit(model)
1848+
compare_optimizers()
1849+
1850+
trainer.max_epochs = 2 # simulate multiple fit calls
1851+
trainer.fit(model)
1852+
compare_optimizers()

0 commit comments

Comments
 (0)