-
Notifications
You must be signed in to change notification settings - Fork 3.6k
[WIP] Fix/lr schedulers update calling order #7708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,12 +14,15 @@ | |
| import os | ||
| from copy import deepcopy | ||
|
|
||
| import pytest | ||
| import torch | ||
| from torch.utils.data import DataLoader | ||
|
|
||
| import pytorch_lightning as pl | ||
| from pytorch_lightning import seed_everything, Trainer | ||
| from pytorch_lightning.callbacks import ModelCheckpoint | ||
| from tests.helpers import BoringModel | ||
| from pytorch_lightning.utilities.cloud_io import load as pl_load | ||
| from tests.helpers import BoringModel, RandomDataset | ||
|
|
||
|
|
||
| def test_finetuning_with_resume_from_checkpoint(tmpdir): | ||
|
|
@@ -84,3 +87,40 @@ def validation_step(self, batch, batch_idx): | |
| assert best_model_path.endswith(f"epoch=0{idx}.ckpt") | ||
| else: | ||
| assert f"epoch={idx + 1}" in best_model_path | ||
|
|
||
|
|
||
| @pytest.mark.parametrize(['max_epochs', 'data_length'], [(1, 64), (2, 64), (3, 32)]) | ||
| def test_lr_schedulers_step_count(tmpdir, max_epochs, data_length): | ||
| """ | ||
| This test validates that checkpoint is always saved after lr_scheduler beeing updated during training | ||
| """ | ||
|
|
||
| class TestModel(BoringModel): | ||
|
|
||
| def configure_optimizers(self): | ||
| optimizer = torch.optim.SGD(self.parameters(), lr=0.001) | ||
| lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) | ||
| lr_scheduler_dict = {'scheduler': lr_scheduler, 'interval': 'step'} | ||
| return [optimizer], [lr_scheduler_dict] | ||
|
|
||
| def train_dataloader(self): | ||
| return DataLoader(RandomDataset(32, data_length)) | ||
|
|
||
| train_step_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_train_step", every_n_train_steps=1) | ||
| val_epoch_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_val_epoch", every_n_val_epochs=1) | ||
|
Comment on lines
+109
to
+110
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need to save/load a checkpoint? Can't you just test the state |
||
|
|
||
| model = TestModel() | ||
| trainer = Trainer( | ||
| default_root_dir=tmpdir, | ||
| max_epochs=max_epochs, | ||
| callbacks=[train_step_checkpoint_callback, val_epoch_checkpoint_callback] | ||
| ) | ||
| trainer.fit(model) | ||
| step_idx = data_length * max_epochs - 1 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this could be just
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You are right! |
||
| train_step_lr_scheduler = pl_load(f"{tmpdir}/every_train_step/epoch={max_epochs-1}-step={step_idx}.ckpt" | ||
| )['lr_schedulers'][0] | ||
| val_epoch_lr_scheduler = pl_load(f"{tmpdir}/every_val_epoch/epoch={max_epochs-1}-step={step_idx}.ckpt" | ||
| )['lr_schedulers'][0] | ||
| # | ||
| assert train_step_lr_scheduler['last_epoch'] == val_epoch_lr_scheduler['last_epoch'] == step_idx + 1 | ||
| assert train_step_lr_scheduler['_step_count'] == val_epoch_lr_scheduler['_step_count'] == step_idx + 2 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if you are using a
ReduceLROnPlateauscheduler linked to a metric logged after this point? I believe this would failThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dear @carmocca
Exactly, thanks for pointing it out.
Will work around it further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @carmocca , I realize that this could be a bit tricky. Because pl.module
on_train_batch_endcallbacks may log metrics, the LR scheduler must be updated following those callbacks. However, we must also make sure that a Checkpoint'son_batch_batch_endhook is called after the scheduler update.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @carmocca , in this case, the same issue may also exist for the validation part?
if the metric is added in
on_validation_endcallbacks?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Long story short, in that snippet you only update the epoch schedulers if running only training
and we had another call to update epoch schedulers if running validation too.
So no issue
However, this was just updated in #7357. It has no conflicts with this PR but you should rebase master regardless
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @carmocca , it seems that this issue has been fixed by your updates!
https://colab.research.google.com/drive/1bBkhGiKJoavp1O4Oi4kLWVnxohl5HR6M?usp=sharing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh! Sorry I didn't check myself.
Lovely when you unknowingly fix something :D