diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index ec69593ce2f09..e5c9b1f829ce1 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Renamed the `pl.utilities.exceptions.GracefulExitException` to `SIGTERMException` ([#16501](https://github.com/Lightning-AI/lightning/pull/16501)) +- The `Callback.on_train_epoch_end` hook now runs after the `LightningModule.on_train_epoch_end` hook for instances of `EarlyStopping` and `Checkpoint` callbacks ([#16567](https://github.com/Lightning-AI/lightning/pull/16567)) - The `LightningModule.{un}toggle_optimizer` methods no longer accept a `optimizer_idx` argument to select the relevant optimizer. Instead, the optimizer object can be passed in directly ([#16560](https://github.com/Lightning-AI/lightning/pull/16560)) diff --git a/src/pytorch_lightning/loops/fit_loop.py b/src/pytorch_lightning/loops/fit_loop.py index 695871e0c9f77..572a61b824e58 100644 --- a/src/pytorch_lightning/loops/fit_loop.py +++ b/src/pytorch_lightning/loops/fit_loop.py @@ -302,8 +302,12 @@ def on_advance_end(self) -> None: self.epoch_progress.increment_processed() # call train epoch end hooks - self.trainer._call_callback_hooks("on_train_epoch_end") + # we always call callback hooks first, but here we need to make an exception for the callbacks that + # monitor a metric, otherwise they wouldn't be able to monitor a key logged in + # `LightningModule.on_train_epoch_end` + self.trainer._call_callback_hooks("on_train_epoch_end", monitoring_callbacks=False) self.trainer._call_lightning_module_hook("on_train_epoch_end") + self.trainer._call_callback_hooks("on_train_epoch_end", monitoring_callbacks=True) self.trainer._logger_connector.on_epoch_end() diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index ba1720910afe3..233e645325a9e 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -1173,6 +1173,7 @@ def _call_callback_hooks( self, hook_name: str, *args: Any, + monitoring_callbacks: Optional[bool] = None, **kwargs: Any, ) -> None: log.debug(f"{self.__class__.__name__}: calling callback hook: {hook_name}") @@ -1182,7 +1183,14 @@ def _call_callback_hooks( prev_fx_name = pl_module._current_fx_name pl_module._current_fx_name = hook_name - for callback in self.callbacks: + callbacks = self.callbacks + if monitoring_callbacks is True: + # the list of "monitoring callbacks" is hard-coded to these two. we could add an API to define this + callbacks = [cb for cb in callbacks if isinstance(cb, (EarlyStopping, Checkpoint))] + elif monitoring_callbacks is False: + callbacks = [cb for cb in callbacks if not isinstance(cb, (EarlyStopping, Checkpoint))] + + for callback in callbacks: fn = getattr(callback, hook_name) if callable(fn): with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"): diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 7663a53212427..0c1f7819d958a 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -134,7 +134,7 @@ def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expec class ModelOverrideValidationReturn(BoringModel): validation_return_values = torch.tensor(loss_values) - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): loss = self.validation_return_values[self.current_epoch] self.log("test_val_loss", loss) @@ -164,7 +164,7 @@ def test_early_stopping_patience_train( class ModelOverrideTrainReturn(BoringModel): train_return_values = torch.tensor(loss_values) - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): loss = self.train_return_values[self.current_epoch] self.log("train_loss", loss) @@ -187,7 +187,7 @@ def training_epoch_end(self, outputs): assert trainer.current_epoch - 1 == expected_stop_epoch -def test_pickling(tmpdir): +def test_pickling(): early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) @@ -226,7 +226,7 @@ def test_early_stopping_no_val_step(tmpdir): ) def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_threshold, losses, expected_epoch): class CurrentModel(BoringModel): - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): val_loss = losses[self.current_epoch] self.log("abc", val_loss) @@ -252,7 +252,7 @@ def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value): expected_stop_epoch = 2 class CurrentModel(BoringModel): - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): val_loss = losses[self.current_epoch] self.log("val_loss", val_loss) @@ -352,12 +352,12 @@ def _epoch_end(self) -> None: self.log("abc", torch.tensor(loss)) self.log("cba", torch.tensor(0)) - def training_epoch_end(self, outputs): + def on_train_epoch_end(self): if not self.early_stop_on_train: return self._epoch_end() - def validation_epoch_end(self, outputs): + def on_validation_epoch_end(self): if self.early_stop_on_train: return self._epoch_end() diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 16b2a04a95817..74f7fb15c6116 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -548,11 +548,11 @@ def training_step(self, batch, batch_idx): dict(name="on_validation_model_train"), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), - # `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end` + dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" + # `ModelCheckpoint.save_checkpoint` is called here dict(name="Callback.state_dict"), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), - dict(name="on_train_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), dict(name="on_train_end"), dict(name="Callback.on_fit_end", args=(trainer, model)), @@ -627,10 +627,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): *model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0), dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), + dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" + # `ModelCheckpoint.save_checkpoint` is called here dict(name="Callback.state_dict"), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), - dict(name="on_train_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), dict(name="on_train_end"), dict(name="Callback.on_fit_end", args=(trainer, model)), @@ -706,10 +707,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir): *model._train_batch(trainer, model, steps_after_reload, current_batch=1), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), + dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback" + # `ModelCheckpoint.save_checkpoint` is called here dict(name="Callback.state_dict"), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), - dict(name="on_train_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), dict(name="on_train_end"), dict(name="Callback.on_fit_end", args=(trainer, model)),