Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,21 @@ def on_save_checkpoint(
"best_model_path": self.best_model_path,
"current_score": self.current_score,
"dirpath": self.dirpath,
"best_k_models": self.best_k_models,
"kth_best_model_path": self.kth_best_model_path,
"kth_value": self.kth_value,
"last_model_path": self.last_model_path,
Comment on lines +347 to +350
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed an issue with doing this.

Since we save each "ModelCheckpoint" mode sequentally, these attributes will not be correct depending on the order if more than 1 mode triggers a save for the same global step:

https://github.com/PyTorchLightning/pytorch-lightning/blob/fe940e195dceb18eb9f3bd512cea56ae3405d464/pytorch_lightning/callbacks/model_checkpoint.py#L366-L373

Currently, a "top-k" checkpoint will not include the last_model_path path even if it's saved right after for this global step.

I'm not sure what would be the best solution here. I think we should start recommending multiple ModelCheckpoint instances as a best practice because these interactions between flags can be unintuitive.

cc @awaelchli @ananthsub @jjenniferdai
Related to #4335 and #11805 (comment)

}

def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
) -> None:
self.best_model_score = callback_state["best_model_score"]
self.best_model_path = callback_state["best_model_path"]
self.best_k_models = callback_state.get("best_k_models", self.best_k_models)
self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path)
self.kth_value = callback_state.get("kth_value", self.kth_value)
self.last_model_path = callback_state.get("last_model_path", self.last_model_path)

def save_checkpoint(self, trainer: "pl.Trainer") -> None:
"""Performs the main logic around saving a checkpoint.
Expand Down
34 changes: 34 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1202,3 +1202,37 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir):
)
trainer.fit(model)
assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"}


def test_model_checkpoint_saveload_ckpt(tmpdir):
ckpt = {
"monitor": "random_value",
"best_model_path": "epoch=10-step=1436.ckpt",
"best_model_score": torch.tensor(2.246),
"current_score": torch.tensor(1.5),
"dirpath": tmpdir,
"best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)},
"kth_best_model_path": "epoch=10-step=1436.ckpt",
"kth_value": torch.tensor(2.246),
"last_model_path": "last2245.ckpt",
}

# test on_save_checkpoint
cb_write = ModelCheckpoint(dirpath=tmpdir, monitor="random_value", save_top_k=-1, save_last=True)
for key, val in ckpt.items():
setattr(cb_write, key, val)
written_ckpt = cb_write.on_save_checkpoint("", "", "")
for state in ckpt:
assert ckpt[state] == written_ckpt[state]

# test on_load_checkpoint
# Note: "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint.
# We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them.
# "current_score" is left as initialized, i.e. None, and can therefore also be asserted
cb_restore = ModelCheckpoint(dirpath=tmpdir + "restore", monitor=None, save_top_k=-1, save_last=True)
cb_restore.on_load_checkpoint("", "", written_ckpt)
for key, val in written_ckpt.items():
if key not in ("current_score", "dirpath", "monitor"):
assert getattr(cb_restore, key) == val
else:
assert getattr(cb_restore, key) != val
11 changes: 9 additions & 2 deletions tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,15 @@ def get_trainer_args():

for before, after in zip(callbacks_before_resume, callback_capture.callbacks):
if isinstance(before, ModelCheckpoint):
assert before.best_model_path == after.best_model_path
assert before.best_model_score == after.best_model_score
for attribute in (
"best_model_path",
"best_model_score",
"best_k_models",
"kth_best_model_path",
"kth_value",
"last_model_path",
):
assert getattr(before, attribute) == getattr(after, attribute)


def test_callbacks_references_fit_ckpt_path(tmpdir):
Expand Down