Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/code-checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:

- name: Install dependencies
run: |
pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip install torch==1.12 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt
pip install -r requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
pip list
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ module = [
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.core.datamodule",
"pytorch_lightning.core.decorators",
"pytorch_lightning.core.module",
"pytorch_lightning.core.saving",
"pytorch_lightning.demos.boring_classes",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
if _TORCH_GREATER_EQUAL_1_12:
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
else:
MixedPrecision = None
MixedPrecision = None # type: ignore[misc,assignment]


class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/fully_sharded_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
)
from torch.distributed.fsdp.wrap import enable_wrap
else:
MixedPrecision = None
MixedPrecision = None # type: ignore[misc,assignment]
BackwardPrefetch = None # type: ignore[misc,assignment]
CPUOffload = None # type: ignore[misc,assignment]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train
# load last weights
if worker_output.weights_path is not None:
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
trainer.lightning_module.load_state_dict(ckpt)
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)

trainer.state = worker_output.trainer_state
Expand Down
11 changes: 6 additions & 5 deletions src/pytorch_lightning/utilities/cloud_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@
from fsspec.core import url_to_fs
from fsspec.implementations.local import AbstractFileSystem

from pytorch_lightning.utilities.types import _PATH
from pytorch_lightning.utilities.types import _DEVICE, _PATH


def load(
path_or_url: Union[IO, _PATH],
map_location: Optional[
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
] = None,
map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None,
) -> Any:
"""Loads a checkpoint.

Expand All @@ -41,7 +39,10 @@ def load(
# any sort of BytesIO or similar
return torch.load(path_or_url, map_location=map_location)
if str(path_or_url).startswith("http"):
return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location)
return torch.hub.load_state_dict_from_url(
str(path_or_url),
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
)
fs = get_filesystem(path_or_url)
with fs.open(path_or_url, "rb") as f:
return torch.load(f, map_location=map_location)
Expand Down