Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Validation is now always run inside the training epoch scope ([#7357](https://github.com/PyTorchLightning/pytorch-lightning/pull/7357))


- `EarlyStopping` now runs at the end of the training epoch by default ([#8286](https://github.com/PyTorchLightning/pytorch-lightning/pull/8286))


- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7437))
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
Expand Down
14 changes: 9 additions & 5 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(
check_finite: bool = True,
stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,
check_on_train_epoch_end: bool = False,
check_on_train_epoch_end: bool = True,
):
super().__init__()
self.min_delta = min_delta
Expand Down Expand Up @@ -149,7 +149,12 @@ def _validate_condition_metric(self, logs):
def monitor_op(self) -> Callable:
return self.mode_dict[self.mode]

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
def on_save_checkpoint(
self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
checkpoint: Dict[str, Any],
) -> Dict[str, Any]:
return {
'wait_count': self.wait_count,
'stopped_epoch': self.stopped_epoch,
Expand All @@ -167,18 +172,17 @@ def _should_skip_check(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerFn
return trainer.state.fn != TrainerFn.FITTING or trainer.sanity_checking

def on_train_epoch_end(self, trainer, pl_module) -> None:
def on_train_epoch_end(self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule') -> None:
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
return
self._run_early_stopping_check(trainer)

def on_validation_end(self, trainer, pl_module) -> None:
if self._check_on_train_epoch_end or self._should_skip_check(trainer):
return

self._run_early_stopping_check(trainer)

def _run_early_stopping_check(self, trainer) -> None:
def _run_early_stopping_check(self, trainer: 'pl.Trainer') -> None:
"""
Checks whether the early stopping condition is met
and if so tells the trainer to stop the training.
Expand Down
3 changes: 2 additions & 1 deletion tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
)
trainer.fit(model, datamodule=dm)

assert len(early_stop_callback.saved_states) == 4

checkpoint_filepath = checkpoint_callback.kth_best_model_path
# ensure state is persisted properly
checkpoint = torch.load(checkpoint_filepath)
# the checkpoint saves "epoch + 1"
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
assert 4 == len(early_stop_callback.saved_states)
assert checkpoint["callbacks"][type(early_stop_callback)] == early_stop_callback_state

# ensure state is reloaded properly (assertion in the callback)
Expand Down