Skip to content

Commit 14e6b46

Browse files
rohitgr7carmocca
andauthored
Update test for ckpt+val_check_interval (#7084)
* update test * Apply suggestions from code review Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 9beec26 commit 14e6b46

File tree

1 file changed

+52
-27
lines changed

1 file changed

+52
-27
lines changed

tests/checkpointing/test_model_checkpoint.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,16 @@ def configure_optimizers(self):
163163

164164
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
165165
@pytest.mark.parametrize(
166-
"val_check_interval,reduce_lr_on_plateau",
166+
"val_check_interval,reduce_lr_on_plateau,epoch_aligned",
167167
[
168-
(0.25, True),
169-
(0.25, False),
170-
(0.33, False),
168+
(0.25, True, True),
169+
(0.25, False, True),
170+
(0.42, False, False),
171171
],
172172
)
173-
def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_interval, reduce_lr_on_plateau):
173+
def test_model_checkpoint_score_and_ckpt_val_check_interval(
174+
tmpdir, val_check_interval, reduce_lr_on_plateau, epoch_aligned
175+
):
174176
"""
175177
Test that when a model checkpoint is saved, it saves with the correct
176178
score appended to ckpt_path and checkpoint data with val_check_interval
@@ -182,6 +184,7 @@ def test_model_checkpoint_score_and_ckpt_val_check_interval(tmpdir, val_check_in
182184
monitor = 'val_log'
183185
per_epoch_steps = int(limit_train_batches * val_check_interval)
184186
per_epoch_call_count = limit_train_batches // per_epoch_steps
187+
left_over_steps = limit_train_batches % per_epoch_steps
185188

186189
class CustomBoringModel(BoringModel):
187190

@@ -236,35 +239,57 @@ def configure_optimizers(self):
236239
ckpt_files = list(Path(tmpdir).glob('*.ckpt'))
237240
scores = [metric[monitor] for metric in trainer.dev_debugger.logged_metrics if monitor in metric]
238241
lr_scheduler_debug = trainer.dev_debugger.saved_lr_scheduler_updates
239-
assert len(ckpt_files) == len(scores) == per_epoch_call_count * max_epochs
242+
243+
# on_train_end ckpt callback is called which creates an additional ckpt in case no ckpt is created at the
244+
# end of epoch, thus if val_check_interval doesn't align with the training steps we create an additional ckpt
245+
additional_ckpt, additional_ckpt_path = 0, None
246+
if not epoch_aligned:
247+
additional_ckpt_path = [f for f in ckpt_files if 'v1' in f.stem][0]
248+
additional_ckpt = 1
249+
250+
additional_ckpt = 1 if not epoch_aligned else 0
251+
assert len(ckpt_files) == len(scores) + additional_ckpt == per_epoch_call_count * max_epochs + additional_ckpt
240252
assert len(lr_scheduler_debug) == max_epochs
241253

254+
def _make_assertions(epoch, ix, add=''):
255+
global_ix = ix + per_epoch_call_count * epoch
256+
score = scores[global_ix]
257+
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
258+
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}{add}.ckpt'
259+
assert math.isclose(score, expected_score, rel_tol=1e-4)
260+
261+
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
262+
assert chk['epoch'] == epoch + 1
263+
epoch_num = epoch + (1 if add else 0)
264+
expected_global_step = per_epoch_steps * (global_ix + 1) + (left_over_steps * epoch_num)
265+
assert chk['global_step'] == expected_global_step
266+
267+
mc_specific_data = chk['callbacks'][type(checkpoint)]
268+
assert mc_specific_data['dirpath'] == checkpoint.dirpath
269+
assert mc_specific_data['monitor'] == monitor
270+
assert mc_specific_data['current_score'] == score
271+
272+
if not reduce_lr_on_plateau:
273+
lr_scheduler_specific_data = chk['lr_schedulers'][0]
274+
did_update = 1 if (ix + 1 == per_epoch_call_count) and (epoch_aligned or add) else 0
275+
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
276+
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
277+
278+
return score
279+
242280
for epoch in range(max_epochs):
243-
for ix in range(per_epoch_call_count):
244-
global_ix = ix + per_epoch_call_count * epoch
245-
score = scores[global_ix]
246-
expected_score = getattr(model, f'{monitor}s')[global_ix].mean().item()
247-
expected_filename = f'{monitor}={score:.4f}-epoch={epoch}.ckpt'
248-
assert math.isclose(score, expected_score, rel_tol=1e-4)
249-
250-
chk = pl_load(os.path.join(checkpoint.dirpath, expected_filename))
251-
assert chk['epoch'] == epoch + 1
252-
assert chk['global_step'] == per_epoch_steps * (global_ix + 1)
253-
254-
mc_specific_data = chk['callbacks'][type(checkpoint)]
255-
assert mc_specific_data['dirpath'] == checkpoint.dirpath
256-
assert mc_specific_data['monitor'] == monitor
257-
assert mc_specific_data['current_score'] == score
258-
259-
if not reduce_lr_on_plateau:
260-
lr_scheduler_specific_data = chk['lr_schedulers'][0]
261-
did_update = 1 if ix + 1 == per_epoch_call_count else 0
262-
assert lr_scheduler_specific_data['_step_count'] == epoch + 1 + did_update
263-
assert lr_scheduler_specific_data['_last_lr'][0] == lr * (lr**(epoch + did_update))
281+
for i in range(per_epoch_call_count):
282+
score = _make_assertions(epoch, i)
264283

265284
assert lr_scheduler_debug[epoch]['monitor_val'] == (score if reduce_lr_on_plateau else None)
266285
assert lr_scheduler_debug[epoch]['monitor_key'] == (monitor if reduce_lr_on_plateau else None)
267286

287+
# check the ckpt file saved on_train_end
288+
if additional_ckpt_path:
289+
epoch = max_epochs - 1
290+
i = per_epoch_call_count - 1
291+
_make_assertions(epoch, i, add='-v1')
292+
268293

269294
@pytest.mark.parametrize("save_top_k", [-1, 0, 1, 2])
270295
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k: int):

0 commit comments

Comments
 (0)