Skip to content

Commit 82d2d1d

Browse files
jxtngxawaelchliotajrohitgr7
authored
Fix mypy errors attributed to pytorch_lightning.core.saving (#13932)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: otaj <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 34afde7 commit 82d2d1d

File tree

5 files changed

+29
-24
lines changed

5 files changed

+29
-24
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ module = [
5353
"pytorch_lightning.callbacks.quantization",
5454
"pytorch_lightning.core.datamodule",
5555
"pytorch_lightning.core.module",
56-
"pytorch_lightning.core.saving",
5756
"pytorch_lightning.demos.boring_classes",
5857
"pytorch_lightning.demos.mnist_datamodule",
5958
"pytorch_lightning.profilers.base",

src/pytorch_lightning/core/saving.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020
from argparse import Namespace
2121
from copy import deepcopy
2222
from enum import Enum
23-
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
23+
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
2424
from warnings import warn
2525

26-
import torch
2726
import yaml
2827

2928
import pytorch_lightning as pl
@@ -34,7 +33,7 @@
3433
from pytorch_lightning.utilities.migration import pl_legacy_patch
3534
from pytorch_lightning.utilities.parsing import parse_class_init_keys
3635
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
37-
from pytorch_lightning.utilities.types import _PATH
36+
from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH
3837

3938
log = logging.getLogger(__name__)
4039
PRIMITIVE_TYPES = (bool, int, float, str)
@@ -58,11 +57,11 @@ class ModelIO:
5857
def load_from_checkpoint(
5958
cls,
6059
checkpoint_path: Union[str, IO],
61-
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
60+
map_location: _MAP_LOCATION_TYPE = None,
6261
hparams_file: Optional[str] = None,
6362
strict: bool = True,
64-
**kwargs,
65-
):
63+
**kwargs: Any,
64+
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
6665
r"""
6766
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
6867
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
@@ -171,15 +170,15 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
171170

172171

173172
def _load_from_checkpoint(
174-
cls: Union["pl.LightningModule", "pl.LightningDataModule"],
173+
cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
175174
checkpoint_path: Union[str, IO],
176-
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
175+
map_location: _MAP_LOCATION_TYPE = None,
177176
hparams_file: Optional[str] = None,
178-
strict: Optional[bool] = None,
177+
strict: bool = True,
179178
**kwargs: Any,
180-
) -> Any:
179+
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
181180
if map_location is None:
182-
map_location = lambda storage, loc: storage
181+
map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage)
183182
with pl_legacy_patch():
184183
checkpoint = pl_load(checkpoint_path, map_location=map_location)
185184

@@ -202,15 +201,18 @@ def _load_from_checkpoint(
202201

203202
if issubclass(cls, pl.LightningDataModule):
204203
return _load_state(cls, checkpoint, **kwargs)
205-
return _load_state(cls, checkpoint, strict=strict, **kwargs)
204+
# allow cls to be evaluated as subclassed LightningModule or,
205+
# as LightningModule for internal tests
206+
if issubclass(cls, pl.LightningModule):
207+
return _load_state(cls, checkpoint, strict=strict, **kwargs)
206208

207209

208210
def _load_state(
209-
cls: Union["pl.LightningModule", "pl.LightningDataModule"],
211+
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
210212
checkpoint: Dict[str, Any],
211-
strict: Optional[bool] = None,
213+
strict: bool = True,
212214
**cls_kwargs_new: Any,
213-
) -> Any:
215+
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
214216
cls_spec = inspect.getfullargspec(cls.__init__)
215217
cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()
216218

@@ -228,8 +230,7 @@ def _load_state(
228230
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))
229231

230232
# 2. Try to restore model hparams from checkpoint using the new key
231-
_new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY
232-
cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key))
233+
cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}))
233234

234235
# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
235236
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
@@ -271,7 +272,9 @@ def _load_state(
271272
return obj
272273

273274

274-
def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object:
275+
def _convert_loaded_hparams(
276+
model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None
277+
) -> Dict[str, Any]:
275278
"""Convert hparams according given type in callable or string (past) format."""
276279
# if not hparams type define
277280
if not hparams_type:

src/pytorch_lightning/utilities/cloud_io.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,19 @@
1515

1616
import io
1717
from pathlib import Path
18-
from typing import Any, Callable, Dict, IO, Optional, Union
18+
from typing import Any, Dict, IO, Union
1919

2020
import fsspec
2121
import torch
2222
from fsspec.core import url_to_fs
2323
from fsspec.implementations.local import AbstractFileSystem
2424

25-
from pytorch_lightning.utilities.types import _DEVICE, _PATH
25+
from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH
2626

2727

2828
def load(
2929
path_or_url: Union[IO, _PATH],
30-
map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None,
30+
map_location: _MAP_LOCATION_TYPE = None,
3131
) -> Any:
3232
"""Loads a checkpoint.
3333

src/pytorch_lightning/utilities/parsing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
108108
del hparams_dict[k]
109109

110110

111-
def parse_class_init_keys(cls: Type["pl.LightningModule"]) -> Tuple[str, Optional[str], Optional[str]]:
111+
def parse_class_init_keys(
112+
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]]
113+
) -> Tuple[str, Optional[str], Optional[str]]:
112114
"""Parse key words for standard ``self``, ``*args`` and ``**kwargs``.
113115
114116
Examples:

src/pytorch_lightning/utilities/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from contextlib import contextmanager
2020
from dataclasses import dataclass
2121
from pathlib import Path
22-
from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
22+
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
2323

2424
import torch
2525
from torch import Tensor
@@ -49,6 +49,7 @@
4949
]
5050
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
5151
_DEVICE = Union[torch.device, str, int]
52+
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]
5253

5354

5455
@runtime_checkable

0 commit comments

Comments
 (0)