Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,8 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
"""
self._validate_monitor_key(trainer)

# track epoch when ckpt was last checked
global_step = trainer.global_step
self._last_global_step_saved = global_step

# what can be monitored
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=global_step)
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=trainer.global_step)

# callback supports multiple simultaneous modes
# here we call each mode sequentially
Expand Down Expand Up @@ -638,6 +634,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> D
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if not self.save_last:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)

filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
# set the last model path before saving because it will be part of the state.
Expand All @@ -649,9 +646,9 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.monitor is None or self.save_top_k == 0:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)

current = monitor_candidates.get(self.monitor)

if self.check_monitor_top_k(trainer, current):
self._update_best_and_save(current, trainer, monitor_candidates)
elif self.verbose:
Expand All @@ -662,6 +659,7 @@ def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
if self.monitor is not None or self.save_top_k == 0:
return
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)

filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
# set the best model path before saving because it will be part of the state.
Expand Down
9 changes: 9 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,3 +1268,12 @@ def test_none_monitor_saves_correct_best_model_path(tmpdir):
full_path = str(tmpdir / expected)
ckpt = torch.load(full_path)
assert ckpt["callbacks"][mc.state_key]["best_model_path"] == full_path


def test_last_global_step_saved():
# this should not save anything
model_checkpoint = ModelCheckpoint(save_top_k=0, save_last=False, monitor="foo")
trainer = Mock()
trainer.callback_metrics = {"foo": 123}
model_checkpoint.save_checkpoint(trainer)
assert model_checkpoint._last_global_step_saved == -1
4 changes: 2 additions & 2 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def mock_save_function(filepath, *args):
trainer.save_checkpoint = mock_save_function

# emulate callback's calls during the training
for i, loss in enumerate(losses):
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
for i, loss in enumerate(losses, 1):
trainer.fit_loop.global_step = i
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`

file_lists = set(os.listdir(tmpdir))

Expand Down