From c40c5dd7f3d44afd276f40c9544f0bdc00c3ea87 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 5 Jul 2021 12:38:38 +0200 Subject: [PATCH 1/3] Change default `EarlyStopping.check_on_train_epoch_end` to true --- pytorch_lightning/callbacks/early_stopping.py | 10 +++++----- tests/callbacks/test_early_stopping.py | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index b6bff43fd6317..1a7ab3da40c29 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -97,7 +97,7 @@ def __init__( check_finite: bool = True, stopping_threshold: Optional[float] = None, divergence_threshold: Optional[float] = None, - check_on_train_epoch_end: bool = False, + check_on_train_epoch_end: bool = True, ): super().__init__() self.min_delta = min_delta @@ -149,7 +149,8 @@ def _validate_condition_metric(self, logs): def monitor_op(self) -> Callable: return self.mode_dict[self.mode] - def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', + checkpoint: Dict[str, Any]) -> Dict[str, Any]: return { 'wait_count': self.wait_count, 'stopped_epoch': self.stopped_epoch, @@ -167,7 +168,7 @@ def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerFn return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking - def on_train_epoch_end(self, trainer, pl_module) -> None: + def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None: if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) @@ -175,10 +176,9 @@ def on_train_epoch_end(self, trainer, pl_module) -> None: def on_validation_end(self, trainer, pl_module) -> None: if self._check_on_train_epoch_end or self._should_skip_check(trainer): return - self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer) -> None: + def _run_early_stopping_check(self, trainer: 'pl.Trainer') -> None: """ Checks whether the early stopping condition is met and if so tells the trainer to stop the training. diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index d7a6f15459912..56252208a6f13 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -69,12 +69,13 @@ def test_resume_early_stopping_from_checkpoint(tmpdir): ) trainer.fit(model, datamodule=dm) + assert len(early_stop_callback.saved_states) == 4 + checkpoint_filepath = checkpoint_callback.kth_best_model_path # ensure state is persisted properly checkpoint = torch.load(checkpoint_filepath) # the checkpoint saves "epoch + 1" early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1] - assert 4 == len(early_stop_callback.saved_states) assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state # ensure state is reloaded properly (assertion in the callback) From 643f6cf0e9a287d03910c016e81fcec3b230698e Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 5 Jul 2021 12:40:32 +0200 Subject: [PATCH 2/3] Magic comma --- pytorch_lightning/callbacks/early_stopping.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 1a7ab3da40c29..98675866b093b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -149,8 +149,12 @@ def _validate_condition_metric(self, logs): def monitor_op(self) -> Callable: return self.mode_dict[self.mode] - def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', - checkpoint: Dict[str, Any]) -> Dict[str, Any]: + def on_save_checkpoint( + self, + trainer: 'pl.Trainer', + pl_module: 'pl.LightningModule', + checkpoint: Dict[str, Any], + ) -> Dict[str, Any]: return { 'wait_count': self.wait_count, 'stopped_epoch': self.stopped_epoch, From 1d2170dd0e75560082c600d71a203723ebf3612d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 5 Jul 2021 12:46:06 +0200 Subject: [PATCH 3/3] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 26c863be63d83..8a2bf76dd0df7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -154,6 +154,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357)) +- `EarlyStopping` now runs at the end of the training epoch by default ([#8286](https://github.com/PyTorchLightning/pytorch-lightning/pull/8286)) + + - Refactored Loops * Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7437)) * Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))