From 7ff7d9603d9de78f0ce7596cebbbe965ead9d624 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Apr 2021 11:29:59 +0200 Subject: [PATCH 01/14] stopping with NaN --- pytorch_lightning/callbacks/early_stopping.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 24ebcdf807357..57d951615b591 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -18,6 +18,7 @@ Monitor a metric and stop training when it stops improving. """ +import logging from typing import Any, Dict import numpy as np @@ -27,6 +28,8 @@ from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +log = logging.getLogger(__name__) + class EarlyStopping(Callback): r""" @@ -53,6 +56,8 @@ class EarlyStopping(Callback): monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity monitored has stopped increasing. strict: whether to crash the training if `monitor` is not found in the validation metrics. + check_finite: Stops training when the monitor becomes NaN or infinite. Set this argument to ``False`` + if this behavior is undesired. Raises: MisconfigurationException: @@ -80,16 +85,18 @@ def __init__( verbose: bool = False, mode: str = 'min', strict: bool = True, + check_finite: bool = True, ): super().__init__() self.monitor = monitor + self.min_delta = min_delta self.patience = patience self.verbose = verbose + self.mode = mode self.strict = strict - self.min_delta = min_delta + self.check_finite = check_finite self.wait_count = 0 self.stopped_epoch = 0 - self.mode = mode if self.mode not in self.mode_dict: raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") @@ -160,6 +167,13 @@ def _run_early_stopping_check(self, trainer): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) + if self.check_finite and not torch.isfinite(current): + trainer.should_stop = True + log.info( + f"[{trainer.global_rank}] Monitored metric {self.monitor} is not finite." + f" Current value is {current:.3f}, best value was {self.best_score:.3f}." + ) + if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 From 91184936e268c28a70c1caca17f8207bc7414a58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Tue, 6 Apr 2021 11:34:50 +0200 Subject: [PATCH 02/14] improve message --- pytorch_lightning/callbacks/early_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 57d951615b591..d635627ca71c6 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -170,8 +170,8 @@ def _run_early_stopping_check(self, trainer): if self.check_finite and not torch.isfinite(current): trainer.should_stop = True log.info( - f"[{trainer.global_rank}] Monitored metric {self.monitor} is not finite." - f" Current value is {current:.3f}, best value was {self.best_score:.3f}." + f"[{trainer.global_rank}] Monitored metric {self.monitor} = {current:.3f} is not finite." + f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop." ) if self.monitor_op(current - self.min_delta, self.best_score): From 7f301b821beedf364cfdaf3d3ff94a7530b3faef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Apr 2021 13:36:17 +0200 Subject: [PATCH 03/14] initial commit --- pytorch_lightning/callbacks/early_stopping.py | 41 ++++++++++++++----- 1 file changed, 31 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d635627ca71c6..ced96a5cb688b 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,7 +19,7 @@ """ import logging -from typing import Any, Dict +from typing import Any, Dict, Optional, Tuple import numpy as np import torch @@ -86,6 +86,8 @@ def __init__( mode: str = 'min', strict: bool = True, check_finite: bool = True, + stopping_threshold: Optional[float] = None, + divergence_threshold: Optional[float] = None, ): super().__init__() self.monitor = monitor @@ -95,6 +97,8 @@ def __init__( self.mode = mode self.strict = strict self.check_finite = check_finite + self.stopping_threshold = stopping_threshold + self.divergence_threshold = divergence_threshold self.wait_count = 0 self.stopped_epoch = 0 @@ -167,22 +171,39 @@ def _run_early_stopping_check(self, trainer): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) + should_stop, reason = self._evalute_stopping_criteria(current) + if should_stop: + self.stopped_epoch = trainer.current_epoch + if reason: + log.info(f"[{trainer.global_rank}] {reason}") + + # stop every ddp process if any world process decides to stop + should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) + trainer.should_stop = trainer.should_stop or should_stop + + def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: + should_stop = False + reason = None if self.check_finite and not torch.isfinite(current): - trainer.should_stop = True - log.info( - f"[{trainer.global_rank}] Monitored metric {self.monitor} = {current:.3f} is not finite." + should_stop = True + reason = ( + f"Monitored metric {self.monitor} = {current:.3f} is not finite." f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop." ) - if self.monitor_op(current - self.min_delta, self.best_score): + should_stop = False self.best_score = current self.wait_count = 0 + elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): + should_stop = True + reason = "" + elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold): + should_stop = True + reason = "" else: self.wait_count += 1 - if self.wait_count >= self.patience: - self.stopped_epoch = trainer.current_epoch - trainer.should_stop = True + should_stop = True + reason = "" - # stop every ddp process if any world process decides to stop - trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop) + return should_stop, reason From 189f939ab1250eac73282bc9e4017cf873641421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Apr 2021 13:58:00 +0200 Subject: [PATCH 04/14] added stopping reason --- pytorch_lightning/callbacks/early_stopping.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index ced96a5cb688b..02a6357997585 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -196,14 +196,23 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: self.wait_count = 0 elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): should_stop = True - reason = "" + reason = ( + f"Below tolerance {self.monitor} = {current} <= {self.stopping_threshold}" + " Signaling Trainer to stop." + ) elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold): should_stop = True - reason = "" + reason = ( + f"Divergence: {self.monitor} = {current} > {self.divergence_threshold}." + " Signaling Trainer to stop." + ) else: self.wait_count += 1 if self.wait_count >= self.patience: should_stop = True - reason = "" + reason = ( + f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs." + f" Best score: {self.best_score:.3f}. Signaling Trainer to stop." + ) return should_stop, reason From d81b6168fb2c98e2c322f79cfe205fac29733226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 7 Apr 2021 14:57:05 +0200 Subject: [PATCH 05/14] add docs --- pytorch_lightning/callbacks/early_stopping.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 02a6357997585..293c36aac4c9c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -58,6 +58,8 @@ class EarlyStopping(Callback): strict: whether to crash the training if `monitor` is not found in the validation metrics. check_finite: Stops training when the monitor becomes NaN or infinite. Set this argument to ``False`` if this behavior is undesired. + stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold. + divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. Raises: MisconfigurationException: From 76fe93c545eb636a08ccac69630ad41cf5997646 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 12:13:13 +0200 Subject: [PATCH 06/14] make test --- pytorch_lightning/callbacks/early_stopping.py | 8 ++++---- tests/callbacks/test_early_stopping.py | 18 +++++++++++++----- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 293c36aac4c9c..cd7a2a8ca12c4 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -192,10 +192,6 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: f"Monitored metric {self.monitor} = {current:.3f} is not finite." f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop." ) - if self.monitor_op(current - self.min_delta, self.best_score): - should_stop = False - self.best_score = current - self.wait_count = 0 elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): should_stop = True reason = ( @@ -208,6 +204,10 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: f"Divergence: {self.monitor} = {current} > {self.divergence_threshold}." " Signaling Trainer to stop." ) + elif self.monitor_op(current - self.min_delta, self.best_score): + should_stop = False + self.best_score = current + self.wait_count = 0 else: self.wait_count += 1 if self.wait_count >= self.patience: diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index cc619077ee136..7bf0f53499542 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -213,25 +213,33 @@ def test_early_stopping_no_val_step(tmpdir): assert trainer.current_epoch < trainer.max_epochs - 1 -def test_early_stopping_functionality(tmpdir): +@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [ + (None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5), + (2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8), + (None, 15.9, [9, 4, 2, 16, 32, 64], 3), +]) +def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch): class CurrentModel(BoringModel): def validation_epoch_end(self, outputs): - losses = [8, 4, 2, 3, 4, 5, 8, 10] val_loss = losses[self.current_epoch] self.log('abc', val_loss) model = CurrentModel() - + early_stopping = EarlyStopping( + monitor='abc', + stopping_threshold=stopping_threshold, + divergence_threshold=divergence_theshold, + ) trainer = Trainer( default_root_dir=tmpdir, - callbacks=[EarlyStopping(monitor='abc')], + callbacks=[early_stopping], overfit_batches=0.20, max_epochs=20, ) trainer.fit(model) - assert trainer.current_epoch == 5, 'early_stopping failed' + assert trainer.current_epoch == expected_epoch, 'early_stopping failed' @pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)]) From 23dde2626b200370816682b4eb175f3405e57045 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 15:49:05 +0200 Subject: [PATCH 07/14] skip formatting for inf --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index cd7a2a8ca12c4..c383ecffb940c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -189,7 +189,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: if self.check_finite and not torch.isfinite(current): should_stop = True reason = ( - f"Monitored metric {self.monitor} = {current:.3f} is not finite." + f"Monitored metric {self.monitor} = {current} is not finite." f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop." ) elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): From 569a7faa190c91c6c2b81e4deb559d79c54c8469 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 15:50:10 +0200 Subject: [PATCH 08/14] Update pytorch_lightning/callbacks/early_stopping.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index c383ecffb940c..526f8c0de3122 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -195,7 +195,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): should_stop = True reason = ( - f"Below tolerance {self.monitor} = {current} <= {self.stopping_threshold}" + f"Below tolerance {self.monitor} = {current} <= {self.stopping_threshold}." " Signaling Trainer to stop." ) elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold): From 1da0324940e16ae506f31febdd0ccc1c470918f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 16:28:57 +0200 Subject: [PATCH 09/14] add changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d2629d3928f0..1febd67f742eb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -88,6 +88,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `model` parameter to precision plugins' `clip_gradients` signature ([#6764](https://github.com/PyTorchLightning/pytorch-lightning/pull/6764)) +- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868)) + + ### Changed - Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259)) From b6a6a7c2993bbb4a9b82c49219d9e2aef22c5f15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 21:31:20 +0200 Subject: [PATCH 10/14] rearrange --- pytorch_lightning/callbacks/early_stopping.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 526f8c0de3122..92a994269130c 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -174,14 +174,14 @@ def _run_early_stopping_check(self, trainer): trainer.dev_debugger.track_early_stopping_history(self, current) should_stop, reason = self._evalute_stopping_criteria(current) - if should_stop: - self.stopped_epoch = trainer.current_epoch - if reason: - log.info(f"[{trainer.global_rank}] {reason}") # stop every ddp process if any world process decides to stop should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop) trainer.should_stop = trainer.should_stop or should_stop + if should_stop: + self.stopped_epoch = trainer.current_epoch + if reason: + log.info(f"[{trainer.global_rank}] {reason}") def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: should_stop = False From 948e5a00c9a78400086e35e11f4c18522326bc2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 8 Apr 2021 21:41:09 +0200 Subject: [PATCH 11/14] test for inf value --- tests/callbacks/test_early_stopping.py | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/callbacks/test_early_stopping.py b/tests/callbacks/test_early_stopping.py index 7bf0f53499542..3844d16edb517 100644 --- a/tests/callbacks/test_early_stopping.py +++ b/tests/callbacks/test_early_stopping.py @@ -242,6 +242,37 @@ def validation_epoch_end(self, outputs): assert trainer.current_epoch == expected_epoch, 'early_stopping failed' +@pytest.mark.parametrize("stop_value", [ + torch.tensor(np.inf), + torch.tensor(np.nan), +]) +def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value): + + losses = [4, 3, stop_value, 2, 1] + expected_stop_epoch = 2 + + class CurrentModel(BoringModel): + + def validation_epoch_end(self, outputs): + val_loss = losses[self.current_epoch] + self.log('val_loss', val_loss) + + model = CurrentModel() + early_stopping = EarlyStopping( + monitor='val_loss', + check_finite=True, + ) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[early_stopping], + overfit_batches=0.20, + max_epochs=10, + ) + trainer.fit(model) + assert trainer.current_epoch == expected_stop_epoch + assert early_stopping.stopped_epoch == expected_stop_epoch + + @pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)]) def test_min_steps_override_early_stopping_functionality(tmpdir, step_freeze: int, min_steps: int, min_epochs: int): """Excepted Behaviour: From 853c2d7e9aeabafe708e85d0d5f22ca64b9f3ebe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 14 Apr 2021 14:40:22 +0200 Subject: [PATCH 12/14] change default for check_finite --- pytorch_lightning/callbacks/early_stopping.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 92a994269130c..fac19fe37f1a4 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -56,8 +56,7 @@ class EarlyStopping(Callback): monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity monitored has stopped increasing. strict: whether to crash the training if `monitor` is not found in the validation metrics. - check_finite: Stops training when the monitor becomes NaN or infinite. Set this argument to ``False`` - if this behavior is undesired. + check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite. stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold. divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold. From cb90c76d510bb7083282b6349baff8158e4384e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 11:05:26 +0200 Subject: [PATCH 13/14] typo --- pytorch_lightning/callbacks/early_stopping.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index fac19fe37f1a4..357118e8a1a2d 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -194,7 +194,7 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): should_stop = True reason = ( - f"Below tolerance {self.monitor} = {current} <= {self.stopping_threshold}." + f"Below tolerance: {self.monitor} = {current} <= {self.stopping_threshold}." " Signaling Trainer to stop." ) elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold): From c70666873121a805d743a7b6b6e8cc869a0fa263 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 16 Apr 2021 14:57:25 +0200 Subject: [PATCH 14/14] message --- pytorch_lightning/callbacks/early_stopping.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 357118e8a1a2d..9af576aafd596 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -78,6 +78,11 @@ class EarlyStopping(Callback): 'max': torch.gt, } + order_dict = { + 'min': "<", + 'max': ">", + } + def __init__( self, monitor: str = 'early_stop_on', @@ -194,13 +199,15 @@ def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]: elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold): should_stop = True reason = ( - f"Below tolerance: {self.monitor} = {current} <= {self.stopping_threshold}." + "Stopping threshold reached:" + f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}." " Signaling Trainer to stop." ) elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold): should_stop = True reason = ( - f"Divergence: {self.monitor} = {current} > {self.divergence_threshold}." + "Divergence threshold reached:" + f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}." " Signaling Trainer to stop." ) elif self.monitor_op(current - self.min_delta, self.best_score):