Skip to content

Commit 37244b5

Browse files
committed
update tests
1 parent 807f223 commit 37244b5

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

tests/callbacks/test_early_stopping.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def test_resume_early_stopping_from_checkpoint(tmpdir):
8181
# the checkpoint saves "epoch + 1"
8282
early_stop_callback_state = early_stop_callback.saved_states[checkpoint["epoch"] - 1]
8383
assert 4 == len(early_stop_callback.saved_states)
84-
assert checkpoint["callbacks"]["EarlyStoppingTestRestore"] == early_stop_callback_state
84+
print(checkpoint["callbacks"])
85+
assert checkpoint["callbacks"]["EarlyStoppingTestRestore[monitor=train_loss]"] == early_stop_callback_state
8586

8687
# ensure state is reloaded properly (assertion in the callback)
8788
early_stop_callback = EarlyStoppingTestRestore(early_stop_callback_state, monitor='train_loss')

tests/callbacks/test_lambda_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def on_train_epoch_start(self):
2828
raise KeyboardInterrupt
2929

3030
checker = set()
31-
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction)]
31+
hooks = [m for m, _ in inspect.getmembers(Callback, predicate=inspect.isfunction) if not m.startswith("_")]
3232
hooks_args = {h: (lambda x: lambda *_: checker.add(x))(h) for h in hooks}
3333
hooks_args["on_save_checkpoint"] = (lambda x: lambda *_: [checker.add(x)])("on_save_checkpoint")
3434

tests/checkpointing/test_model_checkpoint.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def configure_optimizers(self):
152152
assert chk['epoch'] == epoch + 1
153153
assert chk['global_step'] == limit_train_batches * (epoch + 1)
154154

155-
mc_specific_data = chk['callbacks']["ModelCheckpoint"]
155+
mc_specific_data = chk['callbacks'][f"ModelCheckpoint[monitor={monitor}]"]
156156
assert mc_specific_data['dirpath'] == checkpoint.dirpath
157157
assert mc_specific_data['monitor'] == monitor
158158
assert mc_specific_data['current_score'] == score
@@ -269,7 +269,7 @@ def _make_assertions(epoch, ix, add=''):
269269
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
270270
assert chk['global_step'] == expected_global_step
271271

272-
mc_specific_data = chk['callbacks']["ModelCheckpoint"]
272+
mc_specific_data = chk['callbacks'][f"ModelCheckpoint[monitor={monitor}]"]
273273
assert mc_specific_data['dirpath'] == checkpoint.dirpath
274274
assert mc_specific_data['monitor'] == monitor
275275
assert mc_specific_data['current_score'] == score
@@ -870,8 +870,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
870870
ckpt_last = torch.load(path_last)
871871
assert all(ckpt_last_epoch[k] == ckpt_last[k] for k in ("epoch", "global_step"))
872872

873-
ch_type = "ModelCheckpoint"
874-
assert ckpt_last["callbacks"][ch_type] == ckpt_last_epoch["callbacks"][ch_type]
873+
ckpt_id = "ModelCheckpoint[monitor=early_stop_on]"
874+
assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id]
875875

876876
# it is easier to load the model objects than to iterate over the raw dict of tensors
877877
model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch)
@@ -1128,7 +1128,7 @@ def training_step(self, *args):
11281128
trainer.fit(TestModel())
11291129
assert model_checkpoint.current_score == 0.3
11301130
ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()]
1131-
ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts]
1131+
ckpts = [ckpt["callbacks"]["ModelCheckpoint[monitor=foo]"] for ckpt in ckpts]
11321132
assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3]
11331133

11341134

0 commit comments

Comments
 (0)