diff --git a/CHANGELOG.md b/CHANGELOG.md index 114f1af38f82c..93ba9cbb3f77d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -366,6 +366,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Remove hardcoding of local rank in accelerator connector ([#6878](https://github.com/PyTorchLightning/pytorch-lightning/pull/6878)) +- Fixed incorrect number of calls to LR scheduler when `check_val_every_n_epoch > 1` ([#7032](https://github.com/PyTorchLightning/pytorch-lightning/pull/7032)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f7da8e929e865..cb02cd9df521e 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -472,7 +472,6 @@ def run_training_epoch(self): train_dataloader = self.trainer.data_connector.get_profiled_train_dataloader(train_dataloader) dataloader_idx = 0 - val_loop_called = False batch_idx = None is_last_batch = None @@ -514,7 +513,6 @@ def run_training_epoch(self): self.trainer.validating = True self.trainer._run_evaluation() self.trainer.training = True - val_loop_called = True # ----------------------------------------- # SAVE LOGGERS (ie: Tensorboard, etc...) @@ -563,7 +561,7 @@ def run_training_epoch(self): should_train_only = self.trainer.disable_validation or should_skip_eval # update epoch level lr_schedulers if no val loop outside train loop is triggered - if (val_loop_called and not should_check_val) or should_train_only: + if not should_check_val or should_train_only: self.trainer.optimizer_connector.update_learning_rates(interval='epoch') if should_train_only: diff --git a/tests/trainer/optimization/test_optimizers.py b/tests/trainer/optimization/test_optimizers.py index f5b2229f8a99e..a81e0eecf5c61 100644 --- a/tests/trainer/optimization/test_optimizers.py +++ b/tests/trainer/optimization/test_optimizers.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from unittest import mock + import pytest import torch from torch import optim @@ -577,21 +579,21 @@ def configure_optimizers(self): trainer.fit(model) -class TestModel(BoringModel): +@RunIf(min_gpus=2, special=True) +def test_optimizer_state_on_device(tmpdir): + """ Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """ - def configure_optimizers(self): - # Adagrad creates state tensors immediately, model is not yet on GPU. - return optim.Adagrad(self.parameters()) + class TestModel(BoringModel): - def on_train_start(self, *args, **kwargs): - opt = self.optimizers() - _, state = next(iter(opt.state.items())) - assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device + def configure_optimizers(self): + # Adagrad creates state tensors immediately, model is not yet on GPU. + return optim.Adagrad(self.parameters()) + def on_train_start(self, *args, **kwargs): + opt = self.optimizers() + _, state = next(iter(opt.state.items())) + assert state["sum"].device == torch.device("cuda", self.local_rank) == self.device -@RunIf(min_gpus=2, special=True) -def test_optimizer_state_on_device(tmpdir): - """ Test that optimizers that create state initially at instantiation still end up with the state on the GPU. """ model = TestModel() trainer = Trainer( default_root_dir=tmpdir, @@ -600,3 +602,21 @@ def test_optimizer_state_on_device(tmpdir): fast_dev_run=True, ) trainer.fit(model) + + +@pytest.mark.parametrize("check_val_every_n_epoch", [1, 2]) +@mock.patch("torch.optim.lr_scheduler.StepLR.step") +def test_lr_scheduler_epoch_step_frequency(mocked_sched, check_val_every_n_epoch, tmpdir): + epochs = 4 + expected_steps = epochs + 1 # every LRScheduler gets called once at init + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + check_val_every_n_epoch=check_val_every_n_epoch, + max_epochs=epochs, + ) + trainer.fit(model) + assert mocked_sched.call_count == expected_steps