From 2a8f348a2a0f08f34a04be7496096044474f8369 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 9 Apr 2021 22:00:18 -0700 Subject: [PATCH 01/18] Add support for early stopping during training epoch end --- pytorch_lightning/callbacks/early_stopping.py | 17 +++++-- tests/callbacks/test_early_stopping.py | 47 +++++++++++++++---- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 9af576aafd596..9913d438cb4b8 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -146,15 +146,24 @@ def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> 'patience': self.patience } - def on_load_checkpoint(self, callback_state: Dict[str, Any]): + def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: self.wait_count = callback_state['wait_count'] self.stopped_epoch = callback_state['stopped_epoch'] self.best_score = callback_state['best_score'] self.patience = callback_state['patience'] - def on_validation_end(self, trainer, pl_module): + def _should_skip_check(self, trainer) -> bool: from pytorch_lightning.trainer.states import TrainerState - if trainer.state != TrainerState.FITTING or trainer.sanity_checking: + return trainer.state != TrainerState.FITTING or trainer.sanity_checking + + def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: + if not self.during_training or self._should_skip_check(trainer): + return + self._run_early_stopping_check(trainer) + + + def on_validation_end(self, trainer, pl_module): + if self.during_training or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) @@ -170,7 +179,7 @@ def _run_early_stopping_check(self, trainer): trainer.fast_dev_run # disable early_stopping with fast_dev_run or not self._validate_condition_metric(logs) # short circuit if metric not present ): - return # short circuit if metric not present + return current = logs.get(self.monitor) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 3844d16edb517..c685ed5a6356b 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -372,14 +372,25 @@ def test_early_stopping_mode_options(): class EarlyStoppingModel(BoringModel): - def __init__(self, expected_end_epoch): + def __init__(self, expected_end_epoch: int, during_training: bool): super().__init__() self.expected_end_epoch = expected_end_epoch + self.during_training = during_training + + def training_epoch_end(self, outputs): + if not self.during_training: + return + losses = [8, 4, 2, 3, 4, 5, 8, 10] + loss = losses[self.current_epoch] + self.log('abc', torch.tensor(loss)) + self.log('cba', torch.tensor(0)) def validation_epoch_end(self, outputs): + if self.during_training: + return losses = [8, 4, 2, 3, 4, 5, 8, 10] - val_loss = losses[self.current_epoch] - self.log('abc', torch.tensor(val_loss)) + loss = losses[self.current_epoch] + self.log('abc', torch.tensor(loss)) self.log('cba', torch.tensor(0)) def on_train_end(self) -> None: @@ -387,31 +398,51 @@ def on_train_end(self) -> None: @pytest.mark.parametrize( - "callbacks, expected_stop_epoch, accelerator, num_processes", + "callbacks, expected_stop_epoch, during_training, accelerator, num_processes", [ - ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, None, 1), + ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, False, None, 1), ([EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], 3, None, 1), + EarlyStopping(monitor='abc')], 3, False, None, 1), pytest.param([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, + False, 'ddp_cpu', 2, marks=RunIf(skip_windows=True)), pytest.param([EarlyStopping(monitor='cba', patience=3), EarlyStopping(monitor='abc')], 3, + False, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True)), + ([EarlyStopping(monitor='abc', during_training=True), EarlyStopping(monitor='cba', patience=3, during_training=True)], 3, True, None, 1), + ([EarlyStopping(monitor='cba', patience=3, during_training=True), + EarlyStopping(monitor='abc', during_training=True)], 3, True, None, 1), + pytest.param([EarlyStopping(monitor='abc', during_training=True), + EarlyStopping(monitor='cba', patience=3, during_training=True)], + 3, + True, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True)), + pytest.param([EarlyStopping(monitor='cba', patience=3, during_training=True), + EarlyStopping(monitor='abc', during_training=True)], + 3, + True, 'ddp_cpu', 2, marks=RunIf(skip_windows=True)), + ], ) def test_multiple_early_stopping_callbacks( - tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, accelerator: Optional[str], num_processes: int + tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, during_training: bool, accelerator: Optional[str], num_processes: int ): """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" - model = EarlyStoppingModel(expected_stop_epoch) + model = EarlyStoppingModel(expected_stop_epoch, during_training) trainer = Trainer( default_root_dir=tmpdir, From 29a5d6001b2cab649c8758251f74825f4791b573 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 22:22:03 -0700 Subject: [PATCH 02/18] Update early_stopping.py --- pytorch_lightning/callbacks/early_stopping.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 9913d438cb4b8..44103123679be 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -161,7 +161,6 @@ def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: return self._run_early_stopping_check(trainer) - def on_validation_end(self, trainer, pl_module): if self.during_training or self._should_skip_check(trainer): return From 3995acd2722e1b6a3c880288683353a805e270f5 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 22:24:53 -0700 Subject: [PATCH 03/18] Update test_early_stopping.py --- tests/callbacks/test_early_stopping.py | 27 +++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index c685ed5a6356b..251638e034a6d 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -417,28 +417,37 @@ def on_train_end(self) -> None: 'ddp_cpu', 2, marks=RunIf(skip_windows=True)), - ([EarlyStopping(monitor='abc', during_training=True), EarlyStopping(monitor='cba', patience=3, during_training=True)], 3, True, None, 1), - ([EarlyStopping(monitor='cba', patience=3, during_training=True), - EarlyStopping(monitor='abc', during_training=True)], 3, True, None, 1), - pytest.param([EarlyStopping(monitor='abc', during_training=True), - EarlyStopping(monitor='cba', patience=3, during_training=True)], + ([ + EarlyStopping(monitor='abc', during_training=True), + EarlyStopping(monitor='cba', patience=3, during_training=True) + ], 3, True, None, 1), + ([ + EarlyStopping(monitor='cba', patience=3, during_training=True), + EarlyStopping(monitor='abc', during_training=True) + ], 3, True, None, 1), + pytest.param([ + EarlyStopping(monitor='abc', during_training=True), + EarlyStopping(monitor='cba', patience=3, during_training=True) + ], 3, True, 'ddp_cpu', 2, marks=RunIf(skip_windows=True)), - pytest.param([EarlyStopping(monitor='cba', patience=3, during_training=True), - EarlyStopping(monitor='abc', during_training=True)], + pytest.param([ + EarlyStopping(monitor='cba', patience=3, during_training=True), + EarlyStopping(monitor='abc', during_training=True) + ], 3, True, 'ddp_cpu', 2, marks=RunIf(skip_windows=True)), - ], ) def test_multiple_early_stopping_callbacks( - tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, during_training: bool, accelerator: Optional[str], num_processes: int + tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, during_training: bool, accelerator: Optional[str], + num_processes: int ): """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" From f1cbde26e070c74eb8aa62b5631340c90b1c68cb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 23:28:51 -0700 Subject: [PATCH 04/18] rebase --- pytorch_lightning/callbacks/early_stopping.py | 6 ++-- tests/callbacks/test_early_stopping.py | 30 +++++++++---------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 44103123679be..4f8172a77523d 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -94,6 +94,7 @@ def __init__( check_finite: bool = True, stopping_threshold: Optional[float] = None, divergence_threshold: Optional[float] = None, + check_on_train_epoch_end: bool = False, ): super().__init__() self.monitor = monitor @@ -107,6 +108,7 @@ def __init__( self.divergence_threshold = divergence_threshold self.wait_count = 0 self.stopped_epoch = 0 + self._check_on_train_epoch_end = check_on_train_epoch_end if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") @@ -157,12 +159,12 @@ def _should_skip_check(self, trainer) -> bool: return trainer.state != TrainerState.FITTING or trainer.sanity_checking def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: - if not self.during_training or self._should_skip_check(trainer): + if not self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) def on_validation_end(self, trainer, pl_module): - if self.during_training or self._should_skip_check(trainer): + if self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 251638e034a6d..fffcafa7829d7 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -372,13 +372,13 @@ def test_early_stopping_mode_options(): class EarlyStoppingModel(BoringModel): - def __init__(self, expected_end_epoch: int, during_training: bool): + def __init__(self, expected_end_epoch: int, early_stop_on_train: bool): super().__init__() self.expected_end_epoch = expected_end_epoch - self.during_training = during_training + self.early_stop_on_train = early_stop_on_train def training_epoch_end(self, outputs): - if not self.during_training: + if not self.early_stop_on_train: return losses = [8, 4, 2, 3, 4, 5, 8, 10] loss = losses[self.current_epoch] @@ -386,7 +386,7 @@ def training_epoch_end(self, outputs): self.log('cba', torch.tensor(0)) def validation_epoch_end(self, outputs): - if self.during_training: + if self.early_stop_on_train: return losses = [8, 4, 2, 3, 4, 5, 8, 10] loss = losses[self.current_epoch] @@ -398,7 +398,7 @@ def on_train_end(self) -> None: @pytest.mark.parametrize( - "callbacks, expected_stop_epoch, during_training, accelerator, num_processes", + "callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes", [ ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, False, None, 1), ([EarlyStopping(monitor='cba', patience=3), @@ -418,16 +418,16 @@ def on_train_end(self) -> None: 2, marks=RunIf(skip_windows=True)), ([ - EarlyStopping(monitor='abc', during_training=True), - EarlyStopping(monitor='cba', patience=3, during_training=True) + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True) ], 3, True, None, 1), ([ - EarlyStopping(monitor='cba', patience=3, during_training=True), - EarlyStopping(monitor='abc', during_training=True) + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + EarlyStopping(monitor='abc', check_on_train_epoch_end=True) ], 3, True, None, 1), pytest.param([ - EarlyStopping(monitor='abc', during_training=True), - EarlyStopping(monitor='cba', patience=3, during_training=True) + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True) ], 3, True, @@ -435,8 +435,8 @@ def on_train_end(self) -> None: 2, marks=RunIf(skip_windows=True)), pytest.param([ - EarlyStopping(monitor='cba', patience=3, during_training=True), - EarlyStopping(monitor='abc', during_training=True) + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + EarlyStopping(monitor='abc', check_on_train_epoch_end=True) ], 3, True, @@ -446,12 +446,12 @@ def on_train_end(self) -> None: ], ) def test_multiple_early_stopping_callbacks( - tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, during_training: bool, accelerator: Optional[str], + tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, accelerator: Optional[str], num_processes: int ): """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" - model = EarlyStoppingModel(expected_stop_epoch, during_training) + model = EarlyStoppingModel(expected_stop_epoch, check_on_train_epoch_end) trainer = Trainer( default_root_dir=tmpdir, From eb17b6ea8b87f5e5b3b611a9d7259972ef44c259 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 23:29:22 -0700 Subject: [PATCH 05/18] Update test_early_stopping.py --- tests/callbacks/test_early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index fffcafa7829d7..0ce818c4ca230 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -446,8 +446,8 @@ def on_train_end(self) -> None: ], ) def test_multiple_early_stopping_callbacks( - tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, accelerator: Optional[str], - num_processes: int + tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, + accelerator: Optional[str], num_processes: int ): """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" From 99d82b32a15e69f7744d2521b5308cc0fd2efc1b Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 23:32:33 -0700 Subject: [PATCH 06/18] Update test_early_stopping.py --- tests/callbacks/test_early_stopping.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 0ce818c4ca230..983d0f069e526 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -377,21 +377,22 @@ def __init__(self, expected_end_epoch: int, early_stop_on_train: bool): self.expected_end_epoch = expected_end_epoch self.early_stop_on_train = early_stop_on_train - def training_epoch_end(self, outputs): - if not self.early_stop_on_train: - return + def _epoch_end(self) -> None: losses = [8, 4, 2, 3, 4, 5, 8, 10] loss = losses[self.current_epoch] self.log('abc', torch.tensor(loss)) self.log('cba', torch.tensor(0)) + + def training_epoch_end(self, outputs): + if not self.early_stop_on_train: + return + self._epoch_end() + def validation_epoch_end(self, outputs): if self.early_stop_on_train: return - losses = [8, 4, 2, 3, 4, 5, 8, 10] - loss = losses[self.current_epoch] - self.log('abc', torch.tensor(loss)) - self.log('cba', torch.tensor(0)) + self._epoch_end() def on_train_end(self) -> None: assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed' From 8351ff74f521712176454304f1d958ca89b73646 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 23:36:07 -0700 Subject: [PATCH 07/18] doc --- CHANGELOG.md | 2 ++ pytorch_lightning/callbacks/early_stopping.py | 6 ++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8247d1eb549e3..290083a9121e8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [1.3.0] - 2021-MM-DD ### Added +- Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/)) + - Added synchronization points before and after `setup` hooks are run ([#7202](https://github.com/PyTorchLightning/pytorch-lightning/pull/7202)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 4f8172a77523d..bbbb6523db9b2 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -39,8 +39,8 @@ class EarlyStopping(Callback): monitor: quantity to be monitored. min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than `min_delta`, will count as no improvement. - patience: number of validation checks with no improvement - after which training will be stopped. Under the default configuration, one validation check happens after + patience: number of checks with no improvement + after which training will be stopped. Under the default configuration, one check happens after every training epoch. However, the frequency of validation can be modified by setting various parameters on the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``. @@ -59,6 +59,8 @@ class EarlyStopping(Callback): check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite. stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold. divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. + check_on_train_epoch_end: whether to run early stopping at the end of the training epoch. + If this is ``False``, then the check runs at the end of validation. Raises: MisconfigurationException: From b6e0a1d8c352580ac347aef73291cdef53068615 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 23:38:42 -0700 Subject: [PATCH 08/18] Update early_stopping.py --- pytorch_lightning/callbacks/early_stopping.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index bbbb6523db9b2..5de4ec64a6197 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -139,7 +139,7 @@ def _validate_condition_metric(self, logs): return True @property - def monitor_op(self): + def monitor_op(self) -> Callable: return self.mode_dict[self.mode] def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: @@ -165,13 +165,13 @@ def on_train_epoch_end(self, trainer, pl_module, outputs) -> None: return self._run_early_stopping_check(trainer) - def on_validation_end(self, trainer, pl_module): + def on_validation_end(self, trainer, pl_module) -> None: if self._check_on_train_epoch_end or self._should_skip_check(trainer): return self._run_early_stopping_check(trainer) - def _run_early_stopping_check(self, trainer): + def _run_early_stopping_check(self, trainer) -> None: """ Checks whether the early stopping condition is met and if so tells the trainer to stop the training. From 392d7e1b366f01e5db6897f44d413a6e10c6b827 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 23:40:49 -0700 Subject: [PATCH 09/18] Update test_early_stopping.py --- tests/callbacks/test_early_stopping.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 983d0f069e526..abc127add5309 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -383,7 +383,6 @@ def _epoch_end(self) -> None: self.log('abc', torch.tensor(loss)) self.log('cba', torch.tensor(0)) - def training_epoch_end(self, outputs): if not self.early_stop_on_train: return From 05e5a420e6db14a186d9ec34a2088675fbe55c5d Mon Sep 17 00:00:00 2001 From: ananthsub Date: Fri, 9 Apr 2021 22:00:18 -0700 Subject: [PATCH 10/18] Add support for early stopping during training epoch end --- pytorch_lightning/callbacks/early_stopping.py | 9 +++++++++ tests/callbacks/test_early_stopping.py | 1 + 2 files changed, 10 insertions(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 5de4ec64a6197..628cd2f136896 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -93,10 +93,14 @@ def __init__( verbose: bool = False, mode: str = 'min', strict: bool = True, +<<<<<<< HEAD check_finite: bool = True, stopping_threshold: Optional[float] = None, divergence_threshold: Optional[float] = None, check_on_train_epoch_end: bool = False, +======= + during_training: bool = False, +>>>>>>> Add support for early stopping during training epoch end ): super().__init__() self.monitor = monitor @@ -110,7 +114,12 @@ def __init__( self.divergence_threshold = divergence_threshold self.wait_count = 0 self.stopped_epoch = 0 +<<<<<<< HEAD self._check_on_train_epoch_end = check_on_train_epoch_end +======= + self.mode = mode + self.during_training = during_training +>>>>>>> Add support for early stopping during training epoch end if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index abc127add5309..a1c33bfaea83f 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -443,6 +443,7 @@ def on_train_end(self) -> None: 'ddp_cpu', 2, marks=RunIf(skip_windows=True)), + ], ) def test_multiple_early_stopping_callbacks( From 5a2d5532449814b85f87bbae77e37d6427ae07ff Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 22:22:03 -0700 Subject: [PATCH 11/18] Update early_stopping.py --- pytorch_lightning/callbacks/early_stopping.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 628cd2f136896..5de4ec64a6197 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -93,14 +93,10 @@ def __init__( verbose: bool = False, mode: str = 'min', strict: bool = True, -<<<<<<< HEAD check_finite: bool = True, stopping_threshold: Optional[float] = None, divergence_threshold: Optional[float] = None, check_on_train_epoch_end: bool = False, -======= - during_training: bool = False, ->>>>>>> Add support for early stopping during training epoch end ): super().__init__() self.monitor = monitor @@ -114,12 +110,7 @@ def __init__( self.divergence_threshold = divergence_threshold self.wait_count = 0 self.stopped_epoch = 0 -<<<<<<< HEAD self._check_on_train_epoch_end = check_on_train_epoch_end -======= - self.mode = mode - self.during_training = during_training ->>>>>>> Add support for early stopping during training epoch end if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") From 31f63905aedd958c68eea74e6a5b62230785a2bd Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 14 Apr 2021 22:24:53 -0700 Subject: [PATCH 12/18] Update test_early_stopping.py --- tests/callbacks/test_early_stopping.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index a1c33bfaea83f..abc127add5309 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -443,7 +443,6 @@ def on_train_end(self) -> None: 'ddp_cpu', 2, marks=RunIf(skip_windows=True)), - ], ) def test_multiple_early_stopping_callbacks( From 2084d1d6a3b995d6622a2b0b4c7eb17a9005e4ac Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 26 Apr 2021 17:39:37 +0200 Subject: [PATCH 13/18] Fix yapf nonsense --- tests/callbacks/test_early_stopping.py | 130 ++++++++++++++++--------- 1 file changed, 82 insertions(+), 48 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index abc127add5309..685ad964f57b6 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -338,7 +338,7 @@ def validation_epoch_end(self, outputs): limit_train_batches=limit_train_batches, limit_val_batches=2, min_steps=min_steps, - min_epochs=min_epochs + min_epochs=min_epochs, ) trainer.fit(model) @@ -359,8 +359,13 @@ def validation_epoch_end(self, outputs): by_min_epochs = min_epochs * limit_train_batches # Make sure the trainer stops for the max of all minimum requirements - assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), \ - (trainer.global_step, max(min_steps, by_early_stopping, by_min_epochs), step_freeze, min_steps, min_epochs) + assert trainer.global_step == max(min_steps, by_early_stopping, by_min_epochs), ( + trainer.global_step, + max(min_steps, by_early_stopping, by_min_epochs), + step_freeze, + min_steps, + min_epochs, + ) _logger.disabled = False @@ -401,53 +406,82 @@ def on_train_end(self) -> None: "callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes", [ ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, False, None, 1), - ([EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], 3, False, None, 1), - pytest.param([EarlyStopping(monitor='abc'), - EarlyStopping(monitor='cba', patience=3)], - 3, - False, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), - pytest.param([EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], - 3, - False, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), - ([ - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True) - ], 3, True, None, 1), - ([ - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - EarlyStopping(monitor='abc', check_on_train_epoch_end=True) - ], 3, True, None, 1), - pytest.param([ - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True) - ], - 3, - True, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), - pytest.param([ - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - EarlyStopping(monitor='abc', check_on_train_epoch_end=True) - ], - 3, - True, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), + ( + [EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], + 3, + False, + None, + 1, + ), + pytest.param( + [EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], + 3, + False, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True), + ), + pytest.param( + [EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], + 3, + False, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True), + ), + ( + [ + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + ], + 3, + True, + None, + 1, + ), + ( + [ + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + ], + 3, + True, + None, + 1, + ), + pytest.param( + [ + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + ], + 3, + True, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True), + ), + pytest.param( + [ + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + ], + 3, + True, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True), + ), ], ) def test_multiple_early_stopping_callbacks( - tmpdir, callbacks: List[EarlyStopping], expected_stop_epoch: int, check_on_train_epoch_end: bool, - accelerator: Optional[str], num_processes: int + tmpdir, + callbacks: List[EarlyStopping], + expected_stop_epoch: int, + check_on_train_epoch_end: bool, + accelerator: Optional[str], + num_processes: int, ): """Ensure when using multiple early stopping callbacks we stop if any signals we should stop.""" @@ -459,6 +493,6 @@ def test_multiple_early_stopping_callbacks( overfit_batches=0.20, max_epochs=20, accelerator=accelerator, - num_processes=num_processes + num_processes=num_processes, ) trainer.fit(model) From 1691f45419d3ceaac6969fb66796862803c73baa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 26 Apr 2021 17:40:44 +0200 Subject: [PATCH 14/18] Apply suggestions from code review --- CHANGELOG.md | 2 ++ pytorch_lightning/callbacks/early_stopping.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 290083a9121e8..c944ce3e64c32 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [1.3.0] - 2021-MM-DD ### Added + + - Added support for the `EarlyStopping` callback to run at the end of the training epoch ([#6944](https://github.com/PyTorchLightning/pytorch-lightning/pull/6944/)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 5de4ec64a6197..2bbcf9a07bc85 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -60,7 +60,7 @@ class EarlyStopping(Callback): stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold. divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. check_on_train_epoch_end: whether to run early stopping at the end of the training epoch. - If this is ``False``, then the check runs at the end of validation. + If this is ``False``, then the check runs at the end of the validation epoch. Raises: MisconfigurationException: From 77706c78dd2d8d0eed5085d6676dfd967e8467fa Mon Sep 17 00:00:00 2001 From: jirka Date: Mon, 26 Apr 2021 23:57:50 +0200 Subject: [PATCH 15/18] format --- tests/callbacks/test_early_stopping.py | 121 ++++++++++--------------- tests/tuner/test_lr_finder.py | 4 + 2 files changed, 53 insertions(+), 72 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 685ad964f57b6..f64bd884571e4 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -213,11 +213,13 @@ def test_early_stopping_no_val_step(tmpdir): assert trainer.current_epoch < trainer.max_epochs - 1 -@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [ - (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), - (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), - (None, 15.9, [9, 4, 2, 16, 32, 64], 3), -]) +@pytest.mark.parametrize( + "stopping_threshold,divergence_theshold,losses,expected_epoch", [ + (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), + (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), + (None, 15.9, [9, 4, 2, 16, 32, 64], 3), + ] +) def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch): class CurrentModel(BoringModel): @@ -406,73 +408,48 @@ def on_train_end(self) -> None: "callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes", [ ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, False, None, 1), - ( - [EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], - 3, - False, - None, - 1, - ), - pytest.param( - [EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], - 3, - False, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True), - ), - pytest.param( - [EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], - 3, - False, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True), - ), - ( - [ - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - ], - 3, - True, - None, - 1, - ), - ( - [ - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - ], - 3, - True, - None, - 1, - ), - pytest.param( - [ - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - ], - 3, - True, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True), - ), - pytest.param( - [ - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - ], - 3, - True, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True), - ), + ([EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], 3, False, None, 1), + pytest.param([EarlyStopping(monitor='abc'), + EarlyStopping(monitor='cba', patience=3)], + 3, + False, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True)), + pytest.param([EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], + 3, + False, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True)), + ([ + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + ], 3, True, None, 1), + ([ + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + ], 3, True, None, 1), + pytest.param([ + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + ], + 3, + True, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True)), + pytest.param([ + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + ], + 3, + True, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True)), ], ) def test_multiple_early_stopping_callbacks( diff --git a/tests/tuner/test_lr_finder.py b/tests/tuner/test_lr_finder.py index 0a9fc7f1be03f..9834c1c8ad09b 100644 --- a/tests/tuner/test_lr_finder.py +++ b/tests/tuner/test_lr_finder.py @@ -300,7 +300,9 @@ def __init__(self, learning_rate=0.1, batch_size=2): def test_lr_candidates_between_min_and_max(tmpdir): """Test that learning rate candidates are between min_lr and max_lr.""" + class TestModel(BoringModel): + def __init__(self, learning_rate=0.1): super().__init__() self.save_hyperparameters() @@ -322,7 +324,9 @@ def __init__(self, learning_rate=0.1): def test_lr_finder_ends_before_num_training(tmpdir): """Tests learning rate finder ends before `num_training` steps.""" + class TestModel(BoringModel): + def __init__(self, learning_rate=0.1): super().__init__() self.save_hyperparameters() From 80f5bab477a06813273b1493f98d57d3402f992d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 27 Apr 2021 17:13:48 +0200 Subject: [PATCH 16/18] Formatting --- tests/callbacks/test_early_stopping.py | 99 +++++++++++++++----------- 1 file changed, 59 insertions(+), 40 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index f64bd884571e4..6caaf377a6ef5 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -410,46 +410,65 @@ def on_train_end(self) -> None: ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, False, None, 1), ([EarlyStopping(monitor='cba', patience=3), EarlyStopping(monitor='abc')], 3, False, None, 1), - pytest.param([EarlyStopping(monitor='abc'), - EarlyStopping(monitor='cba', patience=3)], - 3, - False, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), - pytest.param([EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], - 3, - False, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), - ([ - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - ], 3, True, None, 1), - ([ - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - ], 3, True, None, 1), - pytest.param([ - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - ], - 3, - True, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), - pytest.param([ - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - ], - 3, - True, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True)), + pytest.param( + [EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], + 3, + False, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True), + ), + pytest.param( + [EarlyStopping(monitor='cba', patience=3), + EarlyStopping(monitor='abc')], + 3, + False, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True), + ), + ( + [ + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + ], + 3, + True, + None, + 1, + ), + ( + [ + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + ], + 3, + True, + None, + 1, + ), + pytest.param( + [ + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + ], + 3, + True, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True), + ), + pytest.param( + [ + EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), + EarlyStopping(monitor='abc', check_on_train_epoch_end=True), + ], + 3, + True, + 'ddp_cpu', + 2, + marks=RunIf(skip_windows=True), + ), ], ) def test_multiple_early_stopping_callbacks( From a59c78ab485b1c619e2bfa84b8c2b940c0262f6a Mon Sep 17 00:00:00 2001 From: jirka Date: Tue, 27 Apr 2021 17:42:41 +0200 Subject: [PATCH 17/18] format? --- tests/callbacks/test_early_stopping.py | 77 +++++--------------------- 1 file changed, 15 insertions(+), 62 deletions(-) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 6caaf377a6ef5..d330955580874 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -404,71 +404,24 @@ def on_train_end(self) -> None: assert self.trainer.current_epoch == self.expected_end_epoch, 'Early Stopping Failed' +_ES_CHECK = dict(check_on_train_epoch_end=True) +_ES_CHECK_P3 = dict(patience=3, check_on_train_epoch_end=True) +_NO_WIN = dict(marks=RunIf(skip_windows=True)) + + @pytest.mark.parametrize( "callbacks, expected_stop_epoch, check_on_train_epoch_end, accelerator, num_processes", [ - ([EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], 3, False, None, 1), - ([EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], 3, False, None, 1), - pytest.param( - [EarlyStopping(monitor='abc'), EarlyStopping(monitor='cba', patience=3)], - 3, - False, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True), - ), - pytest.param( - [EarlyStopping(monitor='cba', patience=3), - EarlyStopping(monitor='abc')], - 3, - False, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True), - ), - ( - [ - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - ], - 3, - True, - None, - 1, - ), - ( - [ - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - ], - 3, - True, - None, - 1, - ), - pytest.param( - [ - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - ], - 3, - True, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True), - ), - pytest.param( - [ - EarlyStopping(monitor='cba', patience=3, check_on_train_epoch_end=True), - EarlyStopping(monitor='abc', check_on_train_epoch_end=True), - ], - 3, - True, - 'ddp_cpu', - 2, - marks=RunIf(skip_windows=True), - ), + ([EarlyStopping('abc'), EarlyStopping('cba', patience=3)], 3, False, None, 1), + ([EarlyStopping('cba', patience=3), EarlyStopping('abc')], 3, False, None, 1), + pytest.param([EarlyStopping('abc'), EarlyStopping('cba', patience=3)], 3, False, 'ddp_cpu', 2, **_NO_WIN), + pytest.param([EarlyStopping('cba', patience=3), EarlyStopping('abc')], 3, False, 'ddp_cpu', 2, **_NO_WIN), + ([EarlyStopping('abc', **_ES_CHECK), EarlyStopping('cba', **_ES_CHECK_P3)], 3, True, None, 1), + ([EarlyStopping('cba', **_ES_CHECK_P3), EarlyStopping('abc', **_ES_CHECK)], 3, True, None, 1), + pytest.param([EarlyStopping('abc', **_ES_CHECK), + EarlyStopping('cba', **_ES_CHECK_P3)], 3, True, 'ddp_cpu', 2, **_NO_WIN), + pytest.param([EarlyStopping('cba', **_ES_CHECK_P3), + EarlyStopping('abc', **_ES_CHECK)], 3, True, 'ddp_cpu', 2, **_NO_WIN), ], ) def test_multiple_early_stopping_callbacks( From 3b6b60626ad38b90fce783c509be267ac1a235cc Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 28 Apr 2021 01:30:00 -0700 Subject: [PATCH 18/18] Update early_stopping.py --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 2bbcf9a07bc85..680bc55ed2426 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,7 +19,7 @@ """ import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import numpy as np import torch