-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
Currently, in training loop, at end of each training epoch and end of training, we checkpoint based on hook on_validation_end, which could be confusing. We would like to move those logics to dedicated hook.
Motivation
- currently, at end of each training epoch (
run_training_epochintraining_loop.py), we call model checkpoint callback and early stopping callback explicitly when there is no validation stage
if should_train_only:
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)
however, we call the on_validation_end hook
for cb in callbacks:
cb.on_validation_end(self.trainer, model)
This is confusing. For example, for ModelCheckpoint callback, on_validation_end is dependent on parameter every_n_val_epochs. When there is no validation stage, triggering relied on every_n_val_epochs is not ideal.
Instead, we could implement a specific hook for controlling, which is independent of every_n_val_epochs
- In addition, at end of training, we also call model checkpoint callback explicitly which again relies on
callback.on_validation_end
Similarly, we should move this to hookon_train_end, and trigger should be independent ofevery_n_val_epochs.
Pitch
For 1. inside run_training_epoch
at the end of training epoch, instead of doing such
if should_train_only:
self.check_checkpoint_callback(True)
self.check_early_stopping_callback(True)
do
if should_train_only:
self.trainer.call_hook("on_train_epoch_without_validation_end")
Note: why need a new hook instead of reusing the existing one on_train_epoch_end, as EarlyStoping and ModelCheckpoint are relying on log_train_epoch_end_metrics and update_learning_rates. As this hook is called in the training epoch when no validation is enabled, propose to name as on_train_epoch_without_validation_end.
Implement on_train_epoch_without_validation_end for ModelCheckpoint and EarlyStopping callback.
Trigger Change:
- as it is triggered when validation is not enabled, instead of relying on
every_n_val_epochs, we do it for every training epoch
For 2. inside on_train_end
at the end of whole training, remove
# trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
# when a checkpoint was saved at the last step
self.trainer.global_step -= 1
self.check_checkpoint_callback(should_update=True, is_last=True)
self.trainer.global_step += 1
instead rely on hook
# hook
self.trainer.call_hook("on_train_end")
Implement on_train_end for ModelCheckpoint
Trigger Change:
- instead of relying on
every_n_val_epochs, we provide optiontrigger_on_train_endto determine whether checkpoint.
By default, we will make it off.
Note: why we default to off, as if we turn it on, there is a risk of checkpointing error for the following scenario:
- in checkpoint, we set monitor value to be some validation metric, like
val_lossorval_auc, when training failure in the middle, and no validation has been ran before the failure (for example, user configure to check val for large number batches), it would complain that there is no valid monitor key and hence the model is not check-pointed.
Once we turn it on, we could allow not validating the key for the end of the training. This way, when the monitor key is missing, though we will not be able to save top k, we could still save last (if save_last is turned on).
Alternatives
For model checkpoint, instead of checkpoint at end of every training epoch, we introduce something like every_n_train_epochs, and check trigger condition based on every_n_train_epochs.