diff --git a/CHANGELOG.md b/CHANGELOG.md index bba7ed346980c..abae0d6d96c25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -228,6 +228,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627)) +- Fixed `EarlyStopping` running on train epoch end when `check_val_every_n_epoch>1` is set ([#9156](https://github.com/PyTorchLightning/pytorch-lightning/pull/9156)) + + - Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 77683ad2819f3..4623b6077dbb9 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -131,11 +131,11 @@ def __init__( def state_key(self) -> str: return self._generate_state_key(monitor=self.monitor, mode=self.mode) - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_init_end(self, trainer: "pl.Trainer") -> None: if self._check_on_train_epoch_end is None: - # if the user runs validation multiple times per training epoch, we try to check after - # validation instead of on train epoch end - self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 + # if the user runs validation multiple times per training epoch or multiple training epochs without + # validation, then we run after validation instead of on train epoch end + self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e7daa4ee53cde..16e8a9de0d826 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -267,9 +267,9 @@ def state_key(self) -> str: def on_init_end(self, trainer: "pl.Trainer") -> None: if self._save_on_train_epoch_end is None: - # if the user runs validation multiple times per training epoch, we try to save checkpoint after - # validation instead of on train epoch end - self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 + # if the user runs validation multiple times per training epoch or multiple training epochs without + # validation, then we run after validation instead of on train epoch end + self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """When pretrain routine starts we build the ckpt dir on the fly.""" diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index ad343cdf329f5..ccc2ca24bf669 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -424,22 +424,27 @@ def test_multiple_early_stopping_callbacks( trainer.fit(model) -def test_check_on_train_epoch_end_with_val_check_interval(tmpdir): +@pytest.mark.parametrize( + "case", + { + "val_check_interval": {"val_check_interval": 0.3, "limit_train_batches": 10, "max_epochs": 10}, + "check_val_every_n_epoch": {"check_val_every_n_epoch": 2, "max_epochs": 5}, + }.items(), +) +def test_check_on_train_epoch_end_smart_handling(tmpdir, case): class TestModel(BoringModel): def validation_step(self, batch, batch_idx): self.log("foo", 1) return super().validation_step(batch, batch_idx) + case, kwargs = case model = TestModel() - val_check_interval, limit_train_batches = 0.3, 10 trainer = Trainer( default_root_dir=tmpdir, - val_check_interval=val_check_interval, - max_epochs=1, - limit_train_batches=limit_train_batches, limit_val_batches=1, callbacks=EarlyStopping(monitor="foo"), progress_bar_refresh_rate=0, + **kwargs, ) side_effect = [(False, "A"), (True, "B")] @@ -449,4 +454,7 @@ def validation_step(self, batch, batch_idx): trainer.fit(model) assert es_mock.call_count == len(side_effect) - assert trainer.global_step == len(side_effect) * int(limit_train_batches * val_check_interval) + if case == "val_check_interval": + assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval) + else: + assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1 diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f49fa16598fd2..0e19202d33d9b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1248,3 +1248,21 @@ def test_trainer_checkpoint_callback_bool(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir) with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"): Trainer(checkpoint_callback=mc) + + +def test_check_val_every_n_epochs_top_k_integration(tmpdir): + model = BoringModel() + mc = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", save_top_k=-1, filename="{epoch}") + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=5, + check_val_every_n_epoch=2, + callbacks=mc, + weights_summary=None, + logger=False, + ) + trainer.fit(model) + assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"}