|
43 | 43 | from tests.helpers.runif import RunIf |
44 | 44 |
|
45 | 45 |
|
| 46 | +def test_model_checkpoint_state_key(): |
| 47 | + early_stopping = ModelCheckpoint(monitor="val_loss") |
| 48 | + expected_id = ( |
| 49 | + "ModelCheckpoint{'monitor': 'val_loss', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," |
| 50 | + " 'train_time_interval': None, 'save_on_train_epoch_end': None}" |
| 51 | + ) |
| 52 | + assert early_stopping.state_key == expected_id |
| 53 | + |
| 54 | + |
46 | 55 | class LogInTwoMethods(BoringModel): |
47 | 56 | def training_step(self, batch, batch_idx): |
48 | 57 | out = super().training_step(batch, batch_idx) |
@@ -148,7 +157,10 @@ def on_validation_epoch_end(self): |
148 | 157 | assert chk["epoch"] == epoch + 1 |
149 | 158 | assert chk["global_step"] == limit_train_batches * (epoch + 1) |
150 | 159 |
|
151 | | - mc_specific_data = chk["callbacks"]["ModelCheckpoint"] |
| 160 | + mc_specific_data = chk["callbacks"][ |
| 161 | + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," |
| 162 | + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" |
| 163 | + ] |
152 | 164 | assert mc_specific_data["dirpath"] == checkpoint.dirpath |
153 | 165 | assert mc_specific_data["monitor"] == monitor |
154 | 166 | assert mc_specific_data["current_score"] == score |
@@ -259,7 +271,10 @@ def _make_assertions(epoch, ix, version=""): |
259 | 271 | expected_global_step = per_val_train_batches * (global_ix + 1) + (leftover_train_batches * epoch_num) |
260 | 272 | assert chk["global_step"] == expected_global_step |
261 | 273 |
|
262 | | - mc_specific_data = chk["callbacks"]["ModelCheckpoint"] |
| 274 | + mc_specific_data = chk["callbacks"][ |
| 275 | + f"ModelCheckpoint{{'monitor': '{monitor}', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," |
| 276 | + " 'train_time_interval': None, 'save_on_train_epoch_end': False}" |
| 277 | + ] |
263 | 278 | assert mc_specific_data["dirpath"] == checkpoint.dirpath |
264 | 279 | assert mc_specific_data["monitor"] == monitor |
265 | 280 | assert mc_specific_data["current_score"] == score |
@@ -857,7 +872,12 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir): |
857 | 872 |
|
858 | 873 | assert ckpt_last_epoch["epoch"] == ckpt_last["epoch"] |
859 | 874 | assert ckpt_last_epoch["global_step"] == ckpt_last["global_step"] |
860 | | - assert ckpt_last["callbacks"]["ModelCheckpoint"] == ckpt_last_epoch["callbacks"]["ModelCheckpoint"] |
| 875 | + |
| 876 | + ckpt_id = ( |
| 877 | + "ModelCheckpoint{'monitor': 'early_stop_on', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," |
| 878 | + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" |
| 879 | + ) |
| 880 | + assert ckpt_last["callbacks"][ckpt_id] == ckpt_last_epoch["callbacks"][ckpt_id] |
861 | 881 |
|
862 | 882 | # it is easier to load the model objects than to iterate over the raw dict of tensors |
863 | 883 | model_last_epoch = LogInTwoMethods.load_from_checkpoint(path_last_epoch) |
@@ -1095,7 +1115,13 @@ def training_step(self, *args): |
1095 | 1115 | trainer.fit(TestModel()) |
1096 | 1116 | assert model_checkpoint.current_score == 0.3 |
1097 | 1117 | ckpts = [torch.load(str(ckpt)) for ckpt in tmpdir.listdir()] |
1098 | | - ckpts = [ckpt["callbacks"]["ModelCheckpoint"] for ckpt in ckpts] |
| 1118 | + ckpts = [ |
| 1119 | + ckpt["callbacks"][ |
| 1120 | + "ModelCheckpoint{'monitor': 'foo', 'mode': 'min', 'every_n_train_steps': 0, 'every_n_epochs': 1," |
| 1121 | + " 'train_time_interval': None, 'save_on_train_epoch_end': True}" |
| 1122 | + ] |
| 1123 | + for ckpt in ckpts |
| 1124 | + ] |
1099 | 1125 | assert sorted(ckpt["current_score"] for ckpt in ckpts) == [0.1, 0.2, 0.3] |
1100 | 1126 |
|
1101 | 1127 |
|
|
0 commit comments