Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ module = [
"pytorch_lightning.trainer.connectors.logger_connector",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.cloud_io",
"pytorch_lightning.utilities.device_dtype_mixin",
"pytorch_lightning.utilities.device_parser",
"pytorch_lightning.utilities.parsing",
Expand Down
15 changes: 10 additions & 5 deletions pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@

import io
from pathlib import Path
from typing import IO, Union
from typing import Any, Callable, Dict, IO, Optional, Union

import fsspec
import torch
from fsspec.implementations.local import LocalFileSystem
from fsspec.implementations.local import AbstractFileSystem, LocalFileSystem
from packaging.version import Version


def load(path_or_url: Union[str, IO, Path], map_location=None):
def load(
path_or_url: Union[str, IO, Path],
map_location: Optional[
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
] = None,
) -> Any:
if not isinstance(path_or_url, (str, Path)):
# any sort of BytesIO or similiar
return torch.load(path_or_url, map_location=map_location)
Expand All @@ -33,7 +38,7 @@ def load(path_or_url: Union[str, IO, Path], map_location=None):
return torch.load(f, map_location=map_location)


def get_filesystem(path: Union[str, Path]):
def get_filesystem(path: Union[str, Path]) -> AbstractFileSystem:
path = str(path)
if "://" in path:
# use the fileystem from the protocol specified
Expand All @@ -42,7 +47,7 @@ def get_filesystem(path: Union[str, Path]):
return LocalFileSystem()


def atomic_save(checkpoint, filepath: str):
def atomic_save(checkpoint: Dict[str, Any], filepath: Union[str, Path]) -> None:
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints.

Args:
Expand Down