diff --git a/pytorch_lightning/callbacks/lambda_early_stopping.py b/pytorch_lightning/callbacks/lambda_early_stopping.py new file mode 100644 index 0000000000000..61cd4d0213e2b --- /dev/null +++ b/pytorch_lightning/callbacks/lambda_early_stopping.py @@ -0,0 +1,171 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Early Stopping +^^^^^^^^^^^^^^ + +Monitor a metric and stop training when it stops improving. + +""" +from typing import Any, Dict + +import numpy as np +import torch + +from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.utilities import rank_zero_warn +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class EarlyStopping(Callback): + r""" + Monitor a metric and stop training when it stops improving. + + Args: + monitor: quantity to be monitored. + min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute + change of less than `min_delta`, will count as no improvement. + patience: number of validation checks with no improvement + after which training will be stopped. Under the default configuration, one validation check happens after + every training epoch. However, the frequency of validation can be modified by setting various parameters on + the ``Trainer``, for example ``check_val_every_n_epoch`` and ``val_check_interval``. + + .. note:: + + It must be noted that the patience parameter counts the number of validation checks with + no improvement, and not the number of training epochs. Therefore, with parameters + ``check_val_every_n_epoch=10`` and ``patience=3``, the trainer will perform at least 40 training + epochs before being stopped. + + verbose: verbosity mode. + mode: one of ``'min'``, ``'max'``. In ``'min'`` mode, training will stop when the quantity + monitored has stopped decreasing and in ``'max'`` mode it will stop when the quantity + monitored has stopped increasing. + strict: whether to crash the training if `monitor` is not found in the validation metrics. + + Raises: + MisconfigurationException: + If ``mode`` is none of ``"min"`` or ``"max"``. + RuntimeError: + If the metric ``monitor`` is not available. + + Example:: + + >>> from pytorch_lightning import Trainer + >>> from pytorch_lightning.callbacks import EarlyStopping + >>> early_stopping = EarlyStopping('val_loss') + >>> trainer = Trainer(callbacks=[early_stopping]) + """ + mode_dict = { + "min": torch.lt, + "max": torch.gt, + } + + def __init__( + self, + monitor: str = "early_stop_on", + min_delta: float = 0.0, + patience: int = 3, + verbose: bool = False, + mode: str = "min", + strict: bool = True, + ): + super().__init__() + self.monitor = monitor + self.patience = patience + self.verbose = verbose + self.strict = strict + self.min_delta = min_delta + self.wait_count = 0 + self.stopped_epoch = 0 + self.mode = mode + + if self.mode not in self.mode_dict: + raise MisconfigurationException(f"`mode` can be {', '.join(self.mode_dict.keys())}, got {self.mode}") + + self.min_delta *= 1 if self.monitor_op == torch.gt else -1 + torch_inf = torch.tensor(np.Inf) + self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf + + def _validate_condition_metric(self, logs): + monitor_val = logs.get(self.monitor) + + error_msg = ( + f"Early stopping conditioned on metric `{self.monitor}` which is not available." + " Pass in or modify your `EarlyStopping` callback to use any of the following:" + f' `{"`, `".join(list(logs.keys()))}`' + ) + + if monitor_val is None: + if self.strict: + raise RuntimeError(error_msg) + if self.verbose > 0: + rank_zero_warn(error_msg, RuntimeWarning) + + return False + + return True + + @property + def monitor_op(self): + return self.mode_dict[self.mode] + + def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]: + return { + "wait_count": self.wait_count, + "stopped_epoch": self.stopped_epoch, + "best_score": self.best_score, + "patience": self.patience, + } + + def on_load_checkpoint(self, callback_state: Dict[str, Any]): + self.wait_count = callback_state["wait_count"] + self.stopped_epoch = callback_state["stopped_epoch"] + self.best_score = callback_state["best_score"] + self.patience = callback_state["patience"] + + def on_validation_end(self, trainer, pl_module): + from pytorch_lightning.trainer.states import TrainerState + + if trainer.state != TrainerState.FITTING or trainer.sanity_checking: + return + + self._run_early_stopping_check(trainer) + + def _run_early_stopping_check(self, trainer): + """Checks whether the early stopping condition is met and if so tells the trainer to stop the training.""" + logs = trainer.callback_metrics + + if trainer.fast_dev_run or not self._validate_condition_metric( # disable early_stopping with fast_dev_run + logs + ): # short circuit if metric not present + return # short circuit if metric not present + + current = logs.get(self.monitor) + + # when in dev debugging + trainer.dev_debugger.track_early_stopping_history(self, current) + + if self.monitor_op(current - self.min_delta, self.best_score): + self.best_score = current + self.wait_count = 0 + else: + self.wait_count += 1 + + if self.wait_count >= self.patience: + self.stopped_epoch = trainer.current_epoch + trainer.should_stop = True + + # stop every ddp process if any world process decides to stop + trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)