Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added


- Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/))


- Added synchronization points before and after `setup` hooks are run ([#7202](https://github.com/PyTorchLightning/pytorch-lightning/pull/7202))


Expand Down
30 changes: 21 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

"""
import logging
from typing import Any, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Optional, Tuple

import numpy as np
import torch
Expand All @@ -39,8 +39,8 @@ class EarlyStopping(Callback):
monitor: quantity to be monitored.
min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no improvement.
patience: number of validation checks with no improvement
after which training will be stopped. Under the default configuration, one validation check happens after
patience: number of checks with no improvement
after which training will be stopped. Under the default configuration, one check happens after
every training epoch. However, the frequency of validation can be modified by setting various parameters on
the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``.

Expand All @@ -59,6 +59,8 @@ class EarlyStopping(Callback):
check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
check_on_train_epoch_end: whether to run early stopping at the end of the training epoch.
If this is ``False``, then the check runs at the end of the validation epoch.

Raises:
MisconfigurationException:
Expand Down Expand Up @@ -94,6 +96,7 @@ def __init__(
check_finite: bool = True,
stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,
check_on_train_epoch_end: bool = False,
):
super().__init__()
self.monitor = monitor
Expand All @@ -107,6 +110,7 @@ def __init__(
self.divergence_threshold = divergence_threshold
self.wait_count = 0
self.stopped_epoch = 0
self._check_on_train_epoch_end = check_on_train_epoch_end

if self.mode not in self.mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
Expand Down Expand Up @@ -135,7 +139,7 @@ def _validate_condition_metric(self, logs):
return True

@property
def monitor_op(self):
def monitor_op(self) -> Callable:
return self.mode_dict[self.mode]

def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -146,20 +150,28 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) ->
'patience': self.patience
}

def on_load_checkpoint(self, callback_state: Dict[str, Any]):
def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
self.wait_count = callback_state['wait_count']
self.stopped_epoch = callback_state['stopped_epoch']
self.best_score = callback_state['best_score']
self.patience = callback_state['patience']

def on_validation_end(self, trainer, pl_module):
def _should_skip_check(self, trainer) -> bool:
from pytorch_lightning.trainer.states import TrainerState
if trainer.state != TrainerState.FITTING or trainer.sanity_checking:
return trainer.state != TrainerState.FITTING or trainer.sanity_checking

def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
return
self._run_early_stopping_check(trainer)

def on_validation_end(self, trainer, pl_module) -> None:
if self._check_on_train_epoch_end or self._should_skip_check(trainer):
return

self._run_early_stopping_check(trainer)

def _run_early_stopping_check(self, trainer):
def _run_early_stopping_check(self, trainer) -> None:
"""
Checks whether the early stopping condition is met
and if so tells the trainer to stop the training.
Expand All @@ -170,7 +182,7 @@ def _run_early_stopping_check(self, trainer):
trainer.fast_dev_run # disable early_stopping with fast_dev_run
or not self._validate_condition_metric(logs) # short circuit if metric not present
):
return # short circuit if metric not present
return

current = logs.get(self.monitor)

Expand Down
85 changes: 54 additions & 31 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,13 @@ def test_early_stopping_no_val_step(tmpdir):
assert trainer.current_epoch < trainer.max_epochs - 1


@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
])
@pytest.mark.parametrize(
"stopping_threshold,divergence_theshold,losses,expected_epoch", [
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
]
)
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch):

class CurrentModel(BoringModel):
Expand Down Expand Up @@ -338,7 +340,7 @@ def validation_epoch_end(self, outputs):
limit_train_batches=limit_train_batches,
limit_val_batches=2,
min_steps=min_steps,
min_epochs=min_epochs
min_epochs=min_epochs,
)
trainer.fit(model)

Expand All @@ -359,8 +361,13 @@ def validation_epoch_end(self, outputs):
by_min_epochs = min_epochs * limit_train_batches

# Make sure the trainer stops for the max of all minimum requirements
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), (
trainer.global_step,
max(min_steps, by_early_stopping, by_min_epochs),
step_freeze,
min_steps,
min_epochs,
)

_logger.disabled = False

Expand All @@ -372,53 +379,69 @@ def test_early_stopping_mode_options():

class EarlyStoppingModel(BoringModel):

