@@ -152,7 +152,7 @@ def configure_optimizers(self):
152152 assert chk ['epoch' ] == epoch + 1
153153 assert chk ['global_step' ] == limit_train_batches * (epoch + 1 )
154154
155- mc_specific_data = chk ['callbacks' ]["ModelCheckpoint" ]
155+ mc_specific_data = chk ['callbacks' ][f "ModelCheckpoint[monitor= { monitor } ] " ]
156156 assert mc_specific_data ['dirpath' ] == checkpoint .dirpath
157157 assert mc_specific_data ['monitor' ] == monitor
158158 assert mc_specific_data ['current_score' ] == score
@@ -269,7 +269,7 @@ def _make_assertions(epoch, ix, add=''):
269269 expected_global_step = per_epoch_steps * (global_ix + 1 ) + (left_over_steps * epoch_num )
270270 assert chk ['global_step' ] == expected_global_step
271271
272- mc_specific_data = chk ['callbacks' ]["ModelCheckpoint" ]
272+ mc_specific_data = chk ['callbacks' ][f "ModelCheckpoint[monitor= { monitor } ] " ]
273273 assert mc_specific_data ['dirpath' ] == checkpoint .dirpath
274274 assert mc_specific_data ['monitor' ] == monitor
275275 assert mc_specific_data ['current_score' ] == score
@@ -870,8 +870,8 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
870870 ckpt_last = torch .load (path_last )
871871 assert all (ckpt_last_epoch [k ] == ckpt_last [k ] for k in ("epoch" , "global_step" ))
872872
873- ch_type = "ModelCheckpoint"
874- assert ckpt_last ["callbacks" ][ch_type ] == ckpt_last_epoch ["callbacks" ][ch_type ]
873+ ckpt_id = "ModelCheckpoint[monitor=early_stop_on] "
874+ assert ckpt_last ["callbacks" ][ckpt_id ] == ckpt_last_epoch ["callbacks" ][ckpt_id ]
875875
876876 # it is easier to load the model objects than to iterate over the raw dict of tensors
877877 model_last_epoch = LogInTwoMethods .load_from_checkpoint (path_last_epoch )
@@ -1128,7 +1128,7 @@ def training_step(self, *args):
11281128 trainer .fit (TestModel ())
11291129 assert model_checkpoint .current_score == 0.3
11301130 ckpts = [torch .load (str (ckpt )) for ckpt in tmpdir .listdir ()]
1131- ckpts = [ckpt ["callbacks" ]["ModelCheckpoint" ] for ckpt in ckpts ]
1131+ ckpts = [ckpt ["callbacks" ]["ModelCheckpoint[monitor=foo] " ] for ckpt in ckpts ]
11321132 assert sorted (ckpt ["current_score" ] for ckpt in ckpts ) == [0.1 , 0.2 , 0.3 ]
11331133
11341134
0 commit comments