Skip to content

Commit c35ca64

Browse files
committed
tests
1 parent fb6dad0 commit c35ca64

File tree

3 files changed

+14
-28
lines changed

3 files changed

+14
-28
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def _format_checkpoint_name(
335335
epoch: int,
336336
step: int,
337337
metrics: Dict[str, Any],
338+
prefix: str = "",
338339
) -> str:
339340
if not filename:
340341
# filename is not set, use default name
@@ -351,6 +352,9 @@ def _format_checkpoint_name(
351352
metrics[name] = 0
352353
filename = filename.format(**metrics)
353354

355+
if prefix:
356+
filename = cls.CHECKPOINT_JOIN_CHAR.join([prefix, filename])
357+
354358
return filename
355359

356360
def format_checkpoint_name(self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None) -> str:

tests/checkpointing/test_model_checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,9 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
294294
assert ckpt_name == str(Path('.').resolve() / 'epoch=3-step=4.ckpt')
295295

296296
# with version
297-
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name', prefix='test')
297+
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=tmpdir, filename='name')
298298
ckpt_name = ckpt.format_checkpoint_name(3, 2, {}, ver=3)
299-
assert ckpt_name == tmpdir / 'test-name-v3.ckpt'
299+
assert ckpt_name == tmpdir / 'name-v3.ckpt'
300300

301301
# using slashes
302302
ckpt = ModelCheckpoint(monitor='early_stop_on', dirpath=None, filename='{epoch}_{val/loss:.5f}')

tests/trainer/test_trainer.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -421,34 +421,17 @@ def test_dp_output_reduce():
421421

422422

423423
@pytest.mark.parametrize(
424-
["save_top_k", "save_last", "file_prefix", "expected_files"],
424+
["save_top_k", "save_last", "expected_files"],
425425
[
426-
pytest.param(
427-
-1,
428-
False,
429-
"",
430-
{"epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt", "epoch=1.ckpt", "epoch=0.ckpt"},
431-
id="CASE K=-1 (all)",
432-
),
433-
pytest.param(1, False, "test_prefix", {"test_prefix-epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"),
434-
pytest.param(2, False, "", {"epoch=4.ckpt", "epoch=2.ckpt"}, id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
435-
pytest.param(
436-
4,
437-
False,
438-
"",
439-
{"epoch=1.ckpt", "epoch=4.ckpt", "epoch=3.ckpt", "epoch=2.ckpt"},
440-
id="CASE K=4 (save all 4 base)",
441-
),
442-
pytest.param(
443-
3,
444-
False,
445-
"", {"epoch=2.ckpt", "epoch=3.ckpt", "epoch=4.ckpt"},
446-
id="CASE K=3 (save the 2nd, 3rd, 4th model)"
447-
),
448-
pytest.param(1, True, "", {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"),
426+
pytest.param(-1, False, [f"epoch={i}.ckpt" for i in range(5)], id="CASE K=-1 (all)"),
427+
pytest.param(1, False, {"epoch=4.ckpt"}, id="CASE K=1 (2.5, epoch 4)"),
428+
pytest.param(2, False, [f"epoch={i}.ckpt" for i in (2, 4)], id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
429+
pytest.param(4, False, [f"epoch={i}.ckpt" for i in range(1, 5)], id="CASE K=4 (save all 4 base)"),
430+
pytest.param(3, False, [f"epoch={i}.ckpt" for i in range(2, 5)], id="CASE K=3 (save the 2nd, 3rd, 4th model)"),
431+
pytest.param(1, True, {"epoch=4.ckpt", "last.ckpt"}, id="CASE K=1 (save the 4th model and the last model)"),
449432
],
450433
)
451-
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files):
434+
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, expected_files):
452435
"""Test ModelCheckpoint options."""
453436

454437
def mock_save_function(filepath, *args):
@@ -463,7 +446,6 @@ def mock_save_function(filepath, *args):
463446
monitor='checkpoint_on',
464447
save_top_k=save_top_k,
465448
save_last=save_last,
466-
prefix=file_prefix,
467449
verbose=1
468450
)
469451
checkpoint_callback.save_function = mock_save_function

0 commit comments

Comments
 (0)