-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add required states for resumed ModelCheckpoint GC #10995
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2014dfb
77ee16d
8d35f0a
32d8fd8
50e376c
16861f1
0632f23
c7c6141
e64c560
763a159
8bd4ecf
f1b66b3
ba07b8c
e97a086
3d7994a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -344,13 +344,21 @@ def on_save_checkpoint( | |
| "best_model_path": self.best_model_path, | ||
| "current_score": self.current_score, | ||
| "dirpath": self.dirpath, | ||
| "best_k_models": self.best_k_models, | ||
| "kth_best_model_path": self.kth_best_model_path, | ||
| "kth_value": self.kth_value, | ||
| "last_model_path": self.last_model_path, | ||
|
Comment on lines
+347
to
+350
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just noticed an issue with doing this. Since we save each "ModelCheckpoint" mode sequentally, these attributes will not be correct depending on the order if more than 1 mode triggers a save for the same global step: Currently, a "top-k" checkpoint will not include the I'm not sure what would be the best solution here. I think we should start recommending multiple cc @awaelchli @ananthsub @jjenniferdai |
||
| } | ||
|
|
||
| def on_load_checkpoint( | ||
| self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any] | ||
| ) -> None: | ||
| self.best_model_score = callback_state["best_model_score"] | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.best_model_path = callback_state["best_model_path"] | ||
| self.best_k_models = callback_state.get("best_k_models", self.best_k_models) | ||
rohitgr7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path) | ||
| self.kth_value = callback_state.get("kth_value", self.kth_value) | ||
| self.last_model_path = callback_state.get("last_model_path", self.last_model_path) | ||
|
|
||
| def save_checkpoint(self, trainer: "pl.Trainer") -> None: | ||
| """Performs the main logic around saving a checkpoint. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.