Skip to content

Commit b00991e

Browse files
Added changeable extension variable for model checkpoints (#4977)
* Added changeable extension variable for model checkpoints * Removed whitespace * Removed the last bit of whitespace * Wrote tests for FILE_EXTENSION * Fixed formatting issues * More formatting issues * Simplify test by just using defaults * Formatting to PEP8 * Added dummy class that inherits ModelCheckpoint; run only one batch instead of epoch for integration test * Fixed too much whitespace formatting * some changes Co-authored-by: rohitgr7 <[email protected]>
1 parent 2e838e6 commit b00991e

File tree

2 files changed

+27
-3
lines changed

2 files changed

+27
-3
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ class ModelCheckpoint(Callback):
9090
Example::
9191
9292
# custom path
93-
# saves a file like: my/path/epoch=0.ckpt
93+
# saves a file like: my/path/epoch=0-step=10.ckpt
9494
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
9595
9696
By default, dirpath is ``None`` and will be set at runtime to the location
@@ -140,6 +140,7 @@ class ModelCheckpoint(Callback):
140140

141141
CHECKPOINT_JOIN_CHAR = "-"
142142
CHECKPOINT_NAME_LAST = "last"
143+
FILE_EXTENSION = ".ckpt"
143144

144145
def __init__(
145146
self,
@@ -442,7 +443,7 @@ def format_checkpoint_name(
442443
)
443444
if ver is not None:
444445
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
445-
ckpt_name = f"{filename}.ckpt"
446+
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
446447
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name
447448

448449
def __resolve_ckpt_dir(self, trainer, pl_module):
@@ -545,7 +546,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
545546
ckpt_name_metrics,
546547
prefix=self.prefix
547548
)
548-
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
549+
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")
549550

550551
self._save_model(last_filepath, trainer, pl_module)
551552
if (

tests/checkpointing/test_model_checkpoint.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,29 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
261261
assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'
262262

263263

264+
class ModelCheckpointExtensionTest(ModelCheckpoint):
265+
FILE_EXTENSION = '.tpkc'
266+
267+
268+
def test_model_checkpoint_file_extension(tmpdir):
269+
"""
270+
Test ModelCheckpoint with different file extension.
271+
"""
272+
273+
model = LogInTwoMethods()
274+
model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True)
275+
trainer = Trainer(
276+
default_root_dir=tmpdir,
277+
callbacks=[model_checkpoint],
278+
max_steps=1,
279+
logger=False,
280+
)
281+
trainer.fit(model)
282+
283+
expected = ['epoch=0-step=0.tpkc', 'last.tpkc']
284+
assert set(expected) == set(os.listdir(tmpdir))
285+
286+
264287
def test_model_checkpoint_save_last(tmpdir):
265288
"""Tests that save_last produces only one last checkpoint."""
266289
seed_everything()

0 commit comments

Comments
 (0)