From 2014dfbe4a9c5859137552e3bea087dc3410e43d Mon Sep 17 00:00:00 2001 From: ORippler Date: Wed, 8 Dec 2021 14:37:28 +0100 Subject: [PATCH 01/14] Add required states for resumed ModelCheckpoint GC --- pytorch_lightning/callbacks/model_checkpoint.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index cd307d18bc03a..36c00cfda60ff 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -342,6 +342,10 @@ def on_save_checkpoint( "best_model_path": self.best_model_path, "current_score": self.current_score, "dirpath": self.dirpath, + "best_k_models": self.best_k_models, + "kth_best_model_path": self.kth_best_model_path, + "kth_value": self.kth_value, + "last_model_path": self.last_model_path, } def on_load_checkpoint( @@ -349,6 +353,10 @@ def on_load_checkpoint( ) -> None: self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] + self.best_k_models = callback_state["best_k_models"] + self.kth_best_model_path = callback_state["kth_best_model_path"] + self.kth_value = callback_state["kth_value"] + self.last_model_path = callback_state["last_model_path"] def save_checkpoint(self, trainer: "pl.Trainer") -> None: """Performs the main logic around saving a checkpoint. From 77ee16ddec2cef263bca3e8fdb7de12ceeeb5f7d Mon Sep 17 00:00:00 2001 From: ORippler Date: Wed, 8 Dec 2021 17:23:42 +0100 Subject: [PATCH 02/14] Add backwards compatibility with legacy cktps Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- pytorch_lightning/callbacks/model_checkpoint.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 36c00cfda60ff..05444fb3d301e 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -353,10 +353,10 @@ def on_load_checkpoint( ) -> None: self.best_model_score = callback_state["best_model_score"] self.best_model_path = callback_state["best_model_path"] - self.best_k_models = callback_state["best_k_models"] - self.kth_best_model_path = callback_state["kth_best_model_path"] - self.kth_value = callback_state["kth_value"] - self.last_model_path = callback_state["last_model_path"] + self.best_k_models = callback_state.get("best_k_models", self.best_k_models) + self.kth_best_model_path = callback_state.get("kth_best_model_path", self.kth_best_model_path) + self.kth_value = callback_state.get("kth_value", self.kth_value) + self.last_model_path = callback_state.get("last_model_path", self.last_model_path) def save_checkpoint(self, trainer: "pl.Trainer") -> None: """Performs the main logic around saving a checkpoint. From 8d35f0ad0bbe24a4210e3ef07af87cf70b867298 Mon Sep 17 00:00:00 2001 From: ORippler Date: Thu, 9 Dec 2021 10:50:38 +0100 Subject: [PATCH 03/14] Add test to check if attrs are written to ckpt Note that we do not yet check for proper loading/reinstantiation of ModelCheckpooint based on the ckpt written to disk --- tests/checkpointing/test_model_checkpoint.py | 22 ++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index fa08057733f68..2d29fbf0a9628 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1202,3 +1202,25 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir): ) trainer.fit(model) assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"} + +def test_model_checkpoint_attributes(tmpdir): + seed_everything() + model = LogInTwoMethods() + + epochs = 2 + checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor=None, save_top_k=-1, save_last=True) + trainer = Trainer( + default_root_dir=tmpdir, + callbacks=[checkpoint_callback], + limit_train_batches=10, + limit_val_batches=10, + max_epochs=epochs, + logger=False, + ) + + trainer.fit(model) + + checkpoint = torch.load(os.path.join(tmpdir, 'last.ckpt'))['callbacks'][checkpoint_callback.state_key] + + for k in ("best_k_models", "kth_best_model_path", "kth_value", "last_model_path"): + assert checkpoint[k] == getattr(checkpoint_callback, k) \ No newline at end of file From 32d8fd8dbaee79ed7a3b290c7156a9574b8f6729 Mon Sep 17 00:00:00 2001 From: ORippler Date: Thu, 9 Dec 2021 10:54:05 +0100 Subject: [PATCH 04/14] Test if attributes are restored properly from ckpt --- tests/checkpointing/test_model_checkpoint.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 2d29fbf0a9628..e1619880d69f4 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1223,4 +1223,10 @@ def test_model_checkpoint_attributes(tmpdir): checkpoint = torch.load(os.path.join(tmpdir, 'last.ckpt'))['callbacks'][checkpoint_callback.state_key] for k in ("best_k_models", "kth_best_model_path", "kth_value", "last_model_path"): - assert checkpoint[k] == getattr(checkpoint_callback, k) \ No newline at end of file + assert checkpoint[k] == getattr(checkpoint_callback, k) + + restored_callback = ModelCheckpoint(dirpath=tmpdir, monitor=None, save_top_k=-1, save_last=True) + restored_callback.on_load_checkpoint('','', checkpoint) + + for k in ("best_k_models", "kth_best_model_path", "kth_value", "last_model_path"): + assert checkpoint[k] == getattr(restored_callback, k) \ No newline at end of file From 50e376c0050119df9df6a213c7d4d979a68733f3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Dec 2021 09:55:28 +0000 Subject: [PATCH 05/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/checkpointing/test_model_checkpoint.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index e1619880d69f4..5624fa61ec009 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1203,6 +1203,7 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir): trainer.fit(model) assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"} + def test_model_checkpoint_attributes(tmpdir): seed_everything() model = LogInTwoMethods() @@ -1220,13 +1221,13 @@ def test_model_checkpoint_attributes(tmpdir): trainer.fit(model) - checkpoint = torch.load(os.path.join(tmpdir, 'last.ckpt'))['callbacks'][checkpoint_callback.state_key] + checkpoint = torch.load(os.path.join(tmpdir, "last.ckpt"))["callbacks"][checkpoint_callback.state_key] for k in ("best_k_models", "kth_best_model_path", "kth_value", "last_model_path"): assert checkpoint[k] == getattr(checkpoint_callback, k) - + restored_callback = ModelCheckpoint(dirpath=tmpdir, monitor=None, save_top_k=-1, save_last=True) - restored_callback.on_load_checkpoint('','', checkpoint) + restored_callback.on_load_checkpoint("", "", checkpoint) for k in ("best_k_models", "kth_best_model_path", "kth_value", "last_model_path"): - assert checkpoint[k] == getattr(restored_callback, k) \ No newline at end of file + assert checkpoint[k] == getattr(restored_callback, k) From 16861f1c08b2b8e0b120bc6655449eb2b3752426 Mon Sep 17 00:00:00 2001 From: ORippler Date: Wed, 15 Dec 2021 18:46:04 +0100 Subject: [PATCH 06/14] Fix broken `test_callbacks_state_fit_ckpt_path` `ModelCheckpoint` is configured to save after every epoch, but `trainer.fit` is called with `max_steps = 1` Note there may be a better way of doing this, where `ModelCheckpoint` is called after `training_step` --- tests/models/test_restore.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 1139e6fb5e8ad..08888ae101a91 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -246,7 +246,7 @@ def get_trainer_args(): checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, - max_steps=1, + max_epochs=1, logger=False, callbacks=[checkpoint, callback_capture], limit_val_batches=2, From 0632f23e864bf77ff489b606b0f346b84cc5da16 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 16 Dec 2021 13:07:57 +0100 Subject: [PATCH 07/14] Update test_restore.py --- tests/models/test_restore.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 08888ae101a91..56cbb88903643 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -243,7 +243,10 @@ def test_callbacks_state_fit_ckpt_path(tmpdir): callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): - checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) + # save_on_train_epoch_end must be set explicitly since otherwise it will be changed internally causing the + # state_key to not match anymore upon loading and thus not loading the callbacks state + # (this is desired behavior, it just conflicts with the place where CaptureCallbacksBeforeTraining intercepts) + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True, save_on_train_epoch_end=False) trainer_args = dict( default_root_dir=tmpdir, max_epochs=1, From c7c6141cb9e5d1c262918834194d94f1074099d7 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 16 Dec 2021 13:09:10 +0100 Subject: [PATCH 08/14] Update test_restore.py --- tests/models/test_restore.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 56cbb88903643..566d80d112d2c 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -243,9 +243,10 @@ def test_callbacks_state_fit_ckpt_path(tmpdir): callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): - # save_on_train_epoch_end must be set explicitly since otherwise it will be changed internally causing the - # state_key to not match anymore upon loading and thus not loading the callbacks state - # (this is desired behavior, it just conflicts with the place where CaptureCallbacksBeforeTraining intercepts) + # save_on_train_epoch_end must be set explicitly since otherwise it will be + # changed internally causing the state_key to not match anymore upon loading + # and thus not loading the callbacks state (this is desired behavior, it just + # conflicts with the place where CaptureCallbacksBeforeTraining intercepts) checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True, save_on_train_epoch_end=False) trainer_args = dict( default_root_dir=tmpdir, From e64c560c4426ce70bf33275ca2cda90aebf52751 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 16 Dec 2021 12:11:15 +0000 Subject: [PATCH 09/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/models/test_restore.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 566d80d112d2c..b0f19687ce3ab 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -243,9 +243,9 @@ def test_callbacks_state_fit_ckpt_path(tmpdir): callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): - # save_on_train_epoch_end must be set explicitly since otherwise it will be - # changed internally causing the state_key to not match anymore upon loading - # and thus not loading the callbacks state (this is desired behavior, it just + # save_on_train_epoch_end must be set explicitly since otherwise it will be + # changed internally causing the state_key to not match anymore upon loading + # and thus not loading the callbacks state (this is desired behavior, it just # conflicts with the place where CaptureCallbacksBeforeTraining intercepts) checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True, save_on_train_epoch_end=False) trainer_args = dict( From 763a15955bd54b78c7ef25060fd7d4298cee310d Mon Sep 17 00:00:00 2001 From: ORippler Date: Thu, 16 Dec 2021 13:49:57 +0100 Subject: [PATCH 10/14] Check that all attributes are restored properly --- tests/models/test_restore.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index b0f19687ce3ab..51e3f8e3deb9f 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -272,8 +272,15 @@ def get_trainer_args(): for before, after in zip(callbacks_before_resume, callback_capture.callbacks): if isinstance(before, ModelCheckpoint): - assert before.best_model_path == after.best_model_path - assert before.best_model_score == after.best_model_score + for attribute in ( + "best_model_path", + "best_model_score", + "best_k_models", + "kth_best_model_path", + "kth_value", + "last_model_path", + ): + assert getattr(before, attribute) == getattr(after, attribute) def test_callbacks_references_fit_ckpt_path(tmpdir): From f1b66b31d820bb3a64b412339d85dd5be645e705 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 17 Dec 2021 00:23:19 +0100 Subject: [PATCH 11/14] revert changes, use fix on master --- tests/models/test_restore.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 16ff0a2bf8d0a..20e83bb070534 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -243,11 +243,7 @@ def test_callbacks_state_fit_ckpt_path(tmpdir): callback_capture = CaptureCallbacksBeforeTraining() def get_trainer_args(): - # save_on_train_epoch_end must be set explicitly since otherwise it will be - # changed internally causing the state_key to not match anymore upon loading - # and thus not loading the callbacks state (this is desired behavior, it just - # conflicts with the place where CaptureCallbacksBeforeTraining intercepts) - checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True, save_on_train_epoch_end=False) + checkpoint = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True) trainer_args = dict( default_root_dir=tmpdir, limit_train_batches=1, From ba07b8c175d36eef6131dc65c1acf5cf801f447a Mon Sep 17 00:00:00 2001 From: ORippler Date: Fri, 17 Dec 2021 13:33:42 +0100 Subject: [PATCH 12/14] Convert to proper unit test --- tests/checkpointing/test_model_checkpoint.py | 57 +++++++++++--------- 1 file changed, 31 insertions(+), 26 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 5624fa61ec009..08792b91076a6 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1204,30 +1204,35 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir): assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"} -def test_model_checkpoint_attributes(tmpdir): - seed_everything() - model = LogInTwoMethods() - - epochs = 2 - checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor=None, save_top_k=-1, save_last=True) - trainer = Trainer( - default_root_dir=tmpdir, - callbacks=[checkpoint_callback], - limit_train_batches=10, - limit_val_batches=10, - max_epochs=epochs, - logger=False, - ) - - trainer.fit(model) - - checkpoint = torch.load(os.path.join(tmpdir, "last.ckpt"))["callbacks"][checkpoint_callback.state_key] - - for k in ("best_k_models", "kth_best_model_path", "kth_value", "last_model_path"): - assert checkpoint[k] == getattr(checkpoint_callback, k) - - restored_callback = ModelCheckpoint(dirpath=tmpdir, monitor=None, save_top_k=-1, save_last=True) - restored_callback.on_load_checkpoint("", "", checkpoint) +def test_model_checkpoint_loadsave_ckpt(tmpdir): + cb = ModelCheckpoint(dirpath=tmpdir, monitor=None, save_top_k=-1, save_last=True) + + # test restore + ckpt_for_restore = { + "best_model_path": 'epoch=0-step=0.ckpt', + "best_model_score": torch.tensor(1.1027), + "best_k_models": {'epoch=0-step=0.ckpt': torch.tensor(1.1027)}, + "kth_best_model_path": "epoch=0-step=0.ckpt", + "kth_value": torch.tensor(1.1027), + "last_model_path": "last.ckpt" + } - for k in ("best_k_models", "kth_best_model_path", "kth_value", "last_model_path"): - assert checkpoint[k] == getattr(restored_callback, k) + cb.on_load_checkpoint("", "", ckpt_for_restore) + for key, val in ckpt_for_restore.items(): + assert getattr(cb, key) == val + + # set attributes from 2nd checkpoint to simulate training and test write + ckpt_for_write = { + "best_model_path": 'epoch=10-step=1436.ckpt', + "best_model_score": torch.tensor(2.246), + "best_k_models": {'epoch=10-step=1436.ckpt': torch.tensor(2.246)}, + "kth_best_model_path": "epoch=10-step=1436.ckpt", + "kth_value": torch.tensor(2.246), + "last_model_path": "last2245.ckpt" + } + for key, val in ckpt_for_write.items(): + setattr(cb, key, val) + + written_ckpt = cb.on_save_checkpoint("", "", "") + for state in ckpt_for_write: + assert ckpt_for_write[state] == written_ckpt[state] \ No newline at end of file From e97a086e2df3072fc1dce6d21e1c97a4ec8e47db Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Dec 2021 12:35:04 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/checkpointing/test_model_checkpoint.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 08792b91076a6..1aad1c20a45ba 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1209,30 +1209,30 @@ def test_model_checkpoint_loadsave_ckpt(tmpdir): # test restore ckpt_for_restore = { - "best_model_path": 'epoch=0-step=0.ckpt', + "best_model_path": "epoch=0-step=0.ckpt", "best_model_score": torch.tensor(1.1027), - "best_k_models": {'epoch=0-step=0.ckpt': torch.tensor(1.1027)}, + "best_k_models": {"epoch=0-step=0.ckpt": torch.tensor(1.1027)}, "kth_best_model_path": "epoch=0-step=0.ckpt", "kth_value": torch.tensor(1.1027), - "last_model_path": "last.ckpt" + "last_model_path": "last.ckpt", } - cb.on_load_checkpoint("", "", ckpt_for_restore) + cb.on_load_checkpoint("", "", ckpt_for_restore) for key, val in ckpt_for_restore.items(): assert getattr(cb, key) == val # set attributes from 2nd checkpoint to simulate training and test write ckpt_for_write = { - "best_model_path": 'epoch=10-step=1436.ckpt', + "best_model_path": "epoch=10-step=1436.ckpt", "best_model_score": torch.tensor(2.246), - "best_k_models": {'epoch=10-step=1436.ckpt': torch.tensor(2.246)}, + "best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)}, "kth_best_model_path": "epoch=10-step=1436.ckpt", "kth_value": torch.tensor(2.246), - "last_model_path": "last2245.ckpt" + "last_model_path": "last2245.ckpt", } for key, val in ckpt_for_write.items(): setattr(cb, key, val) - + written_ckpt = cb.on_save_checkpoint("", "", "") for state in ckpt_for_write: - assert ckpt_for_write[state] == written_ckpt[state] \ No newline at end of file + assert ckpt_for_write[state] == written_ckpt[state] From 3d7994a418048030bb7e228c379200408f713e47 Mon Sep 17 00:00:00 2001 From: ORippler Date: Mon, 20 Dec 2021 12:08:24 +0100 Subject: [PATCH 14/14] Refactor `test_mode_checkpoint_saveload_ckpt` * First save, then load ckpt. * Instantiate ModelCheckpoint twice. --- tests/checkpointing/test_model_checkpoint.py | 48 ++++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 1aad1c20a45ba..a3324905852a7 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -1204,35 +1204,35 @@ def test_check_val_every_n_epochs_top_k_integration(tmpdir): assert set(os.listdir(tmpdir)) == {"epoch=1.ckpt", "epoch=3.ckpt"} -def test_model_checkpoint_loadsave_ckpt(tmpdir): - cb = ModelCheckpoint(dirpath=tmpdir, monitor=None, save_top_k=-1, save_last=True) - - # test restore - ckpt_for_restore = { - "best_model_path": "epoch=0-step=0.ckpt", - "best_model_score": torch.tensor(1.1027), - "best_k_models": {"epoch=0-step=0.ckpt": torch.tensor(1.1027)}, - "kth_best_model_path": "epoch=0-step=0.ckpt", - "kth_value": torch.tensor(1.1027), - "last_model_path": "last.ckpt", - } - - cb.on_load_checkpoint("", "", ckpt_for_restore) - for key, val in ckpt_for_restore.items(): - assert getattr(cb, key) == val - - # set attributes from 2nd checkpoint to simulate training and test write - ckpt_for_write = { +def test_model_checkpoint_saveload_ckpt(tmpdir): + ckpt = { + "monitor": "random_value", "best_model_path": "epoch=10-step=1436.ckpt", "best_model_score": torch.tensor(2.246), + "current_score": torch.tensor(1.5), + "dirpath": tmpdir, "best_k_models": {"epoch=10-step=1436.ckpt": torch.tensor(2.246)}, "kth_best_model_path": "epoch=10-step=1436.ckpt", "kth_value": torch.tensor(2.246), "last_model_path": "last2245.ckpt", } - for key, val in ckpt_for_write.items(): - setattr(cb, key, val) - written_ckpt = cb.on_save_checkpoint("", "", "") - for state in ckpt_for_write: - assert ckpt_for_write[state] == written_ckpt[state] + # test on_save_checkpoint + cb_write = ModelCheckpoint(dirpath=tmpdir, monitor="random_value", save_top_k=-1, save_last=True) + for key, val in ckpt.items(): + setattr(cb_write, key, val) + written_ckpt = cb_write.on_save_checkpoint("", "", "") + for state in ckpt: + assert ckpt[state] == written_ckpt[state] + + # test on_load_checkpoint + # Note: "current_score", "dirpath" and "monitor" are currently not restored by on_load_checkpoint. + # We therefore set "dirpath" and "monitor" to something different than for ckpt/cb_write so we can assert them. + # "current_score" is left as initialized, i.e. None, and can therefore also be asserted + cb_restore = ModelCheckpoint(dirpath=tmpdir + "restore", monitor=None, save_top_k=-1, save_last=True) + cb_restore.on_load_checkpoint("", "", written_ckpt) + for key, val in written_ckpt.items(): + if key not in ("current_score", "dirpath", "monitor"): + assert getattr(cb_restore, key) == val + else: + assert getattr(cb_restore, key) != val