Skip to content

Add more early stopping options #6795

@jlperla

Description

@jlperla

🚀 Feature

Additional early stopping features

@tchaton

First, max_time which should probably be in parallel to max_epochs in the main trainer loop. Why an additional one? Because (1) you never have any idea how long an epoch will be - especially if you tinker with hyperparameters; and (2) sometimes you want to give an amount of time and see which version of the model does the best given a fixed amount of time.

Second, a few variations on the EarlyStopping callback which is based on a metric.

  1. A way to stop because things have diverged completely and you doubt it can recover (e.g. too big, too small, or isnan)
  2. A way to stop because things have converged completely in terms of the quality of the approximation and there is no point doing futher iterations. This is distinct from convergence because it has ceased to get better - which is what it currently does.

For both of these, I think it is useful to have a min_epochs or somethign option to ensure that it doesn't stop right away. I think that is what #6705 is supposed to do though so it isn't needed here?

Finally, I think that it would be great in PL to have a way to log the reason for stopping so that it can be seen in the logs and be available within grid experiments view. Not sure the way to do that though, but maybe the callback could save a string in the logs?

Implementation

I implemented these two features in the currrent callaback with something like:

    def __init__(
        self,
        monitor: str = 'early_stop_on',
        min_delta: float = 0.0,
        patience: int = 3,
        verbose: bool = False,
        mode: str = 'min',
        strict: bool = True,
        stopping_threshold: float = 0.0,
        divergence_threshold: float = 1e6
    ):
        super().__init__()
        self.monitor = monitor
        self.patience = patience
        self.verbose = verbose
        self.strict = strict
        self.min_delta = min_delta
        self.wait_count = 0
        self.stopped_epoch = 0
        self.mode = mode
        self.warned_result_obj = False

        self.__init_monitor_mode()

        self.min_delta *= 1 if self.monitor_op == torch.gt else -1
        torch_inf = torch.tensor(np.Inf)
        self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf

        self.stopping_threshold = stopping_threshold
        self.divergence_threshold = divergence_threshold
        self.last_time = time()
        self.elapsed_time = 0.0

Then I added in something like

    def _run_early_stopping_check(self, trainer, pl_module):
        """
        Checks whether the early stopping condition is met
        and if so tells the trainer to stop the training.
        """
        # ADDED
        self.elapsed_time += time() - self.last_time
        self.last_time = time()

        logs = trainer.callback_metrics

        if (
            trainer.fast_dev_run  # disable early_stopping with fast_dev_run
            or not self._validate_condition_metric(logs)  # short circuit if metric not present
        ):
            return  # short circuit if metric not present

        current = logs.get(self.monitor)

        # when in dev debugging
        trainer.dev_debugger.track_early_stopping_history(self, current)

        if self.monitor_op(current - self.min_delta, self.best_score):
            self.best_score = current
            self.wait_count = 0
        else:
            self.wait_count += 1

        # ADDED
        if self.wait_count >= self.patience:
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True
            print(f"\n{OKCYAN}Stopping. Above patience of {self.patience} epochs without improvement of {self.min_delta}")            
        elif(self.monitor_op(current, self.stopping_threshold)):
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True
            print(f"\n{OKCYAN}Stopping. Below tolerance {self.monitor} = {logs[self.monitor]} <= {self.stopping_threshold}{ENDC}")          
        elif(self.monitor_op(-current,-self.divergence_threshold) or torch.isnan(current)):
            self.stopped_epoch = trainer.current_epoch
            trainer.should_stop = True
            print(f"\n{OKCYAN}Stopping. Divergence {self.monitor} = {logs[self.monitor]} >= {self.divergence_threshold} {ENDC}")      
        # stop every ddp process if any world process decides to stop
        trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)

Metadata

Metadata

Assignees

Labels

featureIs an improvement or enhancementhelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions