-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Additional early stopping features
First, max_time which should probably be in parallel to max_epochs in the main trainer loop. Why an additional one? Because (1) you never have any idea how long an epoch will be - especially if you tinker with hyperparameters; and (2) sometimes you want to give an amount of time and see which version of the model does the best given a fixed amount of time.
Second, a few variations on the EarlyStopping callback which is based on a metric.
- A way to stop because things have diverged completely and you doubt it can recover (e.g. too big, too small, or isnan)
- A way to stop because things have converged completely in terms of the quality of the approximation and there is no point doing futher iterations. This is distinct from convergence because it has ceased to get better - which is what it currently does.
For both of these, I think it is useful to have a min_epochs or somethign option to ensure that it doesn't stop right away. I think that is what #6705 is supposed to do though so it isn't needed here?
Finally, I think that it would be great in PL to have a way to log the reason for stopping so that it can be seen in the logs and be available within grid experiments view. Not sure the way to do that though, but maybe the callback could save a string in the logs?
Implementation
I implemented these two features in the currrent callaback with something like:
def __init__(
self,
monitor: str = 'early_stop_on',
min_delta: float = 0.0,
patience: int = 3,
verbose: bool = False,
mode: str = 'min',
strict: bool = True,
stopping_threshold: float = 0.0,
divergence_threshold: float = 1e6
):
super().__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode
self.warned_result_obj = False
self.__init_monitor_mode()
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
torch_inf = torch.tensor(np.Inf)
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
self.stopping_threshold = stopping_threshold
self.divergence_threshold = divergence_threshold
self.last_time = time()
self.elapsed_time = 0.0Then I added in something like
def _run_early_stopping_check(self, trainer, pl_module):
"""
Checks whether the early stopping condition is met
and if so tells the trainer to stop the training.
"""
# ADDED
self.elapsed_time += time() - self.last_time
self.last_time = time()
logs = trainer.callback_metrics
if (
trainer.fast_dev_run # disable early_stopping with fast_dev_run
or not self._validate_condition_metric(logs) # short circuit if metric not present
):
return # short circuit if metric not present
current = logs.get(self.monitor)
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
else:
self.wait_count += 1
# ADDED
if self.wait_count >= self.patience:
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True
print(f"\n{OKCYAN}Stopping. Above patience of {self.patience} epochs without improvement of {self.min_delta}")
elif(self.monitor_op(current, self.stopping_threshold)):
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True
print(f"\n{OKCYAN}Stopping. Below tolerance {self.monitor} = {logs[self.monitor]} <= {self.stopping_threshold}{ENDC}")
elif(self.monitor_op(-current,-self.divergence_threshold) or torch.isnan(current)):
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True
print(f"\n{OKCYAN}Stopping. Divergence {self.monitor} = {logs[self.monitor]} >= {self.divergence_threshold} {ENDC}")
# stop every ddp process if any world process decides to stop
trainer.should_stop = trainer.training_type_plugin.reduce_boolean_decision(trainer.should_stop)