Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `max_time` Trainer argument to limit training time ([#6823](https://github.com/PyTorchLightning/pytorch-lightning/pull/6823))


- Added new `EarlyStopping` parameters `stopping_threshold` and `divergence_threshold` ([#6868](https://github.com/PyTorchLightning/pytorch-lightning/pull/6868))



### Changed

- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
Expand Down
70 changes: 61 additions & 9 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
Monitor a metric and stop training when it stops improving.

"""
from typing import Any, Dict
import logging
from typing import Any, Dict, Optional, Tuple

import numpy as np
import torch
Expand All @@ -27,6 +28,8 @@
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException

log = logging.getLogger(__name__)


class EarlyStopping(Callback):
r"""
Expand All @@ -53,6 +56,9 @@ class EarlyStopping(Callback):
monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity
monitored has stopped increasing.
strict: whether to crash the training if `monitor` is not found in the validation metrics.
check_finite: When set ``True``, stops training when the monitor becomes NaN or infinite.
stopping_threshold: Stop training immediately once the monitored quantity reaches this threshold.
divergence_threshold: Stop training as soon as the monitored quantity becomes worse than this threshold.
Comment on lines +60 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use stop_limit and stop_loss to follow common financial terms

https://www.investopedia.com/articles/active-trading/091813/which-order-use-stoploss-or-stoplimit-orders.asp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jlperla what do you think of this name suggestion?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But might not be loss, and most people don't know finance. These are basically optimizer settings,which is more universal.

I think sticking with optimizer style lingo is ideal. Divergence is safe and says what it means . Normally one would call the success criteria as tolerances for optimizers. But that is because they are always comparing something (eg a value itself, changes in that value, or first order conditions) to zero.

Since this could presumably compare stopping for this other than close to zero(especially if you are tracking something where a larger number is better) , I think threshold is probably more general. But open minded of course

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also go super simple and go with min_threshold and max_threshold

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because that implies a direction to them.

Copy link
Contributor

@tchaton tchaton Apr 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, the names are good right now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not clear to me, it is you have some converging sequence so it stops when it starts diverse again? shall it be some patience for noise presence reason?
or natively it can be observing training and validation measure and stop overfitting - when these twos tart diverse

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If something is diverging, it is because you are in some sort of local minima or outside of an attractor and it could only return with some massive jumps (i.e. emulation in simulating annealing you are way off in the boons for your optima... in theory it come come back, but it might takes months. You are better off just restarting). So patience isn't the right thing to think of for that.


Raises:
MisconfigurationException:
Expand All @@ -72,6 +78,11 @@ class EarlyStopping(Callback):
'max': torch.gt,
}

order_dict = {
'min': "<",
'max': ">",
}

def __init__(
self,
monitor: str = 'early_stop_on',
Expand All @@ -80,16 +91,22 @@ def __init__(
verbose: bool = False,
mode: str = 'min',
strict: bool = True,
check_finite: bool = True,
stopping_threshold: Optional[float] = None,
divergence_threshold: Optional[float] = None,
):
super().__init__()
self.monitor = monitor
self.min_delta = min_delta
self.patience = patience
self.verbose = verbose
self.mode = mode
self.strict = strict
self.min_delta = min_delta
self.check_finite = check_finite
self.stopping_threshold = stopping_threshold
self.divergence_threshold = divergence_threshold
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode

if self.mode not in self.mode_dict:
raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}")
Expand Down Expand Up @@ -160,15 +177,50 @@ def _run_early_stopping_check(self, trainer):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)

if self.monitor_op(current - self.min_delta, self.best_score):
should_stop, reason = self._evalute_stopping_criteria(current)

# stop every ddp process if any world process decides to stop
should_stop = trainer.training_type_plugin.reduce_boolean_decision(should_stop)
trainer.should_stop = trainer.should_stop or should_stop
if should_stop:
self.stopped_epoch = trainer.current_epoch
if reason:
log.info(f"[{trainer.global_rank}] {reason}")

def _evalute_stopping_criteria(self, current: torch.Tensor) -> Tuple[bool, str]:
should_stop = False
reason = None
if self.check_finite and not torch.isfinite(current):
should_stop = True
reason = (
f"Monitored metric {self.monitor} = {current} is not finite."
f" Previous best value was {self.best_score:.3f}. Signaling Trainer to stop."
)
elif self.stopping_threshold is not None and self.monitor_op(current, self.stopping_threshold):
should_stop = True
reason = (
"Stopping threshold reached:"
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.stopping_threshold}."
" Signaling Trainer to stop."
)
elif self.divergence_threshold is not None and self.monitor_op(-current, -self.divergence_threshold):
should_stop = True
reason = (
"Divergence threshold reached:"
f" {self.monitor} = {current} {self.order_dict[self.mode]} {self.divergence_threshold}."
" Signaling Trainer to stop."
)
elif self.monitor_op(current - self.min_delta, self.best_score):
should_stop = False
self.best_score = current
self.wait_count = 0
else:
self.wait_count += 1

if self.wait_count >= self.patience:
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True
should_stop = True
reason = (
f"Monitored metric {self.monitor} did not improve in the last {self.wait_count} epochs."
f" Best score: {self.best_score:.3f}. Signaling Trainer to stop."
)

# stop every ddp process if any world process decides to stop
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)
return should_stop, reason
49 changes: 44 additions & 5 deletions tests/callbacks/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,25 +213,64 @@ def test_early_stopping_no_val_step(tmpdir):
assert trainer.current_epoch < trainer.max_epochs - 1


def test_early_stopping_functionality(tmpdir):
@pytest.mark.parametrize("stopping_threshold,divergence_theshold,losses,expected_epoch", [
(None, None, [8, 4, 2, 3, 4, 5, 8, 10], 5),
(2.9, None, [9, 8, 7, 6, 5, 6, 4, 3, 2, 1], 8),
(None, 15.9, [9, 4, 2, 16, 32, 64], 3),
])
def test_early_stopping_thresholds(tmpdir, stopping_threshold, divergence_theshold, losses, expected_epoch):

class CurrentModel(BoringModel):

def validation_epoch_end(self, outputs):
losses = [8, 4, 2, 3, 4, 5, 8, 10]
val_loss = losses[self.current_epoch]
self.log('abc', val_loss)

model = CurrentModel()

early_stopping = EarlyStopping(
monitor='abc',
stopping_threshold=stopping_threshold,
divergence_threshold=divergence_theshold,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[EarlyStopping(monitor='abc')],
callbacks=[early_stopping],
overfit_batches=0.20,
max_epochs=20,
)
trainer.fit(model)
assert trainer.current_epoch == 5, 'early_stopping failed'
assert trainer.current_epoch == expected_epoch, 'early_stopping failed'


@pytest.mark.parametrize("stop_value", [
torch.tensor(np.inf),
torch.tensor(np.nan),
])
def test_early_stopping_on_non_finite_monitor(tmpdir, stop_value):

losses = [4, 3, stop_value, 2, 1]
expected_stop_epoch = 2

class CurrentModel(BoringModel):

def validation_epoch_end(self, outputs):
val_loss = losses[self.current_epoch]
self.log('val_loss', val_loss)

model = CurrentModel()
early_stopping = EarlyStopping(
monitor='val_loss',
check_finite=True,
)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[early_stopping],
overfit_batches=0.20,
max_epochs=10,
)
trainer.fit(model)
assert trainer.current_epoch == expected_stop_epoch
assert early_stopping.stopped_epoch == expected_stop_epoch


@pytest.mark.parametrize('step_freeze, min_steps, min_epochs', [(5, 1, 1), (5, 1, 3), (3, 15, 1)])
Expand Down