Skip to content

Commit 74ab878

Browse files
Jungwon-Leecarmocca
authored andcommitted
Fix mypy typing errors in pytorch_lightning/callbacks/model_checkpoint.py (#13617)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 3129d97 commit 74ab878

File tree

9 files changed

+35
-34
lines changed

9 files changed

+35
-34
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ warn_no_return = "False"
4747
# the list can be generated with:
4848
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
4949
module = [
50-
"pytorch_lightning.callbacks.model_checkpoint",
5150
"pytorch_lightning.callbacks.progress.rich_progress",
5251
"pytorch_lightning.callbacks.quantization",
5352
"pytorch_lightning.callbacks.stochastic_weight_avg",

src/pytorch_lightning/callbacks/early_stopping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: O
135135
# validation, then we run after validation instead of on train epoch end
136136
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
137137

138-
def _validate_condition_metric(self, logs: Dict[str, float]) -> bool:
138+
def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool:
139139
monitor_val = logs.get(self.monitor)
140140

141141
error_msg = (

src/pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4040
from pytorch_lightning.utilities.logger import _name, _version
4141
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
42-
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
42+
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
4343
from pytorch_lightning.utilities.warnings import WarningCache
4444

4545
log = logging.getLogger(__name__)
@@ -231,13 +231,14 @@ def __init__(
231231
self._save_on_train_epoch_end = save_on_train_epoch_end
232232
self._last_global_step_saved = 0 # no need to save when no steps were taken
233233
self._last_time_checked: Optional[float] = None
234-
self.current_score = None
235-
self.best_k_models = {}
234+
self.current_score: Optional[Tensor] = None
235+
self.best_k_models: Dict[str, Tensor] = {}
236236
self.kth_best_model_path = ""
237-
self.best_model_score = None
237+
self.best_model_score: Optional[Tensor] = None
238238
self.best_model_path = ""
239239
self.last_model_path = ""
240240

241+
self.kth_value: Tensor
241242
self.__init_monitor_mode(mode)
242243
self.__init_ckpt_dir(dirpath, filename)
243244
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
@@ -256,6 +257,7 @@ def state_key(self) -> str:
256257

257258
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
258259
self.__resolve_ckpt_dir(trainer)
260+
assert self.dirpath is not None
259261
if trainer.is_global_zero and stage == "fit":
260262
self.__warn_if_dir_not_empty(self.dirpath)
261263

@@ -362,7 +364,7 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None: # pragma: no-cover
362364
self._save_topk_checkpoint(trainer, monitor_candidates)
363365
self._save_last_checkpoint(trainer, monitor_candidates)
364366

365-
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
367+
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
366368
if self.save_top_k == 0:
367369
return
368370

@@ -395,7 +397,7 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
395397
from pytorch_lightning.trainer.states import TrainerFn
396398

397399
return (
398-
trainer.fast_dev_run # disable checkpointing with fast_dev_run
400+
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
399401
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
400402
or trainer.sanity_checking # don't save anything during sanity check
401403
or self._last_global_step_saved == trainer.global_step # already saved at the last step
@@ -493,15 +495,15 @@ def check_monitor_top_k(self, trainer: "pl.Trainer", current: Optional[Tensor] =
493495
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
494496

495497
# If using multiple devices, make sure all processes are unanimous on the decision.
496-
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
498+
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))
497499

498500
return should_update_best_and_save
499501

500502
@classmethod
501503
def _format_checkpoint_name(
502504
cls,
503505
filename: Optional[str],
504-
metrics: Dict[str, _METRIC],
506+
metrics: Dict[str, Tensor],
505507
prefix: str = "",
506508
auto_insert_metric_name: bool = True,
507509
) -> str:
@@ -522,7 +524,7 @@ def _format_checkpoint_name(
522524
filename = filename.replace(group, f"{{0[{name}]")
523525

524526
if name not in metrics:
525-
metrics[name] = 0
527+
metrics[name] = torch.tensor(0)
526528
filename = filename.format(metrics)
527529

528530
if prefix:
@@ -531,7 +533,7 @@ def _format_checkpoint_name(
531533
return filename
532534

533535
def format_checkpoint_name(
534-
self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None
536+
self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None
535537
) -> str:
536538
"""Generate a filename according to the defined template.
537539
@@ -591,6 +593,7 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
591593
ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
592594
elif trainer.loggers:
593595
if len(trainer.loggers) == 1:
596+
assert trainer.logger is not None
594597
save_dir = trainer.logger.save_dir or trainer.default_root_dir
595598
else:
596599
save_dir = trainer.default_root_dir
@@ -613,7 +616,7 @@ def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None:
613616
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
614617

615618
def _get_metric_interpolated_filepath_name(
616-
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
619+
self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None
617620
) -> str:
618621
filepath = self.format_checkpoint_name(monitor_candidates)
619622

@@ -624,7 +627,7 @@ def _get_metric_interpolated_filepath_name(
624627

625628
return filepath
626629

627-
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
630+
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]:
628631
monitor_candidates = deepcopy(trainer.callback_metrics)
629632
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
630633
# or does not exist we overwrite it as it's likely an error
@@ -634,7 +637,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
634637
monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
635638
return monitor_candidates
636639

637-
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
640+
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
638641
if not self.save_last:
639642
return
640643

@@ -651,16 +654,18 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
651654
if previous and previous != filepath:
652655
trainer.strategy.remove_checkpoint(previous)
653656

654-
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
657+
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
658+
assert self.monitor
655659
current = monitor_candidates.get(self.monitor)
656660
if self.check_monitor_top_k(trainer, current):
661+
assert current is not None
657662
self._update_best_and_save(current, trainer, monitor_candidates)
658663
elif self.verbose:
659664
epoch = monitor_candidates["epoch"]
660665
step = monitor_candidates["step"]
661666
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
662667

663-
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
668+
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
664669
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
665670
# set the best model path before saving because it will be part of the state.
666671
previous, self.best_model_path = self.best_model_path, filepath
@@ -669,7 +674,7 @@ def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidate
669674
trainer.strategy.remove_checkpoint(previous)
670675

671676
def _update_best_and_save(
672-
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
677+
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
673678
) -> None:
674679
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
675680

@@ -691,11 +696,11 @@ def _update_best_and_save(
691696
if len(self.best_k_models) == k:
692697
# monitor dict has reached k elements
693698
_op = max if self.mode == "min" else min
694-
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
699+
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
695700
self.kth_value = self.best_k_models[self.kth_best_model_path]
696701

697702
_op = min if self.mode == "min" else max
698-
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
703+
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
699704
self.best_model_score = self.best_k_models[self.best_model_path]
700705

701706
if self.verbose:
@@ -715,6 +720,7 @@ def to_yaml(self, filepath: Optional[_PATH] = None) -> None:
715720
file."""
716721
best_k = {k: v.item() for k, v in self.best_k_models.items()}
717722
if filepath is None:
723+
assert self.dirpath
718724
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
719725
with self._fs.open(filepath, "w") as fp:
720726
yaml.dump(best_k, fp)

src/pytorch_lightning/core/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def __to_tensor(self, value: numbers.Number) -> Tensor:
532532
return torch.tensor(value, device=self.device)
533533

534534
@staticmethod
535-
def __check_numel_1(value: torch.Tensor, name: str) -> None:
535+
def __check_numel_1(value: Tensor, name: str) -> None:
536536
if not torch.numel(value) == 1:
537537
raise ValueError(
538538
f"`self.log({name}, {value})` was called, but the tensor must have a single element."

src/pytorch_lightning/strategies/strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bo
285285
"""
286286

287287
def reduce_boolean_decision(self, decision: bool) -> bool:
288-
"""Reduce the early stopping decision across all processes."""
288+
"""Reduce a boolean decision across all processes."""
289289
return decision
290290

291291
def pre_backward(self, closure_loss: Tensor) -> None:

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -169,19 +169,13 @@ def broadcast(self, obj: object, src: int = 0) -> object:
169169
obj = torch.load(buffer)
170170
return obj
171171

172-
def reduce_boolean_decision(self, decision: bool) -> bool:
173-
decision = torch.tensor(int(decision), device=self.root_device)
174-
decision = self.reduce(decision, reduce_op="sum")
175-
decision = bool(decision == self.world_size)
176-
return decision
177-
178172
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
179173
if not isinstance(output, Tensor):
180174
output = torch.tensor(output, device=self.root_device)
181175

182-
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
183-
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
184-
if _invalid_reduce_op or _invalid_reduce_op_str:
176+
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
177+
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
178+
if invalid_reduce_op or invalid_reduce_op_str:
185179
raise MisconfigurationException(
186180
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
187181
)

src/pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,7 @@ def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[Tensor]:
529529
result_metric.meta.sync.should = should
530530
cache = result_metric._computed
531531
if cache is not None:
532-
if not isinstance(cache, torch.Tensor):
532+
if not isinstance(cache, Tensor):
533533
raise ValueError(
534534
f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
535535
f" Found {cache}"

src/pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2705,7 +2705,9 @@ def loggers(self, loggers: Optional[List[Logger]]) -> None:
27052705
self._loggers = loggers if loggers else []
27062706

27072707
@property
2708-
def callback_metrics(self) -> dict:
2708+
def callback_metrics(self) -> Dict[str, Tensor]:
2709+
# TODO: the true typing return can include dictionaries as defined in
2710+
# `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
27092711
return self._logger_connector.callback_metrics
27102712

27112713
@property

src/pytorch_lightning/utilities/distributed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
9999
return gathered_result
100100

101101

102-
def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]:
102+
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
103103
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
104104
torch.distributed.all_gather(gathered_result, result, group)
105105
return gathered_result

0 commit comments

Comments
 (0)