-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
🚀 Feature
callbacks.ModelCheckpoint implements on_save_checkpoint and on_load_checkpoint to save and load it's state to/from a checkpoint. Save/Load all internal members to recover complete state from checkpoint.
Motivation
I use a ModelCheckpoint callback to save my Top-K (3) models during training. Sometimes I resume training at a certain point and want to train further epochs. But the loaded Callback has an empty best_k_models list, so this list is build up again. Older models are not tracked and I get three new models in same folder, and I do not know, which are the Top-3 ones.
Therefore a full state of ModelCheckpoint would resolve this.
Pitch
Extend the ModelCheckpoint methods to export and load all internal states. I'm not sure if there is missing something, but here is a try:
def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
return {
"best_model_score": self.best_model_score,
"best_model_path": self.best_model_path,
"best_k_models" : self.best_k_models,
}
def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]
self.best_k_models = checkpointed_state["best_k_models"]
Alternatives
Subclass the ModelCheckpoint to implement this, but I think this would complicate things.
Additional context
-/-