2020from argparse import Namespace
2121from copy import deepcopy
2222from enum import Enum
23- from typing import Any , Callable , Dict , IO , MutableMapping , Optional , Union
23+ from typing import Any , Callable , cast , Dict , IO , MutableMapping , Optional , Type , Union
2424from warnings import warn
2525
26- import torch
2726import yaml
2827
2928import pytorch_lightning as pl
3433from pytorch_lightning .utilities .migration import pl_legacy_patch
3534from pytorch_lightning .utilities .parsing import parse_class_init_keys
3635from pytorch_lightning .utilities .rank_zero import rank_zero_warn
37- from pytorch_lightning .utilities .types import _PATH
36+ from pytorch_lightning .utilities .types import _MAP_LOCATION_TYPE , _PATH
3837
3938log = logging .getLogger (__name__ )
4039PRIMITIVE_TYPES = (bool , int , float , str )
@@ -58,11 +57,11 @@ class ModelIO:
5857 def load_from_checkpoint (
5958 cls ,
6059 checkpoint_path : Union [str , IO ],
61- map_location : Optional [ Union [ Dict [ str , str ], str , torch . device , int , Callable ]] = None ,
60+ map_location : _MAP_LOCATION_TYPE = None ,
6261 hparams_file : Optional [str ] = None ,
6362 strict : bool = True ,
64- ** kwargs ,
65- ):
63+ ** kwargs : Any ,
64+ ) -> Union [ "pl.LightningModule" , "pl.LightningDataModule" ] :
6665 r"""
6766 Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint
6867 it stores the arguments passed to ``__init__`` in the checkpoint under ``"hyper_parameters"``.
@@ -171,15 +170,15 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
171170
172171
173172def _load_from_checkpoint (
174- cls : Union [" pl.LightningModule", "pl.LightningDataModule" ],
173+ cls : Union [Type [ "ModelIO" ], Type [ " pl.LightningModule"], Type [ "pl.LightningDataModule" ] ],
175174 checkpoint_path : Union [str , IO ],
176- map_location : Optional [ Union [ Dict [ str , str ], str , torch . device , int , Callable ]] = None ,
175+ map_location : _MAP_LOCATION_TYPE = None ,
177176 hparams_file : Optional [str ] = None ,
178- strict : Optional [ bool ] = None ,
177+ strict : bool = True ,
179178 ** kwargs : Any ,
180- ) -> Any :
179+ ) -> Union [ "pl.LightningModule" , "pl.LightningDataModule" ] :
181180 if map_location is None :
182- map_location = lambda storage , loc : storage
181+ map_location = cast ( _MAP_LOCATION_TYPE , lambda storage , loc : storage )
183182 with pl_legacy_patch ():
184183 checkpoint = pl_load (checkpoint_path , map_location = map_location )
185184
@@ -202,15 +201,18 @@ def _load_from_checkpoint(
202201
203202 if issubclass (cls , pl .LightningDataModule ):
204203 return _load_state (cls , checkpoint , ** kwargs )
205- return _load_state (cls , checkpoint , strict = strict , ** kwargs )
204+ # allow cls to be evaluated as subclassed LightningModule or,
205+ # as LightningModule for internal tests
206+ if issubclass (cls , pl .LightningModule ):
207+ return _load_state (cls , checkpoint , strict = strict , ** kwargs )
206208
207209
208210def _load_state (
209- cls : Union ["pl.LightningModule" , "pl.LightningDataModule" ],
211+ cls : Union [Type [ "pl.LightningModule" ], Type [ "pl.LightningDataModule" ] ],
210212 checkpoint : Dict [str , Any ],
211- strict : Optional [ bool ] = None ,
213+ strict : bool = True ,
212214 ** cls_kwargs_new : Any ,
213- ) -> Any :
215+ ) -> Union [ "pl.LightningModule" , "pl.LightningDataModule" ] :
214216 cls_spec = inspect .getfullargspec (cls .__init__ )
215217 cls_init_args_name = inspect .signature (cls .__init__ ).parameters .keys ()
216218
@@ -228,8 +230,7 @@ def _load_state(
228230 cls_kwargs_loaded .update (checkpoint .get (_old_hparam_key , {}))
229231
230232 # 2. Try to restore model hparams from checkpoint using the new key
231- _new_hparam_key = cls .CHECKPOINT_HYPER_PARAMS_KEY
232- cls_kwargs_loaded .update (checkpoint .get (_new_hparam_key ))
233+ cls_kwargs_loaded .update (checkpoint .get (cls .CHECKPOINT_HYPER_PARAMS_KEY , {}))
233234
234235 # 3. Ensure that `cls_kwargs_old` has the right type, back compatibility between dict and Namespace
235236 cls_kwargs_loaded = _convert_loaded_hparams (cls_kwargs_loaded , checkpoint .get (cls .CHECKPOINT_HYPER_PARAMS_TYPE ))
@@ -271,7 +272,9 @@ def _load_state(
271272 return obj
272273
273274
274- def _convert_loaded_hparams (model_args : dict , hparams_type : Optional [Union [Callable , str ]] = None ) -> object :
275+ def _convert_loaded_hparams (
276+ model_args : Dict [str , Any ], hparams_type : Optional [Union [Callable , str ]] = None
277+ ) -> Dict [str , Any ]:
275278 """Convert hparams according given type in callable or string (past) format."""
276279 # if not hparams type define
277280 if not hparams_type :
0 commit comments