Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
fe27312
Fix lr_scheduler state saved at checkpoint
simran2905 Jun 8, 2021
384d9b1
Update docs
simran2905 Jun 8, 2021
99065a9
Update test
simran2905 Jun 8, 2021
c6978fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
3fc3948
Fix lr_scheduler state saved at checkpoint
simran2905 Jun 8, 2021
b5a478a
Update docs
simran2905 Jun 8, 2021
a61ddae
Update test
simran2905 Jun 8, 2021
c639cd5
Add test for reduce_on_plateau
simran2905 Jun 8, 2021
2179f90
add test
simran2905 Jun 8, 2021
8d0a486
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
bb6ed45
Fix lr_scheduler state saved at checkpoint
simran2905 Jun 8, 2021
db4fd8a
Update docs
simran2905 Jun 8, 2021
31c22c4
Update test
simran2905 Jun 8, 2021
f0ec391
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
94c1306
Fix lr_scheduler state saved at checkpoint
simran2905 Jun 8, 2021
2593e7f
Add test for reduce_on_plateau
simran2905 Jun 8, 2021
290c8ee
fix format
simran2905 Jun 8, 2021
f7475fa
fix format
simran2905 Jun 8, 2021
eee22e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
b0b740e
fix bad merge
simran2905 Jun 8, 2021
33fca46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2021
a0a3752
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
simran2905 Jun 15, 2021
3a11491
Update Changelog
simran2905 Jun 15, 2021
f811b24
Add epoch-interval support
simran2905 Jun 15, 2021
0bfaadf
Update tests
simran2905 Jun 15, 2021
75aef70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 15, 2021
187f516
Merge branch 'master' into lr_scheduler_checkpoint
awaelchli Jun 15, 2021
e0ab2d2
update new loop with changes
awaelchli Jun 15, 2021
4df1c18
undo changes from training_loop.py
awaelchli Jun 15, 2021
a0f5190
Nits and improvements
carmocca Jun 15, 2021
72a57c0
Merge branch 'master' into lr_scheduler_checkpoint
carmocca Jun 15, 2021
3598cac
make update_plateau_schedulers required
simran2905 Jun 16, 2021
8fd5ab7
Update test
simran2905 Jun 16, 2021
1e1e277
Required argument
carmocca Jun 17, 2021
665a9d6
Merge branch 'master' into lr_scheduler_checkpoint
carmocca Jun 17, 2021
216a457
Update tests/checkpointing/test_model_checkpoint.py
carmocca Jun 18, 2021
36b4747
Merge branch 'master' into lr_scheduler_checkpoint
carmocca Jun 18, 2021
2f09fe6
Merge branch 'master' into lr_scheduler_checkpoint
carmocca Jun 19, 2021
bbbfce7
Merge branch 'master' into lr_scheduler_checkpoint
carmocca Jun 21, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed


