Skip to content

Commit 444b21d

Browse files
jjenniferdaipre-commit-ci[bot]ananthsub
authored
Optimize non-empty directory warning check in model checkpoint callback (#9615)
* pt1 dir empty check * clean imports * bring back resolve mkdir? * original doc * warningcache * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cp callback after resolve * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move global_zero check outside warn fn Co-authored-by: ananthsub <[email protected]> * move global_zero check outside warn fn 2 Co-authored-by: ananthsub <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ananthsub <[email protected]>
1 parent a3def9d commit 444b21d

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@
2424
import time
2525
from copy import deepcopy
2626
from datetime import timedelta
27-
from pathlib import Path
28-
from typing import Any, Dict, Optional, Union
27+
from typing import Any, Dict, Optional
2928
from weakref import proxy
3029

3130
import numpy as np
@@ -37,7 +36,7 @@
3736
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_info, rank_zero_warn
3837
from pytorch_lightning.utilities.cloud_io import get_filesystem
3938
from pytorch_lightning.utilities.exceptions import MisconfigurationException
40-
from pytorch_lightning.utilities.types import _METRIC, STEP_OUTPUT
39+
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
4140
from pytorch_lightning.utilities.warnings import WarningCache
4241

4342
log = logging.getLogger(__name__)
@@ -203,7 +202,7 @@ class ModelCheckpoint(Callback):
203202

204203
def __init__(
205204
self,
206-
dirpath: Optional[Union[str, Path]] = None,
205+
dirpath: Optional[_PATH] = None,
207206
filename: Optional[str] = None,
208207
monitor: Optional[str] = None,
209208
verbose: bool = False,
@@ -267,6 +266,8 @@ def on_init_end(self, trainer: "pl.Trainer") -> None:
267266
def on_pretrain_routine_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
268267
"""When pretrain routine starts we build the ckpt dir on the fly."""
269268
self.__resolve_ckpt_dir(trainer)
269+
if trainer.is_global_zero:
270+
self.__warn_if_dir_not_empty(self.dirpath)
270271

271272
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
272273
self._last_time_checked = time.monotonic()
@@ -440,11 +441,8 @@ def __validate_init_configuration(self) -> None:
440441
" will duplicate the last checkpoint saved."
441442
)
442443

443-
def __init_ckpt_dir(self, dirpath: Optional[Union[str, Path]], filename: Optional[str]) -> None:
444-
self._fs = get_filesystem(str(dirpath) if dirpath else "")
445-
446-
if self.save_top_k != 0 and dirpath is not None and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
447-
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
444+
def __init_ckpt_dir(self, dirpath: Optional[_PATH], filename: Optional[str]) -> None:
445+
self._fs = get_filesystem(dirpath if dirpath else "")
448446

449447
if dirpath and self._fs.protocol == "file":
450448
dirpath = os.path.realpath(dirpath)
@@ -619,6 +617,10 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
619617
if not trainer.fast_dev_run and trainer.should_rank_save_checkpoint:
620618
self._fs.makedirs(self.dirpath, exist_ok=True)
621619

620+
def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
621+
if self.save_top_k != 0 and self._fs.isdir(dirpath) and len(self._fs.ls(dirpath)) > 0:
622+
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
623+
622624
def _validate_monitor_key(self, trainer: "pl.Trainer") -> None:
623625
metrics = trainer.callback_metrics
624626

@@ -735,7 +737,7 @@ def _update_best_and_save(
735737
if del_filepath is not None and filepath != del_filepath:
736738
trainer.training_type_plugin.remove_checkpoint(del_filepath)
737739

738-
def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
740+
def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
739741
"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML
740742
file."""
741743
best_k = {k: v.item() for k, v in self.best_k_models.items()}
@@ -744,7 +746,7 @@ def to_yaml(self, filepath: Optional[Union[str, Path]] = None) -> None:
744746
with self._fs.open(filepath, "w") as fp:
745747
yaml.dump(best_k, fp)
746748

747-
def file_exists(self, filepath: Union[str, Path], trainer: "pl.Trainer") -> bool:
749+
def file_exists(self, filepath: _PATH, trainer: "pl.Trainer") -> bool:
748750
"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal
749751
state to diverge between ranks."""
750752
exists = self._fs.exists(filepath)

0 commit comments

Comments
 (0)