@@ -421,34 +421,17 @@ def test_dp_output_reduce():
421421
422422
423423@pytest .mark .parametrize (
424- ["save_top_k" , "save_last" , "file_prefix" , " expected_files" ],
424+ ["save_top_k" , "save_last" , "expected_files" ],
425425 [
426- pytest .param (
427- - 1 ,
428- False ,
429- "" ,
430- {"epoch=4.ckpt" , "epoch=3.ckpt" , "epoch=2.ckpt" , "epoch=1.ckpt" , "epoch=0.ckpt" },
431- id = "CASE K=-1 (all)" ,
432- ),
433- pytest .param (1 , False , "test_prefix" , {"test_prefix-epoch=4.ckpt" }, id = "CASE K=1 (2.5, epoch 4)" ),
434- pytest .param (2 , False , "" , {"epoch=4.ckpt" , "epoch=2.ckpt" }, id = "CASE K=2 (2.5 epoch 4, 2.8 epoch 2)" ),
435- pytest .param (
436- 4 ,
437- False ,
438- "" ,
439- {"epoch=1.ckpt" , "epoch=4.ckpt" , "epoch=3.ckpt" , "epoch=2.ckpt" },
440- id = "CASE K=4 (save all 4 base)" ,
441- ),
442- pytest .param (
443- 3 ,
444- False ,
445- "" , {"epoch=2.ckpt" , "epoch=3.ckpt" , "epoch=4.ckpt" },
446- id = "CASE K=3 (save the 2nd, 3rd, 4th model)"
447- ),
448- pytest .param (1 , True , "" , {"epoch=4.ckpt" , "last.ckpt" }, id = "CASE K=1 (save the 4th model and the last model)" ),
426+ pytest .param (- 1 , False , [f"epoch={ i } .ckpt" for i in range (5 )], id = "CASE K=-1 (all)" ),
427+ pytest .param (1 , False , {"epoch=4.ckpt" }, id = "CASE K=1 (2.5, epoch 4)" ),
428+ pytest .param (2 , False , [f"epoch={ i } .ckpt" for i in (2 , 4 )], id = "CASE K=2 (2.5 epoch 4, 2.8 epoch 2)" ),
429+ pytest .param (4 , False , [f"epoch={ i } .ckpt" for i in range (1 , 5 )], id = "CASE K=4 (save all 4 base)" ),
430+ pytest .param (3 , False , [f"epoch={ i } .ckpt" for i in range (2 , 5 )], id = "CASE K=3 (save the 2nd, 3rd, 4th model)" ),
431+ pytest .param (1 , True , {"epoch=4.ckpt" , "last.ckpt" }, id = "CASE K=1 (save the 4th model and the last model)" ),
449432 ],
450433)
451- def test_model_checkpoint_options (tmpdir , save_top_k , save_last , file_prefix , expected_files ):
434+ def test_model_checkpoint_options (tmpdir , save_top_k , save_last , expected_files ):
452435 """Test ModelCheckpoint options."""
453436
454437 def mock_save_function (filepath , * args ):
@@ -463,7 +446,6 @@ def mock_save_function(filepath, *args):
463446 monitor = 'checkpoint_on' ,
464447 save_top_k = save_top_k ,
465448 save_last = save_last ,
466- prefix = file_prefix ,
467449 verbose = 1
468450 )
469451 checkpoint_callback .save_function = mock_save_function
0 commit comments