Skip to content

Commit 91c64e9

Browse files
Load checkpoint from Bytes (#4314)
* load directly from fs * if not str or path * pep8 * type annotation Co-authored-by: Sean Naren <[email protected]>
1 parent 3abfec8 commit 91c64e9

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

pytorch_lightning/core/saving.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import inspect
1818
import os
1919
from argparse import Namespace
20-
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
20+
from typing import Union, Dict, Any, Optional, Callable, MutableMapping, IO
2121
from warnings import warn
2222

2323
import fsspec
@@ -52,7 +52,7 @@ class ModelIO(object):
5252
@classmethod
5353
def load_from_checkpoint(
5454
cls,
55-
checkpoint_path: str,
55+
checkpoint_path: Union[str, IO],
5656
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
5757
hparams_file: Optional[str] = None,
5858
strict: bool = True,
@@ -65,7 +65,7 @@ def load_from_checkpoint(
6565
Any arguments specified through \*args and \*\*kwargs will override args stored in `hparams`.
6666
6767
Args:
68-
checkpoint_path: Path to checkpoint. This can also be a URL.
68+
checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object
6969
map_location:
7070
If your checkpoint saved a GPU model and you now load on CPUs
7171
or a different number of GPUs, use this to map to the new setup.

pytorch_lightning/utilities/cloud_io.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
import io
1616
from distutils.version import LooseVersion
17-
from typing import Union
17+
from typing import Union, IO
1818
from pathlib import Path
1919
from urllib.parse import urlparse
2020
import torch
2121
import fsspec
2222

2323

24-
def load(path_or_url: str, map_location=None):
24+
def load(path_or_url: Union[str, IO, Path], map_location=None):
25+
if not isinstance(path_or_url, (str, Path)):
26+
# any sort of BytesIO or similiar
27+
return torch.load(path_or_url, map_location=map_location)
2528
if path_or_url.startswith("http"):
2629
return torch.hub.load_state_dict_from_url(path_or_url, map_location=map_location)
2730
fs = get_filesystem(path_or_url)

0 commit comments

Comments
 (0)