Skip to content

Commit b6b071f

Browse files
committed
fix mypy error3
1 parent ffaa590 commit b6b071f

File tree

3 files changed

+27
-30
lines changed

3 files changed

+27
-30
lines changed

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
439439
return dataloader
440440

441441
@staticmethod
442-
def _resolve_overfit_batches(dataloaders: Collection[DataLoader], mode: RunningStage) -> Collection[DataLoader]:
442+
def _resolve_overfit_batches(dataloaders: Union[Collection[DataLoader], DataLoader], mode: RunningStage) -> Collection[DataLoader]:
443443
all_have_sequential_sampler = True
444444

445445
def resolve_has_no_sequential_sampler(dataloader: DataLoader):

src/pytorch_lightning/trainer/trainer.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from datetime import timedelta
2626
from functools import partial
2727
from pathlib import Path
28-
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Type, Union
28+
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union
2929
from weakref import proxy
3030

3131
import torch
@@ -77,7 +77,7 @@
7777
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
7878
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
7979
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
80-
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
80+
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
8181
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
8282
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
8383
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
@@ -545,6 +545,7 @@ def __init__(
545545
self._logger_connector.on_trainer_init(logger, log_every_n_steps, move_metrics_to_cpu)
546546

547547
# init debugging flags
548+
self.val_check_batch: Union[int, float]
548549
self.val_check_interval: Union[int, float]
549550
self.num_sanity_val_steps: Union[int, float]
550551
self.limit_train_batches: Union[int, float]
@@ -741,7 +742,7 @@ def _fit_impl(
741742
# TODO: ckpt_path only in v2.0
742743
ckpt_path = ckpt_path or self.resume_from_checkpoint
743744
self._ckpt_path = self.__set_ckpt_path(
744-
ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
745+
ckpt_path, model_provided=True, model_connected=self.lightning_module is not None # type: ignore
745746
)
746747
results = self._run(model, ckpt_path=self.ckpt_path)
747748

@@ -985,7 +986,7 @@ def _predict_impl(
985986
self.state.status = TrainerStatus.RUNNING
986987
self.predicting = True
987988

988-
self.predict_loop.return_predictions = return_predictions
989+
self.predict_loop.return_predictions = return_predictions # type: ignore
989990

990991
# if a datamodule comes in as the second arg, then fix it for the user
991992
if isinstance(dataloaders, LightningDataModule):
@@ -1395,7 +1396,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
13951396

13961397
if model_provided and ckpt_path is None:
13971398
# use passed model to function without loading weights
1398-
return
1399+
return None
13991400

14001401
if model_connected and ckpt_path is None:
14011402
ckpt_path = "best"
@@ -1449,8 +1450,8 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
14491450
f'.{fn}(ckpt_path="last") is set, but there is no fault tolerant'
14501451
" or last checkpoint available. No checkpoint will be loaded."
14511452
)
1452-
return
1453-
ckpt_path = max(candidates_ts.keys(), key=partial(operator.getitem, candidates_ts))
1453+
return None
1454+
ckpt_path = max(candidates_ts.keys(), key=partial(operator.getitem, candidates_ts)) # type: ignore
14541455

14551456
if not ckpt_path:
14561457
raise MisconfigurationException(
@@ -1664,7 +1665,7 @@ def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None
16641665
prev_fx_name = pl_module._current_fx_name
16651666
pl_module._current_fx_name = "on_load_checkpoint"
16661667

1667-
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
1668+
callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks")
16681669

16691670
if callback_states is None:
16701671
return
@@ -1692,7 +1693,7 @@ def _call_callbacks_on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None
16921693

16931694
def _call_callbacks_load_state_dict(self, checkpoint: Dict[str, Any]) -> None:
16941695
"""Called when loading a model checkpoint, calls every callback's `load_state_dict`."""
1695-
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
1696+
callback_states: Optional[Dict[Union[Type, str], Dict]] = checkpoint.get("callbacks")
16961697

16971698
if callback_states is None:
16981699
return
@@ -1745,6 +1746,7 @@ def __init_profiler(self, profiler: Optional[Union[Profiler, str]]) -> None:
17451746
)
17461747
profiler_class = PROFILERS[profiler]
17471748
profiler = profiler_class()
1749+
assert isinstance(profiler, Profiler)
17481750
self.profiler: Profiler = profiler or PassThroughProfiler()
17491751

17501752
def __setup_profiler(self) -> None:
@@ -2126,8 +2128,9 @@ def data_parallel_device_ids(self) -> Optional[List[int]]:
21262128
return self.device_ids if isinstance(self.accelerator, CUDAAccelerator) else None
21272129

21282130
@property
2129-
def lightning_module(self) -> "pl.LightningModule":
2131+
def lightning_module(self) -> "pl.LightningModule": # type: ignore
21302132
# TODO: this is actually an optional return
2133+
assert self.strategy.lightning_module is not None
21312134
return self.strategy.lightning_module
21322135

21332136
@property
@@ -2219,12 +2222,12 @@ def model(self, model: torch.nn.Module) -> None:
22192222

22202223
@property
22212224
def log_dir(self) -> Optional[str]:
2222-
assert self.logger is not None
22232225
if len(self.loggers) == 1:
2224-
if isinstance(self.logger, TensorBoardLogger):
2225-
dirpath = self.logger.log_dir
2226-
else:
2226+
assert self.logger is not None
2227+
if not isinstance(self.logger, TensorBoardLogger):
22272228
dirpath = self.logger.save_dir
2229+
else:
2230+
dirpath = self.logger.log_dir
22282231
else:
22292232
dirpath = self.default_root_dir
22302233

@@ -2709,7 +2712,7 @@ def logger(self, logger: Optional[Logger]) -> None:
27092712
if not logger:
27102713
self.loggers = []
27112714
elif isinstance(logger, LoggerCollection):
2712-
self.loggers = list(logger)
2715+
self.loggers = [x for x in logger]
27132716
else:
27142717
self.loggers = [logger]
27152718

@@ -2722,17 +2725,17 @@ def loggers(self, loggers: Optional[List[Logger]]) -> None:
27222725
self._loggers = loggers if loggers else []
27232726

27242727
@property
2725-
def callback_metrics(self) -> Dict[str, Tensor]:
2728+
def callback_metrics(self) -> Dict:
27262729
# TODO: the true typing return can include dictionaries as defined in
27272730
# `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
27282731
return self._logger_connector.callback_metrics
27292732

27302733
@property
2731-
def logged_metrics(self) -> dict:
2734+
def logged_metrics(self) -> _OUT_DICT:
27322735
return self._logger_connector.logged_metrics
27332736

27342737
@property
2735-
def progress_bar_metrics(self) -> dict:
2738+
def progress_bar_metrics(self) -> Dict:
27362739
return self._logger_connector.progress_bar_metrics
27372740

27382741
@property
@@ -2748,7 +2751,7 @@ def _exit_gracefully_on_signal(self) -> None:
27482751

27492752
def _should_terminate_gracefully(self) -> bool:
27502753
value = torch.tensor(int(self._terminate_gracefully), device=self.strategy.root_device)
2751-
return self.strategy.reduce(value, reduce_op="sum") > 0
2754+
return bool(self.strategy.reduce(value, reduce_op="sum") > 0)
27522755

27532756
"""
27542757
Other

src/pytorch_lightning/utilities/parsing.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717
import inspect
1818
import pickle
1919
import types
20-
from argparse import Namespace
2120
from dataclasses import fields, is_dataclass
22-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union
21+
from typing import Any, Dict, List, MutableMapping, Optional, Sequence, Tuple, Type, Union
2322

2423
from torch import nn
2524
from typing_extensions import Literal
@@ -94,18 +93,13 @@ def is_picklable(obj: object) -> bool:
9493
return False
9594

9695

97-
def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
96+
def clean_namespace(hparams: MutableMapping) -> None:
9897
"""Removes all unpicklable entries from hparams."""
99-
100-
hparams_dict = hparams
101-
if isinstance(hparams, Namespace):
102-
hparams_dict = hparams.__dict__
103-
104-
del_attrs = [k for k, v in hparams_dict.items() if not is_picklable(v)]
98+
del_attrs = [k for k, v in hparams.items() if not is_picklable(v)]
10599

106100
for k in del_attrs:
107101
rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
108-
del hparams_dict[k]
102+
del hparams[k]
109103

110104

111105
def parse_class_init_keys(

0 commit comments

Comments
 (0)