From e295c9c983d07c2cb5ec17c5441008587aa9a9d1 Mon Sep 17 00:00:00 2001 From: "Alvin(Xinyao) Sun" Date: Tue, 25 May 2021 14:19:39 -0600 Subject: [PATCH 1/4] fix: calling update_lr_schedulers after on_train_batch_end hook fix issue: #7637 --- pytorch_lightning/trainer/training_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 62138790138ee..dcb474c85d025 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -494,6 +494,9 @@ def run_training_epoch(self): if batch_output.signal == -1: break + # update LR schedulers + self.update_lr_schedulers('step') + # hook # TODO: add outputs to batches self.on_train_batch_end( @@ -523,8 +526,6 @@ def run_training_epoch(self): # ----------------------------------------- self.save_loggers_on_train_batch_end() - # update LR schedulers - self.update_lr_schedulers('step') self.trainer.checkpoint_connector.has_trained = True self.total_batch_idx += 1 From 84950b90f08f04b831ff00a2d9fdb226f33e9d77 Mon Sep 17 00:00:00 2001 From: "Alvin(Xinyao) Sun" Date: Tue, 25 May 2021 14:29:29 -0600 Subject: [PATCH 2/4] feat: add test --- .../checkpointing/test_trainer_checkpoint.py | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index c0b396e70e477..9f7d4c83e1219 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -13,13 +13,17 @@ # limitations under the License. import os from copy import deepcopy +from pathlib import Path +import pytest import torch +from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning import seed_everything, Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from tests.helpers import BoringModel +from pytorch_lightning.utilities.cloud_io import load as pl_load +from tests.helpers import BoringModel, RandomDataset def test_finetuning_with_resume_from_checkpoint(tmpdir): @@ -84,3 +88,40 @@ def validation_step(self, batch, batch_idx): assert best_model_path.endswith(f"epoch=0{idx}.ckpt") else: assert f"epoch={idx + 1}" in best_model_path + + +@pytest.mark.parametrize(['max_epochs', 'data_length'], [(1, 64), (2, 64), (3, 32)]) +def test_lr_schedulers_step_count(tmpdir, max_epochs, data_length): + """ + This test validates that checkpoint is always saved after lr_scheduler beeing updated during training + """ + + class TestModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=0.001) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + lr_scheduler_dict = {'scheduler': lr_scheduler, 'interval': 'step'} + return [optimizer], [lr_scheduler_dict] + + def train_dataloader(self): + return DataLoader(RandomDataset(32, data_length)) + + train_step_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_train_step", every_n_train_steps=1) + val_epoch_checkpoint_callback = ModelCheckpoint(dirpath=f"{tmpdir}/every_val_epoch", every_n_val_epochs=1) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=max_epochs, + callbacks=[train_step_checkpoint_callback, val_epoch_checkpoint_callback] + ) + trainer.fit(model) + step_idx = data_length * max_epochs - 1 + train_step_lr_scheduler = pl_load(f"{tmpdir}/every_train_step/epoch={max_epochs-1}-step={step_idx}.ckpt" + )['lr_schedulers'][0] + val_epoch_lr_scheduler = pl_load(f"{tmpdir}/every_val_epoch/epoch={max_epochs-1}-step={step_idx}.ckpt" + )['lr_schedulers'][0] + # + assert train_step_lr_scheduler['last_epoch'] == val_epoch_lr_scheduler['last_epoch'] == step_idx + 1 + assert train_step_lr_scheduler['_step_count'] == val_epoch_lr_scheduler['_step_count'] == step_idx + 2 From f3eec8731b73b13444b84bf30d61ad45acbb6822 Mon Sep 17 00:00:00 2001 From: "Alvin(Xinyao) Sun" Date: Tue, 25 May 2021 14:29:46 -0600 Subject: [PATCH 3/4] chore: update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8324701679091..b51d151ff96ca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -124,6 +124,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed LR scheduler steps after saving checkpoint with iteration-based checkpointing + - Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685)) - Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566)) From bd1076e9f73eb64a77eb348b2753d46702694069 Mon Sep 17 00:00:00 2001 From: "Alvin(Xinyao) Sun" Date: Tue, 25 May 2021 15:12:05 -0600 Subject: [PATCH 4/4] chore: rm unused import --- tests/checkpointing/test_trainer_checkpoint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py index 9f7d4c83e1219..d7110831030e6 100644 --- a/tests/checkpointing/test_trainer_checkpoint.py +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -13,7 +13,6 @@ # limitations under the License. import os from copy import deepcopy -from pathlib import Path import pytest import torch