Skip to content

Commit d1efae2

Browse files
simran2905awaelchlicarmoccaBorda
authored
Fix checkpointed state for lr_schedulers with step interval (#7877)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 2303f9c commit d1efae2

File tree

7 files changed

+123
-17
lines changed

7 files changed

+123
-17
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
260260
### Fixed
261261

262262

263+
- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))
264+
265+
263266
- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))
264267

265268

pytorch_lightning/loops/fit_loop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def on_advance_end(self) -> None:
219219
if self.training_loop.batches_seen == 0:
220220
return
221221

222-
self.training_loop.update_lr_schedulers('epoch')
222+
self.training_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True)
223223

224224
did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.skip
225225
if did_train_only:

pytorch_lightning/loops/training_epoch_loop.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
115115
if batch_output.signal == -1:
116116
raise StopIteration
117117

118+
# update non-plateau LR schedulers
119+
# update epoch-interval ones only when we are at the end of training epoch
120+
self.update_lr_schedulers('step', update_plateau_schedulers=False)
121+
if self._num_training_batches_reached(is_last):
122+
self.update_lr_schedulers('epoch', update_plateau_schedulers=False)
123+
118124
batch_end_outputs = [opt_idx_out for opt_idx_out in batch_output.training_step_output if len(opt_idx_out)]
119125
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)
120126

@@ -153,8 +159,8 @@ def on_advance_end(self):
153159
# -----------------------------------------
154160
self.save_loggers_on_train_batch_end()
155161

156-
# update LR schedulers
157-
self.update_lr_schedulers('step')
162+
# update plateau LR scheduler after metrics are logged
163+
self.update_lr_schedulers('step', update_plateau_schedulers=True)
158164
self.trainer.checkpoint_connector.has_trained = True
159165

160166
self.total_batch_idx += 1
@@ -351,15 +357,13 @@ def _prepare_outputs(
351357
processed_outputs = processed_outputs[0]
352358
return processed_outputs
353359

354-
def update_lr_schedulers(self, interval: str) -> None:
360+
def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None:
355361
"""updates the lr schedulers based on the given interval"""
356-
if interval == "step":
357-
finished_accumulation = self.batch_loop._accumulated_batches_reached()
358-
finished_epoch = self._num_training_batches_reached()
359-
if not finished_accumulation and not finished_epoch:
360-
return
362+
if interval == "step" and self.batch_loop.should_accumulate():
363+
return
361364
self.trainer.optimizer_connector.update_learning_rates(
362365
interval=interval,
366+
update_plateau_schedulers=update_plateau_schedulers,
363367
opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)],
364368
)
365369

pytorch_lightning/trainer/connectors/optimizer_connector.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,17 @@ def on_trainer_init(self) -> None:
2929
self.trainer.optimizers = []
3030
self.trainer.optimizer_frequencies = []
3131

32-
def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None:
32+
def update_learning_rates(
33+
self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None
34+
) -> None:
3335
"""Update learning rates.
3436
3537
Args:
3638
interval: either 'epoch' or 'step'.
39+
update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated.
40+
This is used so non-plateau schedulers can be updated before running validation. Checkpoints are
41+
commonly saved during validation, however, on-plateau schedulers might monitor a validation metric
42+
so they have to be updated separately.
3743
opt_indices: indices of the optimizers to update.
3844
"""
3945
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
@@ -46,6 +52,9 @@ def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]]
4652
if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices:
4753
continue
4854

55+
if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
56+
continue
57+
4958
current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch
5059
current_idx += 1 # account for both batch and epoch starts from 0
5160
# Take step if call to update_learning_rates matches the interval key and

tests/callbacks/test_finetuning_callback.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727

2828
class TestBackboneFinetuningCallback(BackboneFinetuning):
2929

30-
def on_train_epoch_end(self, trainer, pl_module):
30+
def on_train_epoch_start(self, trainer, pl_module):
31+
super().on_train_epoch_start(trainer, pl_module)
3132
epoch = trainer.current_epoch
3233
if self.unfreeze_backbone_at_epoch <= epoch:
3334
optimizer = trainer.optimizers[0]

tests/checkpointing/test_model_checkpoint.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,9 @@ def on_validation_epoch_end(self):
162162
if not reduce_lr_on_plateau:
163163
actual_step_count = chk['lr_schedulers'][0]['_step_count']
164164
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
165-
# if validation_step_none, the checkpoint gets saved after the learning rate update
166-
# so we need to increase the count by one
167-
assert actual_step_count == epoch + 1 + validation_step_none
168-
assert actual_lr == lr * gamma**(epoch + validation_step_none)
165+
# checkpoint is saved after updating lr_scheduler states
166+
assert actual_step_count == epoch + 2 # step_count starts at 1
167+
assert actual_lr == lr * gamma**(epoch + 1)
169168

170169
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
171170
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
@@ -262,6 +261,11 @@ def _make_assertions(epoch, ix, version=''):
262261
global_ix = ix + per_epoch_val_checks * epoch
263262
duplicated = bool(version)
264263

