Skip to content

Commit 947d1cb

Browse files
ananthsubcarmoccaBorda
authored
[1/2] Add support for early stopping during training epoch end (#6944)
Co-authored-by: Carlos Mocholi <[email protected]> Co-authored-by: jirka <[email protected]>
1 parent ccd87ca commit 947d1cb

File tree

4 files changed

+83
-40
lines changed

4 files changed

+83
-40
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
13+
- Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/))
14+
15+
1216
- Added synchronization points before and after `setup` hooks are run ([#7202](https://github.com/PyTorchLightning/pytorch-lightning/pull/7202))
1317

1418

pytorch_lightning/callbacks/early_stopping.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
"""
2121
import logging
22-
from typing import Any, Dict, Optional, Tuple
22+
from typing import Any, Callable, Dict, Optional, Tuple
2323

2424
import numpy as np
2525
import torch
@@ -39,8 +39,8 @@ class EarlyStopping(Callback):
3939
monitor: quantity to be monitored.
4040
min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
4141
change of less than `min_delta`, will count as no improvement.
42-
patience: number of validation checks with no improvement
43-
after which training will be stopped. Under the default configuration, one validation check happens after
42+
patience: number of checks with no improvement
43+
after which training will be stopped. Under the default configuration, one check happens after
4444
every training epoch. However, the frequency of validation can be modified by setting various parameters on
4545
the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``.
4646
@@ -59,6 +59,8 @@ class EarlyStopping(Callback):
5959
check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
6060
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
6161
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
62+
check_on_train_epoch_end: whether to run early stopping at the end of the training epoch.
63+
If this is ``False``, then the check runs at the end of the validation epoch.
6264
6365
Raises:
6466
MisconfigurationException:
@@ -94,6 +96,7 @@ def __init__(
9496
check_finite: bool = True,
9597
stopping_threshold: Optional[float] = None,
9698
divergence_threshold: Optional[float] = None,
99+
check_on_train_epoch_end: bool = False,
97100
):
98101
super().__init__()
99102
self.monitor = monitor
@@ -107,6 +110,7 @@ def __init__(
107110
self.divergence_threshold = divergence_threshold
108111
self.wait_count = 0
109112
self.stopped_epoch = 0
113+
self._check_on_train_epoch_end = check_on_train_epoch_end
110114

111115
if self.mode not in self.mode_dict:
112116
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
@@ -135,7 +139,7 @@ def _validate_condition_metric(self, logs):
135139
return True
136140

137141
@property
138-
def monitor_op(self):
142+
def monitor_op(self) -> Callable:
139143
return self.mode_dict[self.mode]
140144

141145
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
@@ -146,20 +150,28 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) ->
146150
'patience': self.patience
147151
}
148152

149-
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
153+
def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
150154
self.wait_count = callback_state['wait_count']
151155
self.stopped_epoch = callback_state['stopped_epoch']
152156
self.best_score = callback_state['best_score']
153157
self.patience = callback_state['patience']
154158

155-
def on_validation_end(self, trainer, pl_module):
159+
def _should_skip_check(self, trainer) -> bool:
156160
from pytorch_lightning.trainer.states import TrainerState
157-
if trainer.state != TrainerState.FITTING or trainer.sanity_checking:
161+
return trainer.state != TrainerState.FITTING or trainer.sanity_checking
162+
163+
def on_train_epoch_end(self, trainer, pl_module, outputs) -> None:
164+
if not self._check_on_train_epoch_end or self._should_skip_check(trainer):
165+
return
166+
self._run_early_stopping_check(trainer)
167+
168+
def on_validation_end(self, trainer, pl_module) -> None:
169+
if self._check_on_train_epoch_end or self._should_skip_check(trainer):
158170
return
159171

160172
self._run_early_stopping_check(trainer)
161173

162-
def _run_early_stopping_check(self, trainer):
174+
def _run_early_stopping_check(self, trainer) -> None:
163175
"""
164176
Checks whether the early stopping condition is met
165177
and if so tells the trainer to stop the training.
@@ -170,7 +182,7 @@ def _run_early_stopping_check(self, trainer):
170182
trainer.fast_dev_run # disable early_stopping with fast_dev_run
171183
or not self._validate_condition_metric(logs) # short circuit if metric not present
172184
):
173-
return # short circuit if metric not present
185+
return
174186

175187
current = logs.get(self.monitor)
176188

tests/callbacks/test_early_stopping.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,13 @@ def test_early_stopping_no_val_step(tmpdir):
213213
assert trainer.current_epoch < trainer.max_epochs - 1
214214

215215

216-
@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [
217-
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
218-
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
219-
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
220-
])
216+
@pytest.mark.parametrize(
217+
"stopping_threshold,divergence_theshold,losses,expected_epoch", [
218+
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
219+
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
220+
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
221+
]
222+
)
221223
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch):
222224

223225
class CurrentModel(BoringModel):
@@ -338,7 +340,7 @@ def validation_epoch_end(self, outputs):
338340
limit_train_batches=limit_train_batches,
339341
limit_val_batches=2,
340342
min_steps=min_steps,
341-
min_epochs=min_epochs
343+
min_epochs=min_epochs,
342344
)
343345
trainer.fit(model)
344346

@@ -359,8 +361,13 @@ def validation_epoch_end(self, outputs):
359361
by_min_epochs = min_epochs * limit_train_batches
360362

361363
# Make sure the trainer stops for the max of all minimum requirements
362-
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \
363-
(trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs)
364+
assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), (
365+
trainer.global_step,
366+
max(min_steps, by_early_stopping, by_min_epochs),
367+
step_freeze,
368+
min_steps,
369+
min_epochs,
370+
)
364371

365372
_logger.disabled = False
366373

@@ -372,53 +379,69 @@ def test_early_stopping_mode_options():
372379

373380
class EarlyStoppingModel(BoringModel):
374381

375-
def __init__(self, expected_end_epoch):
382+
def __init__(self, expected_end_epoch: int, early_stop_on_train: bool):
376383
super().__init__()
377384
self.expected_end_epoch = expected_end_epoch
385+
self.early_stop_on_train = early_stop_on_train
378386

379-
def validation_epoch_end(self, outputs):
387+
def _epoch_end(self) -> None:
380388
losses = [8, 4, 2, 3, 4, 5, 8, 10]
381-
val_loss = losses[self.current_epoch]
382-
self.log('abc', torch.tensor(val_loss))
389+
loss = losses[self.current_epoch]
390+
self.log('abc', torch.tensor(loss))
383391
self.log('cba', torch.tensor(0))
384392

393+
def training_epoch_end(self, outputs):
394+
if not self.early_stop_on_train:
395+
return
396+
self._epoch_end()
397+
398+
def validation_epoch_end(self, outputs):
399+
if self.early_stop_on_train:
400+
return
401+
self._epoch_end()
402+
385403
def on_train_end(self) -> None:
386404
assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed'
387405

388406

407+
_ES_CHECK = dict(check_on_train_epoch_end=True)
408+
_ES_CHECK_P3 = dict(patience=3, check_on_train_epoch_end=True)
409+
_NO_WIN = dict(marks=RunIf(skip_windows=True))
410+
411+
389412
@pytest.mark.parametrize(
390-
"callbacks, expected_stop_epoch, accelerator, num_processes",
413+
"callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes",
391414
[
392-
([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1),
393-
([EarlyStopping(monitor='cba', patience=3),
394-
EarlyStopping(monitor='abc')], 3, None, 1),
395-
pytest.param([EarlyStopping(monitor='abc'),
396-
EarlyStopping(monitor='cba', patience=3)],
397-
3,
398-
'ddp_cpu',
399-
2,
400-
marks=RunIf(skip_windows=True)),
401-
pytest.param([EarlyStopping(monitor='cba', patience=3),
402-
EarlyStopping(monitor='abc')],
403-
3,
404-
'ddp_cpu',
405-
2,
406-
marks=RunIf(skip_windows=True)),
415+
([EarlyStopping('abc'), EarlyStopping('cba', patience=3)], 3, False, None, 1),
416+
([EarlyStopping('cba', patience=3), EarlyStopping('abc')], 3, False, None, 1),
417+
pytest.param([EarlyStopping('abc'), EarlyStopping('cba', patience=3)], 3, False, 'ddp_cpu', 2, **_NO_WIN),
418+
pytest.param([EarlyStopping('cba', patience=3), EarlyStopping('abc')], 3, False, 'ddp_cpu', 2, **_NO_WIN),
419+
([EarlyStopping('abc', **_ES_CHECK), EarlyStopping('cba', **_ES_CHECK_P3)], 3, True, None, 1),
420+
([EarlyStopping('cba', **_ES_CHECK_P3), EarlyStopping('abc', **_ES_CHECK)], 3, True, None, 1),
421+
pytest.param([EarlyStopping('abc', **_ES_CHECK),
422+
EarlyStopping('cba', **_ES_CHECK_P3)], 3, True, 'ddp_cpu', 2, **_NO_WIN),
423+
pytest.param([EarlyStopping('cba', **_ES_CHECK_P3),
424+
EarlyStopping('abc', **_ES_CHECK)], 3, True, 'ddp_cpu', 2, **_NO_WIN),
407425
],
408426
)
409427
def test_multiple_early_stopping_callbacks(
410-
tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, accelerator: Optional[str], num_processes: int
428+
tmpdir,
429+
callbacks: List[EarlyStopping],
430+
expected_stop_epoch: int,
431+
check_on_train_epoch_end: bool,
432+
accelerator: Optional[str],
433+
num_processes: int,
411434
):
412435
"""Ensure when using multiple early stopping callbacks we stop if any signals we should stop."""
413436

414-
model = EarlyStoppingModel(expected_stop_epoch)
437+
model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end)
415438

416439
trainer = Trainer(
417440
default_root_dir=tmpdir,
418441
callbacks=callbacks,
419442
overfit_batches=0.20,
420443
max_epochs=20,
421444
accelerator=accelerator,
422-
num_processes=num_processes
445+
num_processes=num_processes,
423446
)
424447
trainer.fit(model)

tests/tuner/test_lr_finder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,9 @@ def __init__(self, learning_rate=0.1, batch_size=2):
300300

301301
def test_lr_candidates_between_min_and_max(tmpdir):
302302
"""Test that learning rate candidates are between min_lr and max_lr."""
303+
303304
class TestModel(BoringModel):
305+
304306
def __init__(self, learning_rate=0.1):
305307
super().__init__()
306308
self.save_hyperparameters()
@@ -322,7 +324,9 @@ def __init__(self, learning_rate=0.1):
322324

323325
def test_lr_finder_ends_before_num_training(tmpdir):
324326
"""Tests learning rate finder ends before `num_training` steps."""
327+
325328
class TestModel(BoringModel):
329+
326330
def __init__(self, learning_rate=0.1):
327331
super().__init__()
328332
self.save_hyperparameters()

0 commit comments

Comments
 (0)