Skip to content

Commit ef03c39

Browse files
Bordaawaelchlirohitgr7carmocca
authored
Add step index in checkpoint name (#3807)
* true final value of global step * ch check * tests * save each validation interval * wip * add test * add test * wip * fix tests, revert old edits, fix merge conflicts, update doctests * test + bugfix * sort files * format test * suggestion by ananth * added changelog * naming * docs * example * suggestion Co-authored-by: Carlos Mocholí <[email protected]> * fix test * pep * pep Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent f40d086 commit ef03c39

File tree

8 files changed

+117
-67
lines changed

8 files changed

+117
-67
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1717

1818
- Added multiclass AUROC metric ([#4236](https://github.com/PyTorchLightning/pytorch-lightning/pull/4236))
1919

20+
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
21+
2022
### Changed
2123

2224
- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class ModelCheckpoint(Callback):
101101
... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
102102
... )
103103
104-
By default, filename is ``None`` and will be set to ``'{epoch}'``.
104+
By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``.
105105
106106
107107
Example::
@@ -222,16 +222,16 @@ def save_checkpoint(self, trainer, pl_module):
222222
monitor_candidates = self._monitor_candidates(trainer)
223223

224224
# ie: path/val_loss=0.5.ckpt
225-
filepath = self._get_metric_interpolated_filepath_name(epoch, monitor_candidates)
225+
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, epoch, global_step)
226226

227227
# callback supports multiple simultaneous modes
228228
# here we call each mode sequentially
229229
# Mode 1: save all checkpoints OR only the top k
230230
if self.save_top_k:
231-
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, epoch, filepath)
231+
self._save_top_k_checkpoints(monitor_candidates, trainer, pl_module, filepath)
232232

233233
# Mode 2: save the last checkpoint
234-
self._save_last_checkpoint(trainer, pl_module, epoch, monitor_candidates, filepath)
234+
self._save_last_checkpoint(trainer, pl_module, monitor_candidates, filepath)
235235

236236
def __validate_init_configuration(self):
237237
if self.save_top_k is not None and self.save_top_k < -1:
@@ -360,16 +360,17 @@ def _format_checkpoint_name(
360360
cls,
361361
filename: Optional[str],
362362
epoch: int,
363+
step: int,
363364
metrics: Dict[str, Any],
364365
prefix: str = "",
365366
) -> str:
366367
if not filename:
367368
# filename is not set, use default name
368-
filename = "{epoch}"
369+
filename = "{epoch}-{step}"
369370
# check and parse user passed keys in the string
370371
groups = re.findall(r"(\{.*?)[:\}]", filename)
371372
if len(groups) >= 0:
372-
metrics["epoch"] = epoch
373+
metrics.update({"epoch": epoch, 'step': step})
373374
for group in groups:
374375
name = group[1:]
375376
filename = filename.replace(group, name + "={" + name)
@@ -379,32 +380,32 @@ def _format_checkpoint_name(
379380
return cls.CHECKPOINT_JOIN_CHAR.join([txt for txt in (prefix, filename) if txt])
380381

381382
def format_checkpoint_name(
382-
self, epoch: int, metrics: Dict[str, Any], ver: Optional[int] = None
383+
self, epoch: int, step: int, metrics: Dict[str, Any], ver: Optional[int] = None
383384
) -> str:
384385
"""Generate a filename according to the defined template.
385386
386387
Example::
387388
388389
>>> tmpdir = os.path.dirname(__file__)
389390
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
390-
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
391+
>>> os.path.basename(ckpt.format_checkpoint_name(0, 1, metrics={}))
391392
'epoch=0.ckpt'
392393
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
393-
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
394+
>>> os.path.basename(ckpt.format_checkpoint_name(5, 2, metrics={}))
394395
'epoch=005.ckpt'
395396
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
396-
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
397+
>>> os.path.basename(ckpt.format_checkpoint_name(2, 3, metrics=dict(val_loss=0.123456)))
397398
'epoch=2-val_loss=0.12.ckpt'
398399
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
399-
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
400+
>>> os.path.basename(ckpt.format_checkpoint_name(0, 4, metrics={}))
400401
'missing=0.ckpt'
401-
>>> ckpt = ModelCheckpoint(filename='{epoch}')
402-
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
403-
'epoch=0.ckpt'
402+
>>> ckpt = ModelCheckpoint(filename='{step}')
403+
>>> os.path.basename(ckpt.format_checkpoint_name(0, 0, {}))
404+
'step=0.ckpt'
404405
405406
"""
406407
filename = self._format_checkpoint_name(
407-
self.filename, epoch, metrics, prefix=self.prefix
408+
self.filename, epoch, step, metrics, prefix=self.prefix
408409
)
409410
if ver is not None:
410411
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
@@ -479,13 +480,11 @@ def _validate_monitor_key(self, trainer):
479480
)
480481
raise MisconfigurationException(m)
481482

