Skip to content

ModelCheckpoint Callback save and restore extension #4911

@daandres

Description

@daandres

🚀 Feature

callbacks.ModelCheckpoint implements on_save_checkpoint and on_load_checkpoint to save and load it's state to/from a checkpoint. Save/Load all internal members to recover complete state from checkpoint.

Motivation

I use a ModelCheckpoint callback to save my Top-K (3) models during training. Sometimes I resume training at a certain point and want to train further epochs. But the loaded Callback has an empty best_k_models list, so this list is build up again. Older models are not tracked and I get three new models in same folder, and I do not know, which are the Top-3 ones.
Therefore a full state of ModelCheckpoint would resolve this.

Pitch

Extend the ModelCheckpoint methods to export and load all internal states. I'm not sure if there is missing something, but here is a try:

def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
    return {
        "best_model_score": self.best_model_score,
        "best_model_path": self.best_model_path,
        "best_k_models" : self.best_k_models,
    }

def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
    self.best_model_score = checkpointed_state["best_model_score"]
    self.best_model_path = checkpointed_state["best_model_path"]
    self.best_k_models = checkpointed_state["best_k_models"]

Alternatives

Subclass the ModelCheckpoint to implement this, but I think this would complicate things.

Additional context

-/-

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementhelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions