Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed LR scheduler steps after saving checkpoint with iteration-based checkpointing

- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))

- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,9 @@ def run_training_epoch(self):
if batch_output.signal == -1:
break

# update LR schedulers
self.update_lr_schedulers('step')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if you are using a ReduceLROnPlateau scheduler linked to a metric logged after this point? I believe this would fail

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @carmocca
Exactly, thanks for pointing it out.
Will work around it further.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @carmocca , I realize that this could be a bit tricky. Because pl.module on_train_batch_end callbacks may log metrics, the LR scheduler must be updated following those callbacks. However, we must also make sure that a Checkpoint's on_batch_batch_end hook is called after the scheduler update.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @carmocca , in this case, the same issue may also exist for the validation part?

        if not should_check_val or should_train_only:
            self.update_lr_schedulers('epoch')

        if should_train_only:
            self.check_checkpoint_callback(True)

if the metric is added in on_validation_end callbacks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long story short, in that snippet you only update the epoch schedulers if running only training
and we had another call to update epoch schedulers if running validation too.
So no issue

However, this was just updated in #7357. It has no conflicts with this PR but you should rebase master regardless

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @carmocca , it seems that this issue has been fixed by your updates!
https://colab.research.google.com/drive/1bBkhGiKJoavp1O4Oi4kLWVnxohl5HR6M?usp=sharing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! Sorry I didn't check myself.

Lovely when you unknowingly fix something :D


# hook
# TODO: add outputs to batches
self.on_train_batch_end(
Expand Down Expand Up @@ -523,8 +526,6 @@ def run_training_epoch(self):
# -----------------------------------------
self.save_loggers_on_train_batch_end()

# update LR schedulers
self.update_lr_schedulers('step')
self.trainer.checkpoint_connector.has_trained = True

self.total_batch_idx += 1
Expand Down
42 changes: 41 additions & 1 deletion tests/checkpointing/test_trainer_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
import os
from copy import deepcopy

import pytest
import torch
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from tests.helpers import BoringModel
from pytorch_lightning.utilities.cloud_io import load as pl_load
from tests.helpers import BoringModel, RandomDataset


def test_finetuning_with_resume_from_checkpoint(tmpdir):
Expand Down Expand Up @@ -84,3 +87,40 @@ def validation_step(self, batch, batch_idx):
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
else:
assert f"epoch={idx + 1}" in best_model_path


@pytest.mark.parametrize(['max_epochs', 'data_length'], [(1, 64), (2, 64), (3, 32)])
def test_lr_schedulers_step_count(tmpdir, max_epochs, data_length):
"""
This test validates that checkpoint is always saved after lr_scheduler beeing updated during training
"""

class TestModel(BoringModel):

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
lr_scheduler_dict = {'scheduler': lr_scheduler, 'interval': 'step'}
return [optimizer], [lr_scheduler_dict]

def train_dataloader(self):
return DataLoader(RandomDataset(32, data_length))

train_step_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_train_step", every_n_train_steps=1)
val_epoch_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_val_epoch", every_n_val_epochs=1)
Comment on lines +109 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to save/load a checkpoint?

Can't you just test the state on_train_batch_end?


model = TestModel()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=max_epochs,
callbacks=[train_step_checkpoint_callback, val_epoch_checkpoint_callback]
)
trainer.fit(model)
step_idx = data_length * max_epochs - 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be just

step = trainer.global_step - 1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right!

train_step_lr_scheduler = pl_load(f"{tmpdir}/every_train_step/epoch={max_epochs-1}-step={step_idx}.ckpt"
)['lr_schedulers'][0]
val_epoch_lr_scheduler = pl_load(f"{tmpdir}/every_val_epoch/epoch={max_epochs-1}-step={step_idx}.ckpt"
)['lr_schedulers'][0]
#
assert train_step_lr_scheduler['last_epoch'] == val_epoch_lr_scheduler['last_epoch'] == step_idx + 1
assert train_step_lr_scheduler['_step_count'] == val_epoch_lr_scheduler['_step_count'] == step_idx + 2