Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
63ed39a
remove from pyproject.toml
jxtngx Jul 29, 2022
e113c6a
Merge branch 'master' into codeq/core-saving
jxtngx Aug 1, 2022
5d5bbda
update
jxtngx Aug 1, 2022
907db26
Merge branch 'master' into codeq/core-saving
jxtngx Aug 1, 2022
8995503
clean
jxtngx Aug 1, 2022
28e01a0
attempt to resolve circular import
jxtngx Aug 1, 2022
69831a3
undo 28e01a0
jxtngx Aug 1, 2022
cf45c80
Merge branch 'master' into codeq/core-saving
jxtngx Aug 1, 2022
62b004d
update
jxtngx Aug 1, 2022
ba716c2
Merge remote-tracking branch 'origin/codeq/core-saving' into codeq/co…
jxtngx Aug 1, 2022
5fc09be
update for CI/CD pytest
jxtngx Aug 1, 2022
3e248db
update
jxtngx Aug 1, 2022
ea0d812
update for pytest
jxtngx Aug 1, 2022
b584dde
Merge branch 'master' into codeq/core-saving
jxtngx Aug 1, 2022
9c9a86f
Merge branch 'master' into codeq/core-saving
jxtngx Aug 2, 2022
3ca0ca4
Merge branch 'master' into codeq/core-saving
jxtngx Aug 3, 2022
55bb80c
Update src/pytorch_lightning/core/saving.py
jxtngx Aug 4, 2022
99643a0
update
jxtngx Aug 4, 2022
1371b04
Merge branch 'master' into codeq/core-saving
jxtngx Aug 4, 2022
9735d21
update
jxtngx Aug 4, 2022
7cd5da9
Merge remote-tracking branch 'origin/codeq/core-saving' into codeq/co…
jxtngx Aug 4, 2022
855679f
update
jxtngx Aug 4, 2022
02292bc
Merge branch 'master' into codeq/core-saving
jxtngx Aug 5, 2022
0502a4e
Merge branch 'master' into codeq/core-saving
otaj Aug 8, 2022
82fc452
Merge branch 'master' into codeq/core-saving
jxtngx Aug 8, 2022
ad469c9
update
jxtngx Aug 8, 2022
d912ac1
explicit Dict annotation
jxtngx Aug 8, 2022
ca47708
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 8, 2022
f105d7f
add _MAP_LOCATION_TYPES to types
jxtngx Aug 8, 2022
aebb404
commit suggestion
jxtngx Aug 8, 2022
fec1733
reverting last commit
jxtngx Aug 8, 2022
148bb26
update
jxtngx Aug 8, 2022
b4cc603
commit suggestion to check against ci/cd tests
jxtngx Aug 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ module = [
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.module",
"pytorch_lightning.core.saving",
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.profilers.base",
Expand Down
39 changes: 21 additions & 18 deletions src/pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
from argparse import Namespace
from copy import deepcopy
from enum import Enum
from typing import Any, Callable, Dict, IO, MutableMapping, Optional, Union
from typing import Any, Callable, cast, Dict, IO, MutableMapping, Optional, Type, Union
from warnings import warn

import torch
import yaml

import pytorch_lightning as pl
Expand All @@ -34,7 +33,7 @@
from pytorch_lightning.utilities.migration import pl_legacy_patch
from pytorch_lightning.utilities.parsing import parse_class_init_keys
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH

log = logging.getLogger(__name__)
PRIMITIVE_TYPES = (bool, int, float, str)
Expand All @@ -58,11 +57,11 @@ class ModelIO:
def load_from_checkpoint(
cls,
checkpoint_path: Union[str, IO],
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[str] = None,
strict: bool = True,
**kwargs,
):
**kwargs: Any,
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
r"""
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
Expand Down Expand Up @@ -171,15 +170,15 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:


def _load_from_checkpoint(
cls: Union["pl.LightningModule", "pl.LightningDataModule"],
cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint_path: Union[str, IO],
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
map_location: _MAP_LOCATION_TYPE = None,
hparams_file: Optional[str] = None,
strict: Optional[bool] = None,
strict: bool = True,
**kwargs: Any,
) -> Any:
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
if map_location is None:
map_location = lambda storage, loc: storage
map_location = cast(_MAP_LOCATION_TYPE, lambda storage, loc: storage)
with pl_legacy_patch():
checkpoint = pl_load(checkpoint_path, map_location=map_location)

Expand All @@ -202,15 +201,18 @@ def _load_from_checkpoint(

if issubclass(cls, pl.LightningDataModule):
return _load_state(cls, checkpoint, **kwargs)
return _load_state(cls, checkpoint, strict=strict, **kwargs)
# allow cls to be evaluated as subclassed LightningModule or,
# as LightningModule for internal tests
if issubclass(cls, pl.LightningModule):
return _load_state(cls, checkpoint, strict=strict, **kwargs)


def _load_state(
cls: Union["pl.LightningModule", "pl.LightningDataModule"],
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
checkpoint: Dict[str, Any],
strict: Optional[bool] = None,
strict: bool = True,
**cls_kwargs_new: Any,
) -> Any:
) -> Union["pl.LightningModule", "pl.LightningDataModule"]:
cls_spec = inspect.getfullargspec(cls.__init__)
cls_init_args_name = inspect.signature(cls.__init__).parameters.keys()

Expand All @@ -228,8 +230,7 @@ def _load_state(
cls_kwargs_loaded.update(checkpoint.get(_old_hparam_key, {}))

# 2. Try to restore model hparams from checkpoint using the new key
_new_hparam_key = cls.CHECKPOINT_HYPER_PARAMS_KEY
cls_kwargs_loaded.update(checkpoint.get(_new_hparam_key))
cls_kwargs_loaded.update(checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_KEY, {}))

# 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
cls_kwargs_loaded = _convert_loaded_hparams(cls_kwargs_loaded, checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_TYPE))
Expand Down Expand Up @@ -271,7 +272,9 @@ def _load_state(
return obj


def _convert_loaded_hparams(model_args: dict, hparams_type: Optional[Union[Callable, str]] = None) -> object:
def _convert_loaded_hparams(
model_args: Dict[str, Any], hparams_type: Optional[Union[Callable, str]] = None
) -> Dict[str, Any]:
"""Convert hparams according given type in callable or string (past) format."""
# if not hparams type define
if not hparams_type:
Expand Down
6 changes: 3 additions & 3 deletions src/pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,19 @@

import io
from pathlib import Path
from typing import Any, Callable, Dict, IO, Optional, Union
from typing import Any, Dict, IO, Union

import fsspec
import torch
from fsspec.core import url_to_fs
from fsspec.implementations.local import AbstractFileSystem

from pytorch_lightning.utilities.types import _DEVICE, _PATH
from pytorch_lightning.utilities.types import _MAP_LOCATION_TYPE, _PATH


def load(
path_or_url: Union[IO, _PATH],
map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None,
map_location: _MAP_LOCATION_TYPE = None,
) -> Any:
"""Loads a checkpoint.

Expand Down
4 changes: 3 additions & 1 deletion src/pytorch_lightning/utilities/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None:
del hparams_dict[k]


def parse_class_init_keys(cls: Type["pl.LightningModule"]) -> Tuple[str, Optional[str], Optional[str]]:
def parse_class_init_keys(
cls: Union[Type["pl.LightningModule"], Type["pl.LightningDataModule"]]
) -> Tuple[str, Optional[str], Optional[str]]:
"""Parse key words for standard ``self``, ``*args`` and ``**kwargs``.

Examples:
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/utilities/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, Generator, Iterator, List, Mapping, Optional, Sequence, Type, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -49,6 +49,7 @@
]
EVAL_DATALOADERS = Union[DataLoader, Sequence[DataLoader]]
_DEVICE = Union[torch.device, str, int]
_MAP_LOCATION_TYPE = Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]]


@runtime_checkable
Expand Down