264+
# checkpoint saved at the end of training epoch will have updated lr_scheduler states
265+
epoch_end_checkpoint = duplicated
266+
if epoch_aligned:
267+
epoch_end_checkpoint = ix == (per_epoch_val_checks - 1)
268+
265269
score = model.scores[global_ix]
266270
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
267271
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
@@ -281,8 +285,8 @@ def _make_assertions(epoch, ix, version=''):
281285
if not reduce_lr_on_plateau:
282286
actual_step_count = chk['lr_schedulers'][0]['_step_count']
283287
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
284-
assert actual_step_count == epoch + 1 + duplicated
285-
assert actual_lr == lr * gamma**(epoch + duplicated)
288+
assert actual_step_count == epoch + 1 + epoch_end_checkpoint
289+
assert actual_lr == lr * gamma**(epoch + epoch_end_checkpoint)
286290

287291
return score
288292

tests/trainer/optimization/test_optimizers.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import optim
1919

2020
from pytorch_lightning import Callback, Trainer
21+
from pytorch_lightning.callbacks import ModelCheckpoint
2122
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2223
from tests.base import EvalModelTemplate
2324
from tests.helpers.boring_model import BoringModel
@@ -620,3 +621,87 @@ def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch
620621
)
621622
trainer.fit(model)
622623
assert mocked_sched.call_count == expected_steps
624+
625+
626+
@pytest.mark.parametrize('every_n_train_steps, epoch_interval', [(None, True), (2, False), (2, True)])
627+
def test_lr_scheduler_state_updated_before_saving(tmpdir, every_n_train_steps, epoch_interval):
628+
batches = 2
629+
max_epochs = 1
630+
lr, gamma = 1, 10
631+
trainer = Trainer(
632+
default_root_dir=tmpdir,
633+
progress_bar_refresh_rate=0,
634+
logger=False,
635+
max_epochs=max_epochs,
636+
limit_train_batches=batches,
637+
limit_val_batches=1,
638+
callbacks=[ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=every_n_train_steps)]
639+
)
640+
641+
class TestModel(BoringModel):
642+
643+
def configure_optimizers(self):
644+
optimizer = torch.optim.SGD(self.parameters(), lr=lr)
645+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
646+
lr_dict = {'scheduler': lr_scheduler}
647+
if not epoch_interval:
648+
lr_dict['interval'] = 'step'
649+
return [optimizer], [lr_dict]
650+
651+
def on_save_checkpoint(self, checkpoint):
652+
lr_dict = checkpoint['lr_schedulers'][0]
653+
# 2 batches ran. since the lr_dict interval is `step`, the step count should be 2
654+
assert self.trainer.global_step + 1 == batches # the global step hasn't been increased yet
655+
compare_to = max_epochs if epoch_interval else batches
656+
assert lr_dict['_step_count'] - 1 == compare_to # step count starts at 1
657+
assert lr_dict['_last_lr'] == [lr * gamma**compare_to]
658+
self.on_save_checkpoint_called = True
659+
660+
model = TestModel()
661+
trainer.fit(model)
662+
assert model.on_save_checkpoint_called
663+
664+
665+
def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir):
666+
batches = 4
667+
trainer = Trainer(
668+
default_root_dir=tmpdir,
669+
progress_bar_refresh_rate=0,
670+
logger=False,
671+
max_epochs=1,
672+
limit_train_batches=batches,
673+
limit_val_batches=1,
674+
callbacks=[ModelCheckpoint(dirpath=tmpdir)]
675+
)
676+
677+
class TestModel(BoringModel):
678+
679+
def training_step(self, batch, batch_idx, optimizer_idx):
680+
self.log("foo", batch_idx)
681+
return super().training_step(batch, batch_idx)
682+
683+
def configure_optimizers(self):
684+
optimizer_1 = torch.optim.Adam(self.parameters())
685+
optimizer_2 = torch.optim.Adam(self.parameters())
686+
687+
lr_scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_1)
688+
lr_dict_1 = {'scheduler': lr_scheduler1, 'interval': 'step', 'monitor': 'foo'}
689+
690+
lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=1)
691+
lr_dict_2 = {'scheduler': lr_scheduler2, 'interval': 'step'}
692+
return [optimizer_1, optimizer_2], [lr_dict_1, lr_dict_2]
693+
694+
def on_save_checkpoint(self, checkpoint):
695+
lr_dict_1 = checkpoint['lr_schedulers'][0]
696+
# since plateau schedulers are updated after saving checkpoint, last_epoch should be 3
697+
assert lr_dict_1['last_epoch'] == batches - 1 # last epoch starts at 0
698+
699+
lr_dict_2 = checkpoint['lr_schedulers'][1]
700+
assert lr_dict_2['_step_count'] - 1 == batches # step count starts at 1
701+
702+
self.on_save_checkpoint_called = True
703+
704+
model = TestModel()
705+
model.training_epoch_end = None
706+
trainer.fit(model)
707+
assert model.on_save_checkpoint_called

0 commit comments

Comments
 (0)