Skip to content

Commit e9c61cc

Browse files
Bordaawaelchlirohitgr7carmocca
authored andcommitted
Add step index in checkpoint name (#3807)
* true final value of global step * ch check * tests * save each validation interval * wip * add test * add test * wip * fix tests, revert old edits, fix merge conflicts, update doctests * test + bugfix * sort files * format test * suggestion by ananth * added changelog * naming * docs * example * suggestion Co-authored-by: Carlos Mocholí <[email protected]> * fix test * pep * pep Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> (cherry picked from commit ef03c39)
1 parent 48ed664 commit e9c61cc

File tree

8 files changed

+117
-67
lines changed

8 files changed

+117
-67
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121

2222
- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))
2323

24+
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
25+
2426
### Changed
2527

2628
- W&B log in sync with `Trainer` step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
101101
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
102102
... )
103103
104-
By default, filename is ``None`` and will be set to ``'{epoch}'``.
104+
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.
105105
106106
107107
Example::
@@ -223,16 +223,16 @@ def save_checkpoint(self, trainer, pl_module):
223223
monitor_candidates = self._monitor_candidates(trainer)
224224

225225
# ie: path/val_loss=0.5.ckpt
226-
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
226+
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)
227227

228228
# callback supports multiple simultaneous modes
229229
# here we call each mode sequentially
230230
# Mode 1: save all checkpoints OR only the top k
231231
if self.save_top_k:
232-
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
232+
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)
233233

234234
# Mode 2: save the last checkpoint
235-
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
235+
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)
236236

237237
def __validate_init_configuration(self):
238238
if self.save_top_k is not None and self.save_top_k < -1:
@@ -361,16 +361,17 @@ def _format_checkpoint_name(
361361
cls,
362362
filename: Optional[str],
363363
epoch: int,
364+
step: int,
364365
metrics: Dict[str, Any],
365366
prefix: str = "",
366367
) -> str:
367368
if not filename:
368369
# filename is not set, use default name
369-
filename = "{epoch}"
370+
filename = "{epoch}-{step}"
370371
# check and parse user passed keys in the string
371372
groups = re.findall(r"(\{.*?)[:\}]", filename)
372373
if len(groups) >= 0:
373-
metrics["epoch"] = epoch
374+
metrics.update({"epoch": epoch, 'step': step})
374375
for group in groups:
375376
name = group[1:]
376377
filename = filename.replace(group, name + "={" + name)
@@ -380,32 +381,32 @@ def _format_checkpoint_name(
380381
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])
381382

382383
def format_checkpoint_name(
383-
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
384+
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
384385
) -> str:
385386
"""Generate a filename according to the defined template.
386387
387388
Example::
388389
389390
>>> tmpdir = os.path.dirname(__file__)
390391
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
391-
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
392+
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
392393
'epoch=0.ckpt'
393394
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
394-
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
395+
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
395396
'epoch=005.ckpt'
396397
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
397-
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
398+
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
398399
'epoch=2-val_loss=0.12.ckpt'
399400
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
400-
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
401+
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
401402
'missing=0.ckpt'
402-
>>> ckpt = ModelCheckpoint(filename='{epoch}')
403-
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
404-
'epoch=0.ckpt'
403+
>>> ckpt = ModelCheckpoint(filename='{step}')
404+
>>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
405+
'step=0.ckpt'
405406
406407
"""
407408
filename = self._format_checkpoint_name(
408-
self.filename, epoch, metrics, prefix=self.prefix
409+
self.filename, epoch, step, metrics, prefix=self.prefix
409410
)
410411
if ver is not None:
411412
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
@@ -480,13 +481,11 @@ def _validate_monitor_key(self, trainer):
480481
)
481482
raise MisconfigurationException(m)
482483

483-
def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
484-
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
484+
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
485+
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
485486
version_cnt = 0
486487
while self._fs.exists(filepath):
487-
filepath = self.format_checkpoint_name(
488-
epoch, ckpt_name_metrics, ver=version_cnt
489-
)
488+
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
490489
# this epoch called before
491490
version_cnt += 1
492491
return filepath
@@ -495,9 +494,10 @@ def _monitor_candidates(self, trainer):
495494
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
496495
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
497496
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
497+
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
498498
return ckpt_name_metrics
499499

500-
def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
500+
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
501501
should_save_last = self.monitor is None or self.save_last
502502
if not should_save_last:
503503
return
@@ -507,7 +507,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
507507
# when user ALSO asked for the 'last.ckpt' change the name
508508
if self.save_last:
509509
last_filepath = self._format_checkpoint_name(
510-
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
510+
self.CHECKPOINT_NAME_LAST,
511+
trainer.current_epoch,
512+
trainer.global_step,
513+
ckpt_name_metrics,
514+
prefix=self.prefix
511515
)
512516
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
513517

@@ -524,17 +528,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
524528
if self.monitor is None:
525529
self.best_model_path = self.last_model_path
526530

527-
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
531+
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
528532
current = metrics.get(self.monitor)
533+
epoch = metrics.get("epoch")
534+
step = metrics.get("step")
529535

530536
if not isinstance(current, torch.Tensor) and current is not None:
531537
current = torch.tensor(current, device=pl_module.device)
532538

533539
if self.check_monitor_top_k(current):
534-
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
540+
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
535541
elif self.verbose:
536542
rank_zero_info(
537-
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
543+
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
538544
)
539545

540546
def _is_valid_monitor_key(self, metrics):
@@ -545,11 +551,11 @@ def _update_best_and_save(
545551
filepath: str,
546552
current: torch.Tensor,
547553
epoch: int,
554+
step: int,
548555
trainer,
549556
pl_module,
550557
):
551-
552-
k = epoch + 1 if self.save_top_k == -1 else self.save_top_k
558+
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
553559

554560
del_list = []
555561
if len(self.best_k_models) == k and k > 0:
@@ -576,9 +582,8 @@ def _update_best_and_save(
576582

577583
if self.verbose:
578584
rank_zero_info(
579-
f"Epoch {epoch:d}: {self.monitor} reached"
580-
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
581-
f" saving model to {filepath} as top {k}"
585+
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
586+
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
582587
)
583588
self._save_model(filepath, trainer, pl_module)
584589

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
250250
# depre warning
251251
if eval_results is not None and user_reduced:
252252
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
253-
m = f'The {step} should not return anything as of 9.1.' \
254-
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
255-
self.warning_cache.warn(m)
253+
self.warning_cache.warn(
254+
f'The {step} should not return anything as of 9.1.'
255+
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
256+
)
256257

257258
if using_eval_result and not user_reduced:
258259
eval_results = self.__auto_reduce_result_objs(outputs)

0 commit comments

Comments
 (0)