Skip to content
24 changes: 13 additions & 11 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
import time
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
from weakref import proxy

import numpy as np
Expand All @@ -37,7 +36,7 @@
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -203,7 +202,7 @@ class ModelCheckpoint(Callback):

def __init__(
self,
dirpath: Optional[Union[str, Path]] = None,
dirpath: Optional[_PATH] = None,
filename: Optional[str] = None,
monitor: Optional[str] = None,
verbose: bool = False,
Expand Down Expand Up @@ -267,6 +266,8 @@ def on_init_end(self, trainer: "pl.Trainer") -> None:
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""When pretrain routine starts we build the ckpt dir on the fly."""
self.__resolve_ckpt_dir(trainer)
if trainer.is_global_zero:
self.__warn_if_dir_not_empty(self.dirpath)

def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
self._last_time_checked = time.monotonic()
Expand Down Expand Up @@ -440,11 +441,8 @@ def __validate_init_configuration(self) -> None:
" will duplicate the last checkpoint saved."
)

def __init_ckpt_dir(self, dirpath: Optional[Union[str, Path]], filename: Optional[str]) -> None:
self._fs = get_filesystem(str(dirpath) if dirpath else "")

if self.save_top_k != 0 and dirpath is not None and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
self._fs = get_filesystem(dirpath if dirpath else "")

if dirpath and self._fs.protocol == "file":
dirpath = os.path.realpath(dirpath)
Expand Down Expand Up @@ -619,6 +617,10 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint:
self._fs.makedirs(self.dirpath, exist_ok=True)

def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
metrics = trainer.callback_metrics

Expand Down Expand Up @@ -735,7 +737,7 @@ def _update_best_and_save(
if del_filepath is not None and filepath != del_filepath:
trainer.training_type_plugin.remove_checkpoint(del_filepath)

def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
file."""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
Expand All @@ -744,7 +746,7 @@ def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
with self._fs.open(filepath, "w") as fp:
yaml.dump(best_k, fp)

def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool:
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
state to diverge between ranks."""
exists = self._fs.exists(filepath)
Expand Down