diff --git a/pyproject.toml b/pyproject.toml index 761c7be04cc0e..8db782df357d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,6 @@ module = [ "pytorch_lightning.callbacks.quantization", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.module", - "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.profilers.base", diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index da81e4c212560..ffdc0988a1a6e 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -20,10 +20,9 @@ from argparse import Namespace from copy import deepcopy from enum import Enum -from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union +from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union from warnings import warn -import torch import yaml import pytorch_lightning as pl @@ -34,7 +33,7 @@ from pytorch_lightning.utilities.migration import pl_legacy_patch from pytorch_lightning.utilities.parsing import parse_class_init_keys from pytorch_lightning.utilities.rank_zero import rank_zero_warn -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH log = logging.getLogger(__name__) PRIMITIVE_TYPES = (bool, int, float, str) @@ -58,11 +57,11 @@ class ModelIO: def load_from_checkpoint( cls, checkpoint_path: Union[str, IO], - map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[str] = None, strict: bool = True, - **kwargs, - ): + **kwargs: Any, + ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: r""" Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``. @@ -171,15 +170,15 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: def _load_from_checkpoint( - cls: Union["pl.LightningModule", "pl.LightningDataModule"], + cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint_path: Union[str, IO], - map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None, + map_location: _MAP_LOCATION_TYPE = None, hparams_file: Optional[str] = None, - strict: Optional[bool] = None, + strict: bool = True, **kwargs: Any, -) -> Any: +) -> Union["pl.LightningModule", "pl.LightningDataModule"]: if map_location is None: - map_location = lambda storage, loc: storage + map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage) with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) @@ -202,15 +201,18 @@ def _load_from_checkpoint( if issubclass(cls, pl.LightningDataModule): return _load_state(cls, checkpoint, **kwargs) - return _load_state(cls, checkpoint, strict=strict, **kwargs) + # allow cls to be evaluated as subclassed LightningModule or, + # as LightningModule for internal tests + if issubclass(cls, pl.LightningModule): + return _load_state(cls, checkpoint, strict=strict, **kwargs) def _load_state( - cls: Union["pl.LightningModule", "pl.LightningDataModule"], + cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint: Dict[str, Any], - strict: Optional[bool] = None, + strict: bool = True, **cls_kwargs_new: Any, -) -> Any: +) -> Union["pl.LightningModule", "pl.LightningDataModule"]: cls_spec = inspect.getfullargspec(cls.__init__) cls_init_args_name = inspect.signature(cls.__init__).parameters.keys() @@ -228,8 +230,7 @@ def _load_state( cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {})) # 2. Try to restore model hparams from checkpoint using the new key - _new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY - cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key)) + cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {})) # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE)) @@ -271,7 +272,9 @@ def _load_state( return obj -def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object: +def _convert_loaded_hparams( + model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None +) -> Dict[str, Any]: """Convert hparams according given type in callable or string (past) format.""" # if not hparams type define if not hparams_type: diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index ee3358be59541..99629bcda8980 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -15,19 +15,19 @@ import io from pathlib import Path -from typing import Any, Callable, Dict, IO, Optional, Union +from typing import Any, Dict, IO, Union import fsspec import torch from fsspec.core import url_to_fs from fsspec.implementations.local import AbstractFileSystem -from pytorch_lightning.utilities.types import _DEVICE, _PATH +from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH def load( path_or_url: Union[IO, _PATH], - map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None, + map_location: _MAP_LOCATION_TYPE = None, ) -> Any: """Loads a checkpoint. diff --git a/src/pytorch_lightning/utilities/parsing.py b/src/pytorch_lightning/utilities/parsing.py index 9f5fe2d6b6841..81877f1dffba7 100644 --- a/src/pytorch_lightning/utilities/parsing.py +++ b/src/pytorch_lightning/utilities/parsing.py @@ -108,7 +108,9 @@ def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: del hparams_dict[k] -def parse_class_init_keys(cls: Type["pl.LightningModule"]) -> Tuple[str, Optional[str], Optional[str]]: +def parse_class_init_keys( + cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]] +) -> Tuple[str, Optional[str], Optional[str]]: """Parse key words for standard ``self``, ``*args`` and ``**kwargs``. Examples: diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index f6c14d366805f..18e2db6feb6c6 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path -from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union import torch from torch import Tensor @@ -49,6 +49,7 @@ ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] _DEVICE = Union[torch.device, str, int] +_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] @runtime_checkable