diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index ffdc0988a1a6e..0b43bd51d6d2f 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -20,9 +20,10 @@ from argparse import Namespace from copy import deepcopy from enum import Enum -from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union +from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Type, Union from warnings import warn +import torch import yaml import pytorch_lightning as pl @@ -57,7 +58,7 @@ class ModelIO: def load_from_checkpoint( cls, checkpoint_path: Union[str, IO], - map_location: _MAP_LOCATION_TYPE = None, + map_location: Optional[_MAP_LOCATION_TYPE] = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs: Any, @@ -172,13 +173,13 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: def _load_from_checkpoint( cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint_path: Union[str, IO], - map_location: _MAP_LOCATION_TYPE = None, + map_location: Optional[_MAP_LOCATION_TYPE] = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: if map_location is None: - map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage) + map_location = _default_map_location with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) @@ -444,3 +445,7 @@ def convert(val: str) -> Union[int, float, bool, str]: except (ValueError, SyntaxError) as err: log.debug(err) return val + + +def _default_map_location(storage: torch._StorageBase, _: str) -> torch._StorageBase: + return storage diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index 99629bcda8980..ddcb054b35c50 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -15,7 +15,7 @@ import io from pathlib import Path -from typing import Any, Dict, IO, Union +from typing import Any, Dict, IO, Optional, Union import fsspec import torch @@ -27,7 +27,7 @@ def load( path_or_url: Union[IO, _PATH], - map_location: _MAP_LOCATION_TYPE = None, + map_location: Optional[_MAP_LOCATION_TYPE] = None, ) -> Any: """Loads a checkpoint. diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 18e2db6feb6c6..9de416c4a8045 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -49,7 +49,7 @@ ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] _DEVICE = Union[torch.device, str, int] -_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] +_MAP_LOCATION_TYPE = Union[_DEVICE, Callable[[torch._StorageBase, str], torch._StorageBase], Dict[_DEVICE, _DEVICE]] @runtime_checkable