Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from argparse import Namespace
from copy import deepcopy
from enum import Enum
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Type, Union
from warnings import warn

import torch
import yaml

import pytorch_lightning as pl
Expand Down Expand Up @@ -57,7 +58,7 @@ class ModelIO:
def load_from_checkpoint(
cls,
checkpoint_path: Union[str, IO],
map_location: _MAP_LOCATION_TYPE = None,
map_location: Optional[_MAP_LOCATION_TYPE] = None,
hparams_file: Optional[str] = None,
strict: bool = True,
**kwargs: Any,
Expand Down Expand Up @@ -172,13 +173,13 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
def _load_from_checkpoint(
cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint_path: Union[str, IO],
map_location: _MAP_LOCATION_TYPE = None,
map_location: Optional[_MAP_LOCATION_TYPE] = None,
hparams_file: Optional[str] = None,
strict: bool = True,
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
if map_location is None:
map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage)
map_location = _default_map_location
with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)

Expand Down Expand Up @@ -444,3 +445,7 @@ def convert(val: str) -> Union[int, float, bool, str]:
except (ValueError, SyntaxError) as err:
log.debug(err)
return val


def _default_map_location(storage: torch._StorageBase, _: str) -> torch._StorageBase:
return storage
4 changes: 2 additions & 2 deletions src/pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import io
from pathlib import Path
from typing import Any, Dict, IO, Union
from typing import Any, Dict, IO, Optional, Union

import fsspec
import torch
Expand All @@ -27,7 +27,7 @@

def load(
path_or_url: Union[IO, _PATH],
map_location: _MAP_LOCATION_TYPE = None,
map_location: Optional[_MAP_LOCATION_TYPE] = None,
) -> Any:
"""Loads a checkpoint.

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
_DEVICE = Union[torch.device, str, int]
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]
_MAP_LOCATION_TYPE = Union[_DEVICE, Callable[[torch._StorageBase, str], torch._StorageBase], Dict[_DEVICE, _DEVICE]]


@runtime_checkable
Expand Down