def __init__(self, expected_end_epoch):
def __init__(self, expected_end_epoch: int, early_stop_on_train: bool):
super().__init__()
self.expected_end_epoch = expected_end_epoch
self.early_stop_on_train = early_stop_on_train

def validation_epoch_end(self, outputs):
def _epoch_end(self) -> None:
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
self.log('abc', torch.tensor(val_loss))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we log with different keys on val/train for the test ?

loss = losses[self.current_epoch]
self.log('abc', torch.tensor(loss))
self.log('cba', torch.tensor(0))

def training_epoch_end(self, outputs):
if not self.early_stop_on_train:
return
self._epoch_end()

def validation_epoch_end(self, outputs):
if self.early_stop_on_train:
return
self._epoch_end()

def on_train_end(self) -> None:
assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed'


_ES_CHECK = dict(check_on_train_epoch_end=True)
_ES_CHECK_P3 = dict(patience=3, check_on_train_epoch_end=True)
_NO_WIN = dict(marks=RunIf(skip_windows=True))


@pytest.mark.parametrize(
"callbacks, expected_stop_epoch, accelerator, num_processes",
"callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes",
[
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1),
([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')], 3, None, 1),
pytest.param([EarlyStopping(monitor='abc'),
EarlyStopping(monitor='cba', patience=3)],
3,
'ddp_cpu',
2,
marks=RunIf(skip_windows=True)),
pytest.param([EarlyStopping(monitor='cba', patience=3),
EarlyStopping(monitor='abc')],
3,
'ddp_cpu',
2,
marks=RunIf(skip_windows=True)),
([EarlyStopping('abc'), EarlyStopping('cba', patience=3)], 3, False, None, 1),
([EarlyStopping('cba', patience=3), EarlyStopping('abc')], 3, False, None, 1),
pytest.param([EarlyStopping('abc'), EarlyStopping('cba', patience=3)], 3, False, 'ddp_cpu', 2, **_NO_WIN),
pytest.param([EarlyStopping('cba', patience=3), EarlyStopping('abc')], 3, False, 'ddp_cpu', 2, **_NO_WIN),
([EarlyStopping('abc', **_ES_CHECK), EarlyStopping('cba', **_ES_CHECK_P3)], 3, True, None, 1),
([EarlyStopping('cba', **_ES_CHECK_P3), EarlyStopping('abc', **_ES_CHECK)], 3, True, None, 1),
pytest.param([EarlyStopping('abc', **_ES_CHECK),
EarlyStopping('cba', **_ES_CHECK_P3)], 3, True, 'ddp_cpu', 2, **_NO_WIN),
pytest.param([EarlyStopping('cba', **_ES_CHECK_P3),
EarlyStopping('abc', **_ES_CHECK)], 3, True, 'ddp_cpu', 2, **_NO_WIN),
Comment on lines +415 to +424
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean this @carmocca from #6944 (comment)

],
)
def test_multiple_early_stopping_callbacks(
tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, accelerator: Optional[str], num_processes: int
tmpdir,
callbacks: List[EarlyStopping],
expected_stop_epoch: int,
check_on_train_epoch_end: bool,
accelerator: Optional[str],
num_processes: int,
):
"""Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""

model = EarlyStoppingModel(expected_stop_epoch)
model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end)

trainer = Trainer(
default_root_dir=tmpdir,
callbacks=callbacks,
overfit_batches=0.20,
max_epochs=20,
accelerator=accelerator,
num_processes=num_processes
num_processes=num_processes,
)
trainer.fit(model)
4 changes: 4 additions & 0 deletions tests/tuner/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,9 @@ def __init__(self, learning_rate=0.1, batch_size=2):

def test_lr_candidates_between_min_and_max(tmpdir):
"""Test that learning rate candidates are between min_lr and max_lr."""

class TestModel(BoringModel):

def __init__(self, learning_rate=0.1):
super().__init__()
self.save_hyperparameters()
Expand All @@ -322,7 +324,9 @@ def __init__(self, learning_rate=0.1):

def test_lr_finder_ends_before_num_training(tmpdir):
"""Tests learning rate finder ends before `num_training` steps."""

class TestModel(BoringModel):

def __init__(self, learning_rate=0.1):
super().__init__()
self.save_hyperparameters()
Expand Down