Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Renamed the `pl.utilities.exceptions.GracefulExitException` to `SIGTERMException` ([#16501](https://github.com/Lightning-AI/lightning/pull/16501))

- The `Callback.on_train_epoch_end` hook now runs after the `LightningModule.on_train_epoch_end` hook for instances of `EarlyStopping` and `Checkpoint` callbacks ([#16567](https://github.com/Lightning-AI/lightning/pull/16567))

- The `LightningModule.{un}toggle_optimizer` methods no longer accept a `optimizer_idx` argument to select the relevant optimizer. Instead, the optimizer object can be passed in directly ([#16560](https://github.com/Lightning-AI/lightning/pull/16560))

Expand Down
6 changes: 5 additions & 1 deletion src/pytorch_lightning/loops/fit_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,12 @@ def on_advance_end(self) -> None:
self.epoch_progress.increment_processed()

# call train epoch end hooks
self.trainer._call_callback_hooks("on_train_epoch_end")
# we always call callback hooks first, but here we need to make an exception for the callbacks that
# monitor a metric, otherwise they wouldn't be able to monitor a key logged in
# `LightningModule.on_train_epoch_end`
self.trainer._call_callback_hooks("on_train_epoch_end", monitoring_callbacks=False)
self.trainer._call_lightning_module_hook("on_train_epoch_end")
self.trainer._call_callback_hooks("on_train_epoch_end", monitoring_callbacks=True)

self.trainer._logger_connector.on_epoch_end()

Expand Down
10 changes: 9 additions & 1 deletion src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,6 +1173,7 @@ def _call_callback_hooks(
self,
hook_name: str,
*args: Any,
monitoring_callbacks: Optional[bool] = None,
**kwargs: Any,
) -> None:
log.debug(f"{self.__class__.__name__}: calling callback hook: {hook_name}")
Expand All @@ -1182,7 +1183,14 @@ def _call_callback_hooks(
prev_fx_name = pl_module._current_fx_name
pl_module._current_fx_name = hook_name

for callback in self.callbacks:
callbacks = self.callbacks
if monitoring_callbacks is True:
# the list of "monitoring callbacks" is hard-coded to these two. we could add an API to define this
callbacks = [cb for cb in callbacks if isinstance(cb, (EarlyStopping, Checkpoint))]
elif monitoring_callbacks is False:
callbacks = [cb for cb in callbacks if not isinstance(cb, (EarlyStopping, Checkpoint))]

for callback in callbacks:
fn = getattr(callback, hook_name)
if callable(fn):
with self.profiler.profile(f"[Callback]{callback.state_key}.{hook_name}"):
Expand Down
14 changes: 7 additions & 7 deletions tests/tests_pytorch/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_early_stopping_patience(tmpdir, loss_values: list, patience: int, expec
class ModelOverrideValidationReturn(BoringModel):
validation_return_values = torch.tensor(loss_values)

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
loss = self.validation_return_values[self.current_epoch]
self.log("test_val_loss", loss)

Expand Down Expand Up @@ -164,7 +164,7 @@ def test_early_stopping_patience_train(
class ModelOverrideTrainReturn(BoringModel):
train_return_values = torch.tensor(loss_values)

def training_epoch_end(self, outputs):
def on_train_epoch_end(self):
loss = self.train_return_values[self.current_epoch]
self.log("train_loss", loss)

Expand All @@ -187,7 +187,7 @@ def training_epoch_end(self, outputs):
assert trainer.current_epoch - 1 == expected_stop_epoch


def test_pickling(tmpdir):
def test_pickling():
early_stopping = EarlyStopping(monitor="foo")

early_stopping_pickled = pickle.dumps(early_stopping)
Expand Down Expand Up @@ -226,7 +226,7 @@ def test_early_stopping_no_val_step(tmpdir):
)
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_threshold, losses, expected_epoch):
class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
val_loss = losses[self.current_epoch]
self.log("abc", val_loss)

Expand All @@ -252,7 +252,7 @@ def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):
expected_stop_epoch = 2

class CurrentModel(BoringModel):
def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
val_loss = losses[self.current_epoch]
self.log("val_loss", val_loss)

Expand Down Expand Up @@ -352,12 +352,12 @@ def _epoch_end(self) -> None:
self.log("abc", torch.tensor(loss))
self.log("cba", torch.tensor(0))

def training_epoch_end(self, outputs):
def on_train_epoch_end(self):
if not self.early_stop_on_train:
return
self._epoch_end()

def validation_epoch_end(self, outputs):
def on_validation_epoch_end(self):
if self.early_stop_on_train:
return
self._epoch_end()
Expand Down
10 changes: 6 additions & 4 deletions tests/tests_pytorch/models/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,11 +548,11 @@ def training_step(self, batch, batch_idx):
dict(name="on_validation_model_train"),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end`
dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback"
# `ModelCheckpoint.save_checkpoint` is called here
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_train_end", args=(trainer, model)),
dict(name="on_train_end"),
dict(name="Callback.on_fit_end", args=(trainer, model)),
Expand Down Expand Up @@ -627,10 +627,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
*model._train_batch(trainer, model, 2, current_epoch=1, current_batch=0),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback"
# `ModelCheckpoint.save_checkpoint` is called here
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_train_end", args=(trainer, model)),
dict(name="on_train_end"),
dict(name="Callback.on_fit_end", args=(trainer, model)),
Expand Down Expand Up @@ -706,10 +707,11 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmpdir):
*model._train_batch(trainer, model, steps_after_reload, current_batch=1),
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
dict(name="on_train_epoch_end"), # before ModelCheckpoint because it's a "monitoring callback"
# `ModelCheckpoint.save_checkpoint` is called here
dict(name="Callback.state_dict"),
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
dict(name="on_save_checkpoint", args=(saved_ckpt,)),
dict(name="on_train_epoch_end"),
dict(name="Callback.on_train_end", args=(trainer, model)),
dict(name="on_train_end"),
dict(name="Callback.on_fit_end", args=(trainer, model)),
Expand Down