diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f58ff768759e8..54ba24b5e3253 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -163,14 +163,16 @@ def configure_optimizers(self): @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @pytest.mark.parametrize( - "val_check_interval,reduce_lr_on_plateau", + "val_check_interval,reduce_lr_on_plateau,epoch_aligned", [ - (0.25, True), - (0.25, False), - (0.33, False), + (0.25, True, True), + (0.25, False, True), + (0.42, False, False), ], ) -def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau): +def test_model_checkpoint_score_and_ckpt_val_check_interval( + tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned +): """ Test that when a model checkpoint is saved, it saves with the correct score appended to ckpt_path and checkpoint data with val_check_interval @@ -182,6 +184,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_in monitor = 'val_log' per_epoch_steps = int(limit_train_batches * val_check_interval) per_epoch_call_count = limit_train_batches // per_epoch_steps + left_over_steps = limit_train_batches % per_epoch_steps class CustomBoringModel(BoringModel): @@ -236,35 +239,57 @@ def configure_optimizers(self): ckpt_files = list(Path(tmpdir).glob('*.ckpt')) scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric] lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates - assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs + + # on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the + # end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt + additional_ckpt, additional_ckpt_path = 0, None + if not epoch_aligned: + additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0] + additional_ckpt = 1 + + additional_ckpt = 1 if not epoch_aligned else 0 + assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt assert len(lr_scheduler_debug) == max_epochs + def _make_assertions(epoch, ix, add=''): + global_ix = ix + per_epoch_call_count * epoch + score = scores[global_ix] + expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item() + expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt' + assert math.isclose(score, expected_score, rel_tol=1e-4) + + chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) + assert chk['epoch'] == epoch + 1 + epoch_num = epoch + (1 if add else 0) + expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num) + assert chk['global_step'] == expected_global_step + + mc_specific_data = chk['callbacks'][type(checkpoint)] + assert mc_specific_data['dirpath'] == checkpoint.dirpath + assert mc_specific_data['monitor'] == monitor + assert mc_specific_data['current_score'] == score + + if not reduce_lr_on_plateau: + lr_scheduler_specific_data = chk['lr_schedulers'][0] + did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0 + assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update + assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update)) + + return score + for epoch in range(max_epochs): - for ix in range(per_epoch_call_count): - global_ix = ix + per_epoch_call_count * epoch - score = scores[global_ix] - expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item() - expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt' - assert math.isclose(score, expected_score, rel_tol=1e-4) - - chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename)) - assert chk['epoch'] == epoch + 1 - assert chk['global_step'] == per_epoch_steps * (global_ix + 1) - - mc_specific_data = chk['callbacks'][type(checkpoint)] - assert mc_specific_data['dirpath'] == checkpoint.dirpath - assert mc_specific_data['monitor'] == monitor - assert mc_specific_data['current_score'] == score - - if not reduce_lr_on_plateau: - lr_scheduler_specific_data = chk['lr_schedulers'][0] - did_update = 1 if ix + 1 == per_epoch_call_count else 0 - assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update - assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update)) + for i in range(per_epoch_call_count): + score = _make_assertions(epoch, i) assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None) assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None) + # check the ckpt file saved on_train_end + if additional_ckpt_path: + epoch = max_epochs - 1 + i = per_epoch_call_count - 1 + _make_assertions(epoch, i, add='-v1') + @pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2]) def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int):