- Fixed `lr_scheduler` checkpointed state by calling `update_lr_schedulers` before saving checkpoints ([#7877](https://github.com/PyTorchLightning/pytorch-lightning/pull/7877))


- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def on_advance_end(self) -> None:
if self.training_loop.batches_seen == 0:
return

self.training_loop.update_lr_schedulers('epoch')
self.training_loop.update_lr_schedulers('epoch', update_plateau_schedulers=True)

did_train_only = self.trainer.disable_validation or self.trainer.evaluation_loop.skip
if did_train_only:
Expand Down
20 changes: 12 additions & 8 deletions pytorch_lightning/loops/training_epoch_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ def advance(self, dataloader_iter: Iterator, **kwargs: Any) -> None:
if batch_output.signal == -1:
raise StopIteration

# update non-plateau LR schedulers
# update epoch-interval ones only when we are at the end of training epoch
self.update_lr_schedulers('step', update_plateau_schedulers=False)
if self._num_training_batches_reached(is_last):
self.update_lr_schedulers('epoch', update_plateau_schedulers=False)

batch_end_outputs = [opt_idx_out for opt_idx_out in batch_output.training_step_output if len(opt_idx_out)]
processed_batch_end_outputs = self._prepare_outputs(batch_end_outputs, batch_mode=True)

Expand Down Expand Up @@ -153,8 +159,8 @@ def on_advance_end(self):
# -----------------------------------------
self.save_loggers_on_train_batch_end()

# update LR schedulers
self.update_lr_schedulers('step')
# update plateau LR scheduler after metrics are logged
self.update_lr_schedulers('step', update_plateau_schedulers=True)
self.trainer.checkpoint_connector.has_trained = True

self.total_batch_idx += 1
Expand Down Expand Up @@ -351,15 +357,13 @@ def _prepare_outputs(
processed_outputs = processed_outputs[0]
return processed_outputs

def update_lr_schedulers(self, interval: str) -> None:
def update_lr_schedulers(self, interval: str, update_plateau_schedulers: bool) -> None:
"""updates the lr schedulers based on the given interval"""
if interval == "step":
finished_accumulation = self.batch_loop._accumulated_batches_reached()
finished_epoch = self._num_training_batches_reached()
if not finished_accumulation and not finished_epoch:
return
if interval == "step" and self.batch_loop.should_accumulate():
return
self.trainer.optimizer_connector.update_learning_rates(
interval=interval,
update_plateau_schedulers=update_plateau_schedulers,
opt_indices=[opt_idx for opt_idx, _ in self.batch_loop.get_active_optimizers(self.total_batch_idx)],
)

Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/trainer/connectors/optimizer_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ def on_trainer_init(self) -> None:
self.trainer.optimizers = []
self.trainer.optimizer_frequencies = []

def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]] = None) -> None:
def update_learning_rates(
self, interval: str, update_plateau_schedulers: bool, opt_indices: Optional[List[int]] = None
) -> None:
"""Update learning rates.

Args:
interval: either 'epoch' or 'step'.
update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated.
This is used so non-plateau schedulers can be updated before running validation. Checkpoints are
commonly saved during validation, however, on-plateau schedulers might monitor a validation metric
so they have to be updated separately.
opt_indices: indices of the optimizers to update.
"""
if not self.trainer.lr_schedulers or not self.trainer.lightning_module.automatic_optimization:
Expand All @@ -46,6 +52,9 @@ def update_learning_rates(self, interval: str, opt_indices: Optional[List[int]]
if isinstance(lr_scheduler['opt_idx'], int) and lr_scheduler['opt_idx'] not in opt_indices:
continue

if update_plateau_schedulers ^ lr_scheduler["reduce_on_plateau"]:
continue

current_idx = self.trainer.train_loop.batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
# Take step if call to update_learning_rates matches the interval key and
Expand Down
3 changes: 2 additions & 1 deletion tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@

class TestBackboneFinetuningCallback(BackboneFinetuning):

def on_train_epoch_end(self, trainer, pl_module):
def on_train_epoch_start(self, trainer, pl_module):
super().on_train_epoch_start(trainer, pl_module)
epoch = trainer.current_epoch
if self.unfreeze_backbone_at_epoch <= epoch:
optimizer = trainer.optimizers[0]
Expand Down
16 changes: 10 additions & 6 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,9 @@ def on_validation_epoch_end(self):
if not reduce_lr_on_plateau:
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
# if validation_step_none, the checkpoint gets saved after the learning rate update
# so we need to increase the count by one
assert actual_step_count == epoch + 1 + validation_step_none
assert actual_lr == lr * gamma**(epoch + validation_step_none)
# checkpoint is saved after updating lr_scheduler states
assert actual_step_count == epoch + 2 # step_count starts at 1
assert actual_lr == lr * gamma**(epoch + 1)

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
Expand Down Expand Up @@ -262,6 +261,11 @@ def _make_assertions(epoch, ix, version=''):
global_ix = ix + per_epoch_val_checks * epoch
duplicated = bool(version)

# checkpoint saved at the end of training epoch will have updated lr_scheduler states
epoch_end_checkpoint = duplicated
if epoch_aligned:
epoch_end_checkpoint = ix == (per_epoch_val_checks - 1)

score = model.scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{version}.ckpt'
Expand All @@ -281,8 +285,8 @@ def _make_assertions(epoch, ix, version=''):
if not reduce_lr_on_plateau:
actual_step_count = chk['lr_schedulers'][0]['_step_count']
actual_lr = chk['lr_schedulers'][0]['_last_lr'][0]
assert actual_step_count == epoch + 1 + duplicated
assert actual_lr == lr * gamma**(epoch + duplicated)
assert actual_step_count == epoch + 1 + epoch_end_checkpoint
assert actual_lr == lr * gamma**(epoch + epoch_end_checkpoint)

return score

Expand Down
85 changes: 85 additions & 0 deletions tests/trainer/optimization/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import optim

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers.boring_model import BoringModel
Expand Down Expand Up @@ -620,3 +621,87 @@ def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch
)
trainer.fit(model)
assert mocked_sched.call_count == expected_steps


