Skip to content

Add callback method for on_save_checkpoint #2401

@jeremyjordan

Description

@jeremyjordan

🚀 Feature

We should allow Callback objects to optionally persist state that can be reloaded from checkpoints.

Motivation

We already manually save the state for early stopping and model checkpoint callbacks. This refactor would eliminate callback-specific code in the Trainer and extend the ability to user-written callbacks.

Pitch

This callback would just return a state_dict which the Trainer could store. The only thing that I am unclear how we should handle is for other callbacks how we want to reinitialize the state. If we can expect that the same exact callbacks will be passed to the Trainer then it should be trivial. Or we could expect that you only pass in a single instance of each callback class (eg. callbacks=[CustomerLogger(), EarlyStopping(), ModelCheckpoint()] and not callbacks=[CustomerLogger(params_a), CustomerLogger(params_b), EarlyStopping(), ModelCheckpoint()] and just keep a mapping of callback class to state dicts. However, if the user passed multiple callback instances of the same class I'm not sure how we would want to handle that.

I would recommend that we document the following constraints:

  • All objects in the dictionary must be pickle-able.
  • You cannot persist multiple instances of the same callback class.

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions