diff --git a/CHANGELOG.md b/CHANGELOG.md index 8247d1eb549e3..c944ce3e64c32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added + +- Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/)) + + - Added synchronization points before and after `setup` hooks are run ([#7202](https://github.com/PyTorchLightning/pytorch-lightning/pull/7202)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 9af576aafd596..680bc55ed2426 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,7 +19,7 @@ """ import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import numpy as np import torch @@ -39,8 +39,8 @@ class EarlyStopping(Callback): monitor: quantity to be monitored. min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than `min_delta`, will count as no improvement. - patience: number of validation checks with no improvement - after which training will be stopped. Under the default configuration, one validation check happens after + patience: number of checks with no improvement + after which training will be stopped. Under the default configuration, one check happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``. @@ -59,6 +59,8 @@ class EarlyStopping(Callback): check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite. stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold. divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. + check_on_train_epoch_end: whether to run early stopping at the end of the training epoch. + If this is ``False``, then the check runs at the end of the validation epoch. Raises: MisconfigurationException: @@ -94,6 +96,7 @@ def __init__( check_finite: bool = True, stopping_threshold: Optional[float] = None, divergence_threshold: Optional[float] = None, + check_on_train_epoch_end: bool = False, ): super().__init__() self.monitor = monitor @@ -107,6 +110,7 @@ def __init__( self.divergence_threshold = divergence_threshold self.wait_count = 0 self.stopped_epoch = 0 + self._check_on_train_epoch_end = check_on_train_epoch_end if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") @@ -135,7 +139,7 @@ def _validate_condition_metric(self, logs): return True @property - def monitor_op(self): + def monitor_op(self) -> Callable: return self.mode_dict[self.mode] def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: @@ -146,20 +150,28 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> 'patience': self.patience } - def on_load_checkpoint(self, callback_state: Dict[str, Any]): + def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: self.wait_count = callback_state['wait_count'] self.stopped_epoch = callback_state['stopped_epoch'] self.best_score = callback_state['best_score'] self.patience = callback_state['patience'] - def on_validation_end(self, trainer, pl_module): + def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerState - if trainer.state != TrainerState.FITTING or trainer.sanity_checking: + return trainer.state != TrainerState.FITTING or trainer.sanity_checking + + def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: + if not self._check_on_train_epoch_end or self._should_skip_check(trainer): + return + self._run_early_stopping_check(trainer) + + def on_validation_end(self, trainer, pl_module) -> None: + if self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): + def _run_early_stopping_check(self, trainer) -> None: """ Checks whether the early stopping condition is met and if so tells the trainer to stop the training. @@ -170,7 +182,7 @@ def _run_early_stopping_check(self, trainer): trainer.fast_dev_run # disable early_stopping with fast_dev_run or not self._validate_condition_metric(logs) # short circuit if metric not present ): - return # short circuit if metric not present + return current = logs.get(self.monitor) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 3844d16edb517..d330955580874 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -213,11 +213,13 @@ def test_early_stopping_no_val_step(tmpdir): assert trainer.current_epoch < trainer.max_epochs - 1 -@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [ - (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), - (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), - (None, 15.9, [9, 4, 2, 16, 32, 64], 3), -]) +@pytest.mark.parametrize( + "stopping_threshold,divergence_theshold,losses,expected_epoch", [ + (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), + (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), + (None, 15.9, [9, 4, 2, 16, 32, 64], 3), + ] +) def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch): class CurrentModel(BoringModel): @@ -338,7 +340,7 @@ def validation_epoch_end(self, outputs): limit_train_batches=limit_train_batches, limit_val_batches=2, min_steps=min_steps, - min_epochs=min_epochs + min_epochs=min_epochs, ) trainer.fit(model) @@ -359,8 +361,13 @@ def validation_epoch_end(self, outputs): by_min_epochs = min_epochs * limit_train_batches # Make sure the trainer stops for the max of all minimum requirements - assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \ - (trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs) + assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), ( + trainer.global_step, + max(min_steps, by_early_stopping, by_min_epochs), + step_freeze, + min_steps, + min_epochs, + ) _logger.disabled = False @@ -372,46 +379,62 @@ def test_early_stopping_mode_options(): class EarlyStoppingModel(BoringModel): - def __init__(self, expected_end_epoch): + def __init__(self, expected_end_epoch: int, early_stop_on_train: bool): super().__init__() self.expected_end_epoch = expected_end_epoch + self.early_stop_on_train = early_stop_on_train - def validation_epoch_end(self, outputs): + def _epoch_end(self) -> None: losses = [8, 4, 2, 3, 4, 5, 8, 10] - val_loss = losses[self.current_epoch] - self.log('abc', torch.tensor(val_loss)) + loss = losses[self.current_epoch] + self.log('abc', torch.tensor(loss)) self.log('cba', torch.tensor(0)) + def training_epoch_end(self, outputs): + if not self.early_stop_on_train: + return + self._epoch_end() + + def validation_epoch_end(self, outputs): + if self.early_stop_on_train: + return + self._epoch_end() + def on_train_end(self) -> None: assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed' +_ES_CHECK = dict(check_on_train_epoch_end=True) +_ES_CHECK_P3 = dict(patience=3, check_on_train_epoch_end=True) +_NO_WIN = dict(marks=RunIf(skip_windows=True)) + + @pytest.mark.parametrize( - "callbacks, expected_stop_epoch, accelerator, num_processes", + "callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes", [ - ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1), - ([EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], 3, None, 1), - pytest.param([EarlyStopping(monitor='abc'), - EarlyStopping(monitor='cba', patience=3)], - 3, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), - pytest.param([EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], - 3, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), + ([EarlyStopping('abc'), EarlyStopping('cba', patience=3)], 3, False, None, 1), + ([EarlyStopping('cba', patience=3), EarlyStopping('abc')], 3, False, None, 1), + pytest.param([EarlyStopping('abc'), EarlyStopping('cba', patience=3)], 3, False, 'ddp_cpu', 2, **_NO_WIN), + pytest.param([EarlyStopping('cba', patience=3), EarlyStopping('abc')], 3, False, 'ddp_cpu', 2, **_NO_WIN), + ([EarlyStopping('abc', **_ES_CHECK), EarlyStopping('cba', **_ES_CHECK_P3)], 3, True, None, 1), + ([EarlyStopping('cba', **_ES_CHECK_P3), EarlyStopping('abc', **_ES_CHECK)], 3, True, None, 1), + pytest.param([EarlyStopping('abc', **_ES_CHECK), + EarlyStopping('cba', **_ES_CHECK_P3)], 3, True, 'ddp_cpu', 2, **_NO_WIN), + pytest.param([EarlyStopping('cba', **_ES_CHECK_P3), + EarlyStopping('abc', **_ES_CHECK)], 3, True, 'ddp_cpu', 2, **_NO_WIN), ], ) def test_multiple_early_stopping_callbacks( - tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, accelerator: Optional[str], num_processes: int + tmpdir, + callbacks: List[EarlyStopping], + expected_stop_epoch: int, + check_on_train_epoch_end: bool, + accelerator: Optional[str], + num_processes: int, ): """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" - model = EarlyStoppingModel(expected_stop_epoch) + model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end) trainer = Trainer( default_root_dir=tmpdir, @@ -419,6 +442,6 @@ def test_multiple_early_stopping_callbacks( overfit_batches=0.20, max_epochs=20, accelerator=accelerator, - num_processes=num_processes + num_processes=num_processes, ) trainer.fit(model) diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 0a9fc7f1be03f..9834c1c8ad09b 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -300,7 +300,9 @@ def __init__(self, learning_rate=0.1, batch_size=2): def test_lr_candidates_between_min_and_max(tmpdir): """Test that learning rate candidates are between min_lr and max_lr.""" + class TestModel(BoringModel): + def __init__(self, learning_rate=0.1): super().__init__() self.save_hyperparameters() @@ -322,7 +324,9 @@ def __init__(self, learning_rate=0.1): def test_lr_finder_ends_before_num_training(tmpdir): """Tests learning rate finder ends before `num_training` steps.""" + class TestModel(BoringModel): + def __init__(self, learning_rate=0.1): super().__init__() self.save_hyperparameters()