Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 52 additions & 27 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,16 @@ def configure_optimizers(self):

@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize(
"val_check_interval,reduce_lr_on_plateau",
"val_check_interval,reduce_lr_on_plateau,epoch_aligned",
[
(0.25, True),
(0.25, False),
(0.33, False),
(0.25, True, True),
(0.25, False, True),
(0.42, False, False),
],
)
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
def test_model_checkpoint_score_and_ckpt_val_check_interval(
tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned
):
"""
Test that when a model checkpoint is saved, it saves with the correct
score appended to ckpt_path and checkpoint data with val_check_interval
Expand All @@ -182,6 +184,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_in
monitor = 'val_log'
per_epoch_steps = int(limit_train_batches * val_check_interval)
per_epoch_call_count = limit_train_batches // per_epoch_steps
left_over_steps = limit_train_batches % per_epoch_steps

class CustomBoringModel(BoringModel):

Expand Down Expand Up @@ -236,35 +239,57 @@ def configure_optimizers(self):
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs

# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
additional_ckpt, additional_ckpt_path = 0, None
if not epoch_aligned:
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
additional_ckpt = 1

additional_ckpt = 1 if not epoch_aligned else 0
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt
assert len(lr_scheduler_debug) == max_epochs

def _make_assertions(epoch, ix, add=''):
global_ix = ix + per_epoch_call_count * epoch
score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)

chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
epoch_num = epoch + (1 if add else 0)
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
assert chk['global_step'] == expected_global_step

mc_specific_data = chk['callbacks'][type(checkpoint)]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score

if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))

return score

for epoch in range(max_epochs):
for ix in range(per_epoch_call_count):
global_ix = ix + per_epoch_call_count * epoch
score = scores[global_ix]
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
assert math.isclose(score, expected_score, rel_tol=1e-4)

chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
assert chk['epoch'] == epoch + 1
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)

mc_specific_data = chk['callbacks'][type(checkpoint)]
assert mc_specific_data['dirpath'] == checkpoint.dirpath
assert mc_specific_data['monitor'] == monitor
assert mc_specific_data['current_score'] == score

if not reduce_lr_on_plateau:
lr_scheduler_specific_data = chk['lr_schedulers'][0]
did_update = 1 if ix + 1 == per_epoch_call_count else 0
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
for i in range(per_epoch_call_count):
score = _make_assertions(epoch, i)

assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)

# check the ckpt file saved on_train_end
if additional_ckpt_path:
epoch = max_epochs - 1
i = per_epoch_call_count - 1
_make_assertions(epoch, i, add='-v1')


@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int):
Expand Down