Skip to content

Commit 8995503

Browse files
committed
clean
1 parent 907db26 commit 8995503

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

src/pytorch_lightning/core/saving.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
log = logging.getLogger(__name__)
4040
PRIMITIVE_TYPES = (bool, int, float, str)
4141
ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
42+
MAP_LOCATION_TYPE = Optional[
43+
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
44+
]
4245

4346
if _OMEGACONF_AVAILABLE:
4447
from omegaconf import OmegaConf
@@ -175,9 +178,7 @@ def on_hpc_load(self, checkpoint: Dict[str, Any]) -> None:
175178
def _load_from_checkpoint(
176179
cls: Union[Type["ModelIO"], Type["pl.LightningModule"], Type["pl.LightningDataModule"]],
177180
checkpoint_path: Union[str, IO],
178-
map_location: Optional[
179-
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
180-
] = None,
181+
map_location: MAP_LOCATION_TYPE = None,
181182
hparams_file: Optional[str] = None,
182183
strict: Optional[bool] = None,
183184
**kwargs: Any,

0 commit comments

Comments
 (0)