@@ -163,14 +163,16 @@ def configure_optimizers(self):
163163
164164@mock .patch .dict (os .environ , {"PL_DEV_DEBUG" : "1" })
165165@pytest .mark .parametrize (
166- "val_check_interval,reduce_lr_on_plateau" ,
166+ "val_check_interval,reduce_lr_on_plateau,epoch_aligned " ,
167167 [
168- (0.25 , True ),
169- (0.25 , False ),
170- (0.33 , False ),
168+ (0.25 , True , True ),
169+ (0.25 , False , True ),
170+ (0.42 , False , False ),
171171 ],
172172)
173- def test_model_checkpoint_score_and_ckpt_val_check_interval (tmpdir , val_check_interval , reduce_lr_on_plateau ):
173+ def test_model_checkpoint_score_and_ckpt_val_check_interval (
174+ tmpdir , val_check_interval , reduce_lr_on_plateau , epoch_aligned
175+ ):
174176 """
175177 Test that when a model checkpoint is saved, it saves with the correct
176178 score appended to ckpt_path and checkpoint data with val_check_interval
@@ -182,6 +184,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_in
182184 monitor = 'val_log'
183185 per_epoch_steps = int (limit_train_batches * val_check_interval )
184186 per_epoch_call_count = limit_train_batches // per_epoch_steps
187+ left_over_steps = limit_train_batches % per_epoch_steps
185188
186189 class CustomBoringModel (BoringModel ):
187190
@@ -236,35 +239,57 @@ def configure_optimizers(self):
236239 ckpt_files = list (Path (tmpdir ).glob ('*.ckpt' ))
237240 scores = [metric [monitor ] for metric in trainer .dev_debugger .logged_metrics if monitor in metric ]
238241 lr_scheduler_debug = trainer .dev_debugger .saved_lr_scheduler_updates
239- assert len (ckpt_files ) == len (scores ) == per_epoch_call_count * max_epochs
242+
243+ # on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
244+ # end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
245+ additional_ckpt , additional_ckpt_path = 0 , None
246+ if not epoch_aligned :
247+ additional_ckpt_path = [f for f in ckpt_files if 'v1' in f .stem ][0 ]
248+ additional_ckpt = 1
249+
250+ additional_ckpt = 1 if not epoch_aligned else 0
251+ assert len (ckpt_files ) == len (scores ) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt
240252 assert len (lr_scheduler_debug ) == max_epochs
241253
254+ def _make_assertions (epoch , ix , add = '' ):
255+ global_ix = ix + per_epoch_call_count * epoch
256+ score = scores [global_ix ]
257+ expected_score = getattr (model , f'{ monitor } s' )[global_ix ].mean ().item ()
258+ expected_filename = f'{ monitor } ={ score :.4f} -epoch={ epoch } { add } .ckpt'
259+ assert math .isclose (score , expected_score , rel_tol = 1e-4 )
260+
261+ chk = pl_load (os .path .join (checkpoint .dirpath , expected_filename ))
262+ assert chk ['epoch' ] == epoch + 1
263+ epoch_num = epoch + (1 if add else 0 )
264+ expected_global_step = per_epoch_steps * (global_ix + 1 ) + (left_over_steps * epoch_num )
265+ assert chk ['global_step' ] == expected_global_step
266+
267+ mc_specific_data = chk ['callbacks' ][type (checkpoint )]
268+ assert mc_specific_data ['dirpath' ] == checkpoint .dirpath
269+ assert mc_specific_data ['monitor' ] == monitor
270+ assert mc_specific_data ['current_score' ] == score
271+
272+ if not reduce_lr_on_plateau :
273+ lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
274+ did_update = 1 if (ix + 1 == per_epoch_call_count ) and (epoch_aligned or add ) else 0
275+ assert lr_scheduler_specific_data ['_step_count' ] == epoch + 1 + did_update
276+ assert lr_scheduler_specific_data ['_last_lr' ][0 ] == lr * (lr ** (epoch + did_update ))
277+
278+ return score
279+
242280 for epoch in range (max_epochs ):
243- for ix in range (per_epoch_call_count ):
244- global_ix = ix + per_epoch_call_count * epoch
245- score = scores [global_ix ]
246- expected_score = getattr (model , f'{ monitor } s' )[global_ix ].mean ().item ()
247- expected_filename = f'{ monitor } ={ score :.4f} -epoch={ epoch } .ckpt'
248- assert math .isclose (score , expected_score , rel_tol = 1e-4 )
249-
250- chk = pl_load (os .path .join (checkpoint .dirpath , expected_filename ))
251- assert chk ['epoch' ] == epoch + 1
252- assert chk ['global_step' ] == per_epoch_steps * (global_ix + 1 )
253-
254- mc_specific_data = chk ['callbacks' ][type (checkpoint )]
255- assert mc_specific_data ['dirpath' ] == checkpoint .dirpath
256- assert mc_specific_data ['monitor' ] == monitor
257- assert mc_specific_data ['current_score' ] == score
258-
259- if not reduce_lr_on_plateau :
260- lr_scheduler_specific_data = chk ['lr_schedulers' ][0 ]
261- did_update = 1 if ix + 1 == per_epoch_call_count else 0
262- assert lr_scheduler_specific_data ['_step_count' ] == epoch + 1 + did_update
263- assert lr_scheduler_specific_data ['_last_lr' ][0 ] == lr * (lr ** (epoch + did_update ))
281+ for i in range (per_epoch_call_count ):
282+ score = _make_assertions (epoch , i )
264283
265284 assert lr_scheduler_debug [epoch ]['monitor_val' ] == (score if reduce_lr_on_plateau else None )
266285 assert lr_scheduler_debug [epoch ]['monitor_key' ] == (monitor if reduce_lr_on_plateau else None )
267286
287+ # check the ckpt file saved on_train_end
288+ if additional_ckpt_path :
289+ epoch = max_epochs - 1
290+ i = per_epoch_call_count - 1
291+ _make_assertions (epoch , i , add = '-v1' )
292+
268293
269294@pytest .mark .parametrize ("save_top_k" , [- 1 , 0 , 1 , 2 ])
270295def test_model_checkpoint_with_non_string_input (tmpdir , save_top_k : int ):
0 commit comments