Skip to content

[RFC] Training Loop Checkpoint Consolidation #6672

@shuyingsunshine21

Description

@shuyingsunshine21

🚀 Feature

Currently, in training loop, at end of each training epoch and end of training, we checkpoint based on hook on_validation_end, which could be confusing. We would like to move those logics to dedicated hook.

Motivation

  1. currently, at end of each training epoch (run_training_epoch in training_loop.py), we call model checkpoint callback and early stopping callback explicitly when there is no validation stage
if should_train_only:
     self.check_checkpoint_callback(True)
     self.check_early_stopping_callback(True)

however, we call the on_validation_end hook

  for cb in callbacks:
      cb.on_validation_end(self.trainer, model)

This is confusing. For example, for ModelCheckpoint callback, on_validation_end is dependent on parameter every_n_val_epochs. When there is no validation stage, triggering relied on every_n_val_epochs is not ideal.

Instead, we could implement a specific hook for controlling, which is independent of every_n_val_epochs

  1. In addition, at end of training, we also call model checkpoint callback explicitly which again relies on callback.on_validation_end
    Similarly, we should move this to hook on_train_end, and trigger should be independent of every_n_val_epochs.

Pitch

For 1. inside run_training_epoch

at the end of training epoch, instead of doing such

  if should_train_only:
      self.check_checkpoint_callback(True)
      self.check_early_stopping_callback(True)

do

if should_train_only:
   self.trainer.call_hook("on_train_epoch_without_validation_end")

Note: why need a new hook instead of reusing the existing one on_train_epoch_end, as EarlyStoping and ModelCheckpoint are relying on log_train_epoch_end_metrics and update_learning_rates. As this hook is called in the training epoch when no validation is enabled, propose to name as on_train_epoch_without_validation_end.

Implement on_train_epoch_without_validation_end for ModelCheckpoint and EarlyStopping callback.

Trigger Change:

  • as it is triggered when validation is not enabled, instead of relying on every_n_val_epochs, we do it for every training epoch

For 2. inside on_train_end

at the end of whole training, remove

  # trigger checkpoint check. need to temporarily decrease the global step to avoid saving duplicates
  # when a checkpoint was saved at the last step
  self.trainer.global_step -= 1
  self.check_checkpoint_callback(should_update=True, is_last=True)
  self.trainer.global_step += 1

instead rely on hook

  # hook
  self.trainer.call_hook("on_train_end")

Implement on_train_end for ModelCheckpoint

Trigger Change:

  • instead of relying on every_n_val_epochs, we provide option trigger_on_train_end to determine whether checkpoint.
    By default, we will make it off.

Note: why we default to off, as if we turn it on, there is a risk of checkpointing error for the following scenario:

  • in checkpoint, we set monitor value to be some validation metric, like val_loss or val_auc, when training failure in the middle, and no validation has been ran before the failure (for example, user configure to check val for large number batches), it would complain that there is no valid monitor key and hence the model is not check-pointed.
    Once we turn it on, we could allow not validating the key for the end of the training. This way, when the monitor key is missing, though we will not be able to save top k, we could still save last (if save_last is turned on).

Alternatives

For model checkpoint, instead of checkpoint at end of every training epoch, we introduce something like every_n_train_epochs, and check trigger condition based on every_n_train_epochs.

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    designIncludes a design discussionfeatureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions