-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
We should allow Callback objects to optionally persist state that can be reloaded from checkpoints.
Motivation
We already manually save the state for early stopping and model checkpoint callbacks. This refactor would eliminate callback-specific code in the Trainer and extend the ability to user-written callbacks.
Pitch
This callback would just return a state_dict which the Trainer could store. The only thing that I am unclear how we should handle is for other callbacks how we want to reinitialize the state. If we can expect that the same exact callbacks will be passed to the Trainer then it should be trivial. Or we could expect that you only pass in a single instance of each callback class (eg. callbacks=[CustomerLogger(), EarlyStopping(), ModelCheckpoint()] and not callbacks=[CustomerLogger(params_a), CustomerLogger(params_b), EarlyStopping(), ModelCheckpoint()] and just keep a mapping of callback class to state dicts. However, if the user passed multiple callback instances of the same class I'm not sure how we would want to handle that.
I would recommend that we document the following constraints:
- All objects in the dictionary must be pickle-able.
- You cannot persist multiple instances of the same callback class.