From 550e8473bbdadc3145d5c17f8ea8af5e3fb97e40 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 03:29:42 +0200 Subject: [PATCH 1/4] Disable `{save,check}_on_train_epoch_end` with `check_val_every_n_epoch` --- pytorch_lightning/callbacks/early_stopping.py | 2 +- .../callbacks/model_checkpoint.py | 2 +- tests/callbacks/test_early_stopping.py | 20 +++++++++++++------ 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 77683ad2819f3..9bc6434d203a7 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -135,7 +135,7 @@ def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.Lightn if self._check_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch, we try to check after # validation instead of on train epoch end - self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 + self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 def _validate_condition_metric(self, logs): monitor_val = logs.get(self.monitor) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e7daa4ee53cde..647b5509f9e2b 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -269,7 +269,7 @@ def on_init_end(self, trainer: "pl.Trainer") -> None: if self._save_on_train_epoch_end is None: # if the user runs validation multiple times per training epoch, we try to save checkpoint after # validation instead of on train epoch end - self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 + self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: """When pretrain routine starts we build the ckpt dir on the fly.""" diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index ad343cdf329f5..ccc2ca24bf669 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -424,22 +424,27 @@ def test_multiple_early_stopping_callbacks( trainer.fit(model) -def test_check_on_train_epoch_end_with_val_check_interval(tmpdir): +@pytest.mark.parametrize( + "case", + { + "val_check_interval": {"val_check_interval": 0.3, "limit_train_batches": 10, "max_epochs": 10}, + "check_val_every_n_epoch": {"check_val_every_n_epoch": 2, "max_epochs": 5}, + }.items(), +) +def test_check_on_train_epoch_end_smart_handling(tmpdir, case): class TestModel(BoringModel): def validation_step(self, batch, batch_idx): self.log("foo", 1) return super().validation_step(batch, batch_idx) + case, kwargs = case model = TestModel() - val_check_interval, limit_train_batches = 0.3, 10 trainer = Trainer( default_root_dir=tmpdir, - val_check_interval=val_check_interval, - max_epochs=1, - limit_train_batches=limit_train_batches, limit_val_batches=1, callbacks=EarlyStopping(monitor="foo"), progress_bar_refresh_rate=0, + **kwargs, ) side_effect = [(False, "A"), (True, "B")] @@ -449,4 +454,7 @@ def validation_step(self, batch, batch_idx): trainer.fit(model) assert es_mock.call_count == len(side_effect) - assert trainer.global_step == len(side_effect) * int(limit_train_batches * val_check_interval) + if case == "val_check_interval": + assert trainer.global_step == len(side_effect) * int(trainer.limit_train_batches * trainer.val_check_interval) + else: + assert trainer.current_epoch == len(side_effect) * trainer.check_val_every_n_epoch - 1 From 32684f2cf74c10177acbeb2ccd066cdd48be06ad Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 27 Aug 2021 03:38:06 +0200 Subject: [PATCH 2/4] Update signature, comment, and CHANGELOG --- CHANGELOG.md | 3 +++ pytorch_lightning/callbacks/early_stopping.py | 6 +++--- pytorch_lightning/callbacks/model_checkpoint.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bba7ed346980c..abae0d6d96c25 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -228,6 +228,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). [#8627](https://github.com/PyTorchLightning/pytorch-lightning/pull/8627)) +- Fixed `EarlyStopping` running on train epoch end when `check_val_every_n_epoch>1` is set ([#9156](https://github.com/PyTorchLightning/pytorch-lightning/pull/9156)) + + - Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 9bc6434d203a7..4623b6077dbb9 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -131,10 +131,10 @@ def __init__( def state_key(self) -> str: return self._generate_state_key(monitor=self.monitor, mode=self.mode) - def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def on_init_end(self, trainer: "pl.Trainer") -> None: if self._check_on_train_epoch_end is None: - # if the user runs validation multiple times per training epoch, we try to check after - # validation instead of on train epoch end + # if the user runs validation multiple times per training epoch or multiple training epochs without + # validation, then we run after validation instead of on train epoch end self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 def _validate_condition_metric(self, logs): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 647b5509f9e2b..16e8a9de0d826 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -267,8 +267,8 @@ def state_key(self) -> str: def on_init_end(self, trainer: "pl.Trainer") -> None: if self._save_on_train_epoch_end is None: - # if the user runs validation multiple times per training epoch, we try to save checkpoint after - # validation instead of on train epoch end + # if the user runs validation multiple times per training epoch or multiple training epochs without + # validation, then we run after validation instead of on train epoch end self._save_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1 def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: From 707f1adab19043a59fb9c5a72af1d98bc4e0345e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 3 Sep 2021 15:55:22 +0200 Subject: [PATCH 3/4] Add test for 9163 --- tests/checkpointing/test_model_checkpoint.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f49fa16598fd2..d8276fc469cd4 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1248,3 +1248,21 @@ def test_trainer_checkpoint_callback_bool(tmpdir): mc = ModelCheckpoint(dirpath=tmpdir) with pytest.raises(MisconfigurationException, match="Invalid type provided for checkpoint_callback"): Trainer(checkpoint_callback=mc) + + +def test_check_val_every_n_epochs_top_k_integration(tmpdir): + model = BoringModel() + mc = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", save_top_k=5, filename="{epoch}") + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + num_sanity_val_steps=0, + max_epochs=5, + check_val_every_n_epoch=2, + callbacks=mc, + weights_summary=None, + logger=False, + ) + trainer.fit(model) + assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"} From 23dbc5a109aa3248000185bf15366e8b6572f9ee Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 3 Sep 2021 15:58:01 +0200 Subject: [PATCH 4/4] topk -1 --- tests/checkpointing/test_model_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index d8276fc469cd4..0e19202d33d9b 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1252,7 +1252,7 @@ def test_trainer_checkpoint_callback_bool(tmpdir): def test_check_val_every_n_epochs_top_k_integration(tmpdir): model = BoringModel() - mc = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", save_top_k=5, filename="{epoch}") + mc = ModelCheckpoint(dirpath=tmpdir, monitor="epoch", save_top_k=-1, filename="{epoch}") trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=1,