@pytest.mark.parametrize('every_n_train_steps, epoch_interval', [(None, True), (2, False), (2, True)])
def test_lr_scheduler_state_updated_before_saving(tmpdir, every_n_train_steps, epoch_interval):
batches = 2
max_epochs = 1
lr, gamma = 1, 10
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
logger=False,
max_epochs=max_epochs,
limit_train_batches=batches,
limit_val_batches=1,
callbacks=[ModelCheckpoint(dirpath=tmpdir, every_n_train_steps=every_n_train_steps)]
)

class TestModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=lr)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=gamma)
lr_dict = {'scheduler': lr_scheduler}
if not epoch_interval:
lr_dict['interval'] = 'step'
return [optimizer], [lr_dict]

def on_save_checkpoint(self, checkpoint):
lr_dict = checkpoint['lr_schedulers'][0]
# 2 batches ran. since the lr_dict interval is `step`, the step count should be 2
assert self.trainer.global_step + 1 == batches # the global step hasn't been increased yet
compare_to = max_epochs if epoch_interval else batches
assert lr_dict['_step_count'] - 1 == compare_to # step count starts at 1
assert lr_dict['_last_lr'] == [lr * gamma**compare_to]
self.on_save_checkpoint_called = True

model = TestModel()
trainer.fit(model)
assert model.on_save_checkpoint_called


def test_plateau_scheduler_lr_step_interval_updated_after_saving(tmpdir):
batches = 4
trainer = Trainer(
default_root_dir=tmpdir,
progress_bar_refresh_rate=0,
logger=False,
max_epochs=1,
limit_train_batches=batches,
limit_val_batches=1,
callbacks=[ModelCheckpoint(dirpath=tmpdir)]
)

class TestModel(BoringModel):

def training_step(self, batch, batch_idx, optimizer_idx):
self.log("foo", batch_idx)
return super().training_step(batch, batch_idx)

def configure_optimizers(self):
optimizer_1 = torch.optim.Adam(self.parameters())
optimizer_2 = torch.optim.Adam(self.parameters())

lr_scheduler1 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_1)
lr_dict_1 = {'scheduler': lr_scheduler1, 'interval': 'step', 'monitor': 'foo'}

lr_scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer_2, step_size=1)
lr_dict_2 = {'scheduler': lr_scheduler2, 'interval': 'step'}
return [optimizer_1, optimizer_2], [lr_dict_1, lr_dict_2]

def on_save_checkpoint(self, checkpoint):
lr_dict_1 = checkpoint['lr_schedulers'][0]
# since plateau schedulers are updated after saving checkpoint, last_epoch should be 3
assert lr_dict_1['last_epoch'] == batches - 1 # last epoch starts at 0

lr_dict_2 = checkpoint['lr_schedulers'][1]
assert lr_dict_2['_step_count'] - 1 == batches # step count starts at 1

self.on_save_checkpoint_called = True

model = TestModel()
model.training_epoch_end = None
trainer.fit(model)
assert model.on_save_checkpoint_called