From 586d2d44482e2e13e559b205a216a665a23ee91b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 30 Jan 2023 19:35:43 +0100 Subject: [PATCH 1/3] Run `on_train_epoch_end` after the LM for callbacks that monitor --- src/pytorch_lightning/loops/fit_loop.py | 6 +++++- src/pytorch_lightning/trainer/trainer.py | 9 ++++++++- tests/tests_pytorch/models/test_hooks.py | 10 ++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 695871e0c9f77..572a61b824e58 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -302,8 +302,12 @@ def on_advance_end(self) -> None: self.epoch_progress.increment_processed() # call train epoch end hooks - self.trainer._call_callback_hooks("on_train_epoch_end") + # we always call callback hooks first, but here we need to make an exception for the callbacks that + # monitor a metric, otherwise they wouldn't be able to monitor a key logged in + # `LightningModule.on_train_epoch_end` + self.trainer._call_callback_hooks("on_train_epoch_end", monitoring_callbacks=False) self.trainer._call_lightning_module_hook("on_train_epoch_end") + self.trainer._call_callback_hooks("on_train_epoch_end", monitoring_callbacks=True) self.trainer._logger_connector.on_epoch_end() diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index ba1720910afe3..5d94a39c9b761 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1173,6 +1173,7 @@ def _call_callback_hooks( self, hook_name: str, *args: Any, + monitoring_callbacks: Optional[bool] = None, **kwargs: Any, ) -> None: log.debug(f"{self.__class__.__name__}: calling callback hook: {hook_name}") @@ -1182,7 +1183,13 @@ def _call_callback_hooks( prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = hook_name - for callback in self.callbacks: + callbacks = self.callbacks + if monitoring_callbacks is True: + callbacks = [cb for cb in callbacks if isinstance(cb, (EarlyStopping, Checkpoint))] + elif monitoring_callbacks is False: + callbacks = [cb for cb in callbacks if not isinstance(cb, (EarlyStopping, Checkpoint))] + + for callback in callbacks: fn = getattr(callback, hook_name) if callable(fn): with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 16b2a04a95817..74f7fb15c6116 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -548,11 +548,11 @@ def training_step(self, batch, batch_idx): dict(name="on_validation_model_train"), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), - # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end` + dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" + # `ModelCheckpoint.save_checkpoint` is called here dict(name="Callback.state_dict"), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), - dict(name="on_train_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), dict(name="on_train_end"), dict(name="Callback.on_fit_end", args=(trainer, model)), @@ -627,10 +627,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): *model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0), dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), + dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" + # `ModelCheckpoint.save_checkpoint` is called here dict(name="Callback.state_dict"), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), - dict(name="on_train_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), dict(name="on_train_end"), dict(name="Callback.on_fit_end", args=(trainer, model)), @@ -706,10 +707,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir): *model._train_batch(trainer, model, steps_after_reload, current_batch=1), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), + dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" + # `ModelCheckpoint.save_checkpoint` is called here dict(name="Callback.state_dict"), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), - dict(name="on_train_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), dict(name="on_train_end"), dict(name="Callback.on_fit_end", args=(trainer, model)), From 10a4af119e9c3fa54b8fec9062df7d633b06e95b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 30 Jan 2023 19:57:01 +0100 Subject: [PATCH 2/3] Update early stopping tests --- .../tests_pytorch/callbacks/test_early_stopping.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 7663a53212427..0c1f7819d958a 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -134,7 +134,7 @@ def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expec class ModelOverrideValidationReturn(BoringModel): validation_return_values = torch.tensor(loss_values) - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): loss = self.validation_return_values[self.current_epoch] self.log("test_val_loss", loss) @@ -164,7 +164,7 @@ def test_early_stopping_patience_train( class ModelOverrideTrainReturn(BoringModel): train_return_values = torch.tensor(loss_values) - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): loss = self.train_return_values[self.current_epoch] self.log("train_loss", loss) @@ -187,7 +187,7 @@ def training_epoch_end(self, outputs): assert trainer.current_epoch - 1 == expected_stop_epoch -def test_pickling(tmpdir): +def test_pickling(): early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) @@ -226,7 +226,7 @@ def test_early_stopping_no_val_step(tmpdir): ) def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_threshold, losses, expected_epoch): class CurrentModel(BoringModel): - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): val_loss = losses[self.current_epoch] self.log("abc", val_loss) @@ -252,7 +252,7 @@ def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value): expected_stop_epoch = 2 class CurrentModel(BoringModel): - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): val_loss = losses[self.current_epoch] self.log("val_loss", val_loss) @@ -352,12 +352,12 @@ def _epoch_end(self) -> None: self.log("abc", torch.tensor(loss)) self.log("cba", torch.tensor(0)) - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): if not self.early_stop_on_train: return self._epoch_end() - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): if self.early_stop_on_train: return self._epoch_end() From dc604d7f582779ad9ad4c1e8a95fc42eb426801e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 30 Jan 2023 20:00:12 +0100 Subject: [PATCH 3/3] CHANGELOG --- src/pytorch_lightning/CHANGELOG.md | 1 + src/pytorch_lightning/trainer/trainer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ec69593ce2f09..e5c9b1f829ce1 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Renamed the `pl.utilities.exceptions.GracefulExitException` to `SIGTERMException` ([#16501](https://github.com/Lightning-AI/lightning/pull/16501)) +- The `Callback.on_train_epoch_end` hook now runs after the `LightningModule.on_train_epoch_end` hook for instances of `EarlyStopping` and `Checkpoint` callbacks ([#16567](https://github.com/Lightning-AI/lightning/pull/16567)) - The `LightningModule.{un}toggle_optimizer` methods no longer accept a `optimizer_idx` argument to select the relevant optimizer. Instead, the optimizer object can be passed in directly ([#16560](https://github.com/Lightning-AI/lightning/pull/16560)) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 5d94a39c9b761..233e645325a9e 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1185,6 +1185,7 @@ def _call_callback_hooks( callbacks = self.callbacks if monitoring_callbacks is True: + # the list of "monitoring callbacks" is hard-coded to these two. we could add an API to define this callbacks = [cb for cb in callbacks if isinstance(cb, (EarlyStopping, Checkpoint))] elif monitoring_callbacks is False: callbacks = [cb for cb in callbacks if not isinstance(cb, (EarlyStopping, Checkpoint))]