482-
def _get_metric_interpolated_filepath_name(self, epoch, ckpt_name_metrics):
483-
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
483+
def _get_metric_interpolated_filepath_name(self, ckpt_name_metrics: Dict[str, Any], epoch: int, step: int):
484+
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics)
484485
version_cnt = 0
485486
while self._fs.exists(filepath):
486-
filepath = self.format_checkpoint_name(
487-
epoch, ckpt_name_metrics, ver=version_cnt
488-
)
487+
filepath = self.format_checkpoint_name(epoch, step, ckpt_name_metrics, ver=version_cnt)
489488
# this epoch called before
490489
version_cnt += 1
491490
return filepath
@@ -494,9 +493,10 @@ def _monitor_candidates(self, trainer):
494493
ckpt_name_metrics = deepcopy(trainer.logger_connector.logged_metrics)
495494
ckpt_name_metrics.update(trainer.logger_connector.callback_metrics)
496495
ckpt_name_metrics.update(trainer.logger_connector.progress_bar_metrics)
496+
ckpt_name_metrics.update({"step": trainer.global_step, "epoch": trainer.current_epoch})
497497
return ckpt_name_metrics
498498

499-
def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, filepath):
499+
def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath):
500500
should_save_last = self.monitor is None or self.save_last
501501
if not should_save_last:
502502
return
@@ -506,7 +506,11 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
506506
# when user ALSO asked for the 'last.ckpt' change the name
507507
if self.save_last:
508508
last_filepath = self._format_checkpoint_name(
509-
self.CHECKPOINT_NAME_LAST, epoch, ckpt_name_metrics, prefix=self.prefix
509+
self.CHECKPOINT_NAME_LAST,
510+
trainer.current_epoch,
511+
trainer.global_step,
512+
ckpt_name_metrics,
513+
prefix=self.prefix
510514
)
511515
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
512516

@@ -523,17 +527,19 @@ def _save_last_checkpoint(self, trainer, pl_module, epoch, ckpt_name_metrics, fi
523527
if self.monitor is None:
524528
self.best_model_path = self.last_model_path
525529

526-
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, epoch, filepath):
530+
def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath):
527531
current = metrics.get(self.monitor)
532+
epoch = metrics.get("epoch")
533+
step = metrics.get("step")
528534

529535
if not isinstance(current, torch.Tensor) and current is not None:
530536
current = torch.tensor(current, device=pl_module.device)
531537

532538
if self.check_monitor_top_k(current):
533-
self._update_best_and_save(filepath, current, epoch, trainer, pl_module)
539+
self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module)
534540
elif self.verbose:
535541
rank_zero_info(
536-
f"Epoch {epoch:d}: {self.monitor} was not in top {self.save_top_k}"
542+
f"Epoch {epoch:d}, step {step:d}: {self.monitor} was not in top {self.save_top_k}"
537543
)
538544

539545
def _is_valid_monitor_key(self, metrics):
@@ -544,11 +550,11 @@ def _update_best_and_save(
544550
filepath: str,
545551
current: torch.Tensor,
546552
epoch: int,
553+
step: int,
547554
trainer,
548555
pl_module,
549556
):
550-
551-
k = epoch + 1 if self.save_top_k == -1 else self.save_top_k
557+
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
552558

553559
del_list = []
554560
if len(self.best_k_models) == k and k > 0:
@@ -575,9 +581,8 @@ def _update_best_and_save(
575581

576582
if self.verbose:
577583
rank_zero_info(
578-
f"Epoch {epoch:d}: {self.monitor} reached"
579-
f" {current:0.5f} (best {self.best_model_score:0.5f}),"
580-
f" saving model to {filepath} as top {k}"
584+
f"Epoch {epoch:d}, global step {step:d}: {self.monitor} reached {current:0.5f}"
585+
f' (best {self.best_model_score:0.5f}), saving model to "{filepath}" as top {k}'
581586
)
582587
self._save_model(filepath, trainer, pl_module)
583588

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -250,9 +250,10 @@ def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
250250
# depre warning
251251
if eval_results is not None and user_reduced:
252252
step = 'testing_epoch_end' if self.testing else 'validation_epoch_end'
253-
m = f'The {step} should not return anything as of 9.1.' \
254-
f'to log, use self.log(...) or self.write(...) directly in the LightningModule'
255-
self.warning_cache.warn(m)
253+
self.warning_cache.warn(
254+
f'The {step} should not return anything as of 9.1.'
255+
' To log, use self.log(...) or self.write(...) directly in the LightningModule'
256+
)
256257

257258
if using_eval_result and not user_reduced:
258259
eval_results = self.__auto_reduce_result_objs(outputs)

0 commit comments

Comments
 (0)