Skip to content
Merged
6 changes: 5 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed an attribute error when running the tuner together with the `StochasticWeightAveraging` callback ([#14836](https://github.com/Lightning-AI/lightning/pull/14836))


- Fixed MissingFieldException in offline mode for the `NeptuneLogger()` ([#14919](https://github.com/Lightning-AI/lightning/pull/14919))


- Fixed wandb `save_dir` is overridden by `None` `dir` when using CLI ([#14878](https://github.com/Lightning-AI/lightning/pull/14878))


Expand All @@ -296,14 +298,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed torchscript error with containers of LightningModules ([#14904](https://github.com/Lightning-AI/lightning/pull/14904))


- Fixed reloading of the last checkpoint on run restart ([#14907](https://github.com/Lightning-AI/lightning/pull/14907))

- `SaveConfigCallback` instances should only save the config once to allow having the `overwrite=False` safeguard when using `LightningCLI(..., run=False)` ([#14927](https://github.com/Lightning-AI/lightning/pull/14927))

- `SaveConfigCallback` instances should only save the config once to allow having the `overwrite=False` safeguard when using `LightningCLI(..., run=False)` ([#14927](https://github.com/Lightning-AI/lightning/pull/14927))


- Fixed an issue with terminating the trainer profiler when a `StopIteration` exception is raised while using an `IterableDataset` ([#14940](https://github.com/Lightning-AI/lightning/pull/14945))



## [1.7.7] - 2022-09-22

### Fixed
Expand Down
27 changes: 19 additions & 8 deletions src/pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import warnings
from copy import deepcopy
from datetime import timedelta
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Set
from weakref import proxy

import numpy as np
Expand Down Expand Up @@ -238,6 +238,7 @@ def __init__(
self.last_model_path = ""

self.kth_value: Tensor
self.dirpath: Optional[_PATH]
self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
Expand All @@ -255,8 +256,9 @@ def state_key(self) -> str:
)

def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
self.__resolve_ckpt_dir(trainer)
assert self.dirpath is not None
dirpath = self.__resolve_ckpt_dir(trainer)
dirpath = trainer.strategy.broadcast(dirpath)
self.dirpath = dirpath
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)

Expand Down Expand Up @@ -571,7 +573,7 @@ def format_checkpoint_name(
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> _PATH:
"""Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to
determine where to save checkpoints. The path for saving weights is set in this priority:

Expand All @@ -583,7 +585,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
"""
if self.dirpath is not None:
# short circuit if dirpath was passed to ModelCheckpoint
return
return self.dirpath

if len(trainer.loggers) > 0:
if trainer.loggers[0].save_dir is not None:
Expand All @@ -600,9 +602,18 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
# if no loggers, use default_root_dir
ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints")

ckpt_path = trainer.strategy.broadcast(ckpt_path)

self.dirpath = ckpt_path
return ckpt_path

def _find_last_checkpoints(self, trainer: "pl.Trainer") -> Set[str]:
# find all checkpoints in the folder
ckpt_path = self.__resolve_ckpt_dir(trainer)
if self._fs.exists(ckpt_path):
return {
os.path.normpath(p)
for p in self._fs.ls(ckpt_path, detail=False)
if self.CHECKPOINT_NAME_LAST in os.path.split(p)[1]
}
return set()

def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightning_lite.plugins.environments.slurm_environment import SLURMEnvironment
from lightning_lite.utilities.cloud_io import get_filesystem
from lightning_lite.utilities.types import _PATH
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins.precision import ApexMixedPrecisionPlugin, NativeMixedPrecisionPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE
Expand Down Expand Up @@ -158,9 +159,10 @@ def _set_ckpt_path(
ckpt_path = getattr(self.trainer.checkpoint_callback, "best_model_path", None)

elif ckpt_path == "last":
candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + [
getattr(cb, "last_model_path", None) for cb in self.trainer.checkpoint_callbacks
]
candidates = {getattr(ft, "ckpt_path", None) for ft in ft_checkpoints}
for callback in self.trainer.checkpoint_callbacks:
if isinstance(callback, ModelCheckpoint):
candidates |= callback._find_last_checkpoints(self.trainer)
candidates_fs = {path: get_filesystem(path) for path in candidates if path}
candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)}
if not candidates_ts:
Expand Down
34 changes: 34 additions & 0 deletions tests/tests_pytorch/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,40 @@ def test_checkpoint_path_input_last(tmpdir, ckpt_path, save_last, fn):
assert trainer.ckpt_path == final_path


def test_checkpoint_find_last(tmpdir):
"""Test that the last checkpoint is found correctly."""
model = BoringModel()
mc = ModelCheckpoint(save_last=True)
trainer = Trainer(
max_epochs=1,
limit_train_batches=1,
limit_val_batches=0,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
callbacks=[mc],
)
assert trainer.ckpt_path is None
trainer.fit(model)

model = BoringModel()
mc = ModelCheckpoint()
trainer = Trainer(
max_epochs=1,
limit_train_batches=1,
limit_val_batches=0,
enable_model_summary=False,
enable_progress_bar=False,
logger=False,
default_root_dir=tmpdir,
callbacks=[mc],
)
assert trainer.ckpt_path is None
trainer.fit(model, ckpt_path="last")
assert trainer.ckpt_path == str(tmpdir / "checkpoints" / "last.ckpt")


@pytest.mark.parametrize("ckpt_path", (None, "best", "specific"))
@pytest.mark.parametrize("save_top_k", (-1, 0, 1, 2))
@pytest.mark.parametrize("fn", ("validate", "test", "predict"))
Expand Down