From 5f0f29109aca877f595db7b9a889e3cdf9d76b19 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 8 Aug 2022 21:54:08 +0200 Subject: [PATCH 1/3] Update type definition for map_location --- src/pytorch_lightning/core/saving.py | 11 ++++++++--- src/pytorch_lightning/utilities/cloud_io.py | 4 ++-- src/pytorch_lightning/utilities/types.py | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index ffdc0988a1a6e..1271e39680bcf 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -23,6 +23,7 @@ from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union from warnings import warn +import torch import yaml import pytorch_lightning as pl @@ -57,7 +58,7 @@ class ModelIO: def load_from_checkpoint( cls, checkpoint_path: Union[str, IO], - map_location: _MAP_LOCATION_TYPE = None, + map_location: Optional[_MAP_LOCATION_TYPE] = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs: Any, @@ -172,13 +173,13 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None: def _load_from_checkpoint( cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]], checkpoint_path: Union[str, IO], - map_location: _MAP_LOCATION_TYPE = None, + map_location: Optional[_MAP_LOCATION_TYPE] = None, hparams_file: Optional[str] = None, strict: bool = True, **kwargs: Any, ) -> Union["pl.LightningModule", "pl.LightningDataModule"]: if map_location is None: - map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage) + map_location = _default_map_location with pl_legacy_patch(): checkpoint = pl_load(checkpoint_path, map_location=map_location) @@ -444,3 +445,7 @@ def convert(val: str) -> Union[int, float, bool, str]: except (ValueError, SyntaxError) as err: log.debug(err) return val + + +def _default_map_location(storage: torch.StorageBase, _: str) -> torch.StorageBase: + return storage diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index 99629bcda8980..63b49323396c5 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -15,7 +15,7 @@ import io from pathlib import Path -from typing import Any, Dict, IO, Union +from typing import Any, Dict, IO, Union, Optional import fsspec import torch @@ -27,7 +27,7 @@ def load( path_or_url: Union[IO, _PATH], - map_location: _MAP_LOCATION_TYPE = None, + map_location: Optional[_MAP_LOCATION_TYPE] = None, ) -> Any: """Loads a checkpoint. diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 18e2db6feb6c6..8d34bbf258cc1 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -49,7 +49,7 @@ ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] _DEVICE = Union[torch.device, str, int] -_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] +_MAP_LOCATION_TYPE = Union[_DEVICE, Callable[[torch.StorageBase, str], torch.StorageBase], Dict[_DEVICE, _DEVICE]] @runtime_checkable From 6d8433397877579ca300914a86edf62ff96b3457 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Aug 2022 20:30:16 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/cloud_io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index 63b49323396c5..ddcb054b35c50 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -15,7 +15,7 @@ import io from pathlib import Path -from typing import Any, Dict, IO, Union, Optional +from typing import Any, Dict, IO, Optional, Union import fsspec import torch From ba681b78c97b16c7a3396a0244346b83ad9a6f0a Mon Sep 17 00:00:00 2001 From: awaelchli Date: Mon, 8 Aug 2022 22:39:53 +0200 Subject: [PATCH 3/3] update --- src/pytorch_lightning/core/saving.py | 4 ++-- src/pytorch_lightning/utilities/types.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/core/saving.py b/src/pytorch_lightning/core/saving.py index 1271e39680bcf..0b43bd51d6d2f 100644 --- a/src/pytorch_lightning/core/saving.py +++ b/src/pytorch_lightning/core/saving.py @@ -20,7 +20,7 @@ from argparse import Namespace from copy import deepcopy from enum import Enum -from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union +from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Type, Union from warnings import warn import torch @@ -447,5 +447,5 @@ def convert(val: str) -> Union[int, float, bool, str]: return val -def _default_map_location(storage: torch.StorageBase, _: str) -> torch.StorageBase: +def _default_map_location(storage: torch._StorageBase, _: str) -> torch._StorageBase: return storage diff --git a/src/pytorch_lightning/utilities/types.py b/src/pytorch_lightning/utilities/types.py index 8d34bbf258cc1..9de416c4a8045 100644 --- a/src/pytorch_lightning/utilities/types.py +++ b/src/pytorch_lightning/utilities/types.py @@ -49,7 +49,7 @@ ] EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]] _DEVICE = Union[torch.device, str, int] -_MAP_LOCATION_TYPE = Union[_DEVICE, Callable[[torch.StorageBase, str], torch.StorageBase], Dict[_DEVICE, _DEVICE]] +_MAP_LOCATION_TYPE = Union[_DEVICE, Callable[[torch._StorageBase, str], torch._StorageBase], Dict[_DEVICE, _DEVICE]] @runtime_checkable