Skip to content

Commit 76836a3

Browse files
authored
Run mypy with PyTorch 1.12 (#14044)
1 parent 5c05719 commit 76836a3

File tree

6 files changed

+10
-10
lines changed

6 files changed

+10
-10
lines changed

.github/workflows/code-checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
3333
- name: Install dependencies
3434
run: |
35-
pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
35+
pip install torch==1.12 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html
3636
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt
3737
# todo: adjust requirements for both code-bases
3838
pip install -r requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ module = [
5252
"pytorch_lightning.callbacks.progress.rich_progress",
5353
"pytorch_lightning.callbacks.quantization",
5454
"pytorch_lightning.core.datamodule",
55-
"pytorch_lightning.core.decorators",
5655
"pytorch_lightning.core.module",
5756
"pytorch_lightning.core.saving",
5857
"pytorch_lightning.demos.boring_classes",

src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
if _TORCH_GREATER_EQUAL_1_12:
2424
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
2525
else:
26-
MixedPrecision = None
26+
MixedPrecision = None # type: ignore[misc,assignment]
2727

2828

2929
class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin):

src/pytorch_lightning/strategies/fully_sharded_native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
)
5252
from torch.distributed.fsdp.wrap import enable_wrap
5353
else:
54-
MixedPrecision = None
54+
MixedPrecision = None # type: ignore[misc,assignment]
5555
BackwardPrefetch = None # type: ignore[misc,assignment]
5656
CPUOffload = None # type: ignore[misc,assignment]
5757

src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train
144144
# load last weights
145145
if worker_output.weights_path is not None:
146146
ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
147-
trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type]
147+
trainer.lightning_module.load_state_dict(ckpt)
148148
self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)
149149

150150
trainer.state = worker_output.trainer_state

src/pytorch_lightning/utilities/cloud_io.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,12 @@
2222
from fsspec.core import url_to_fs
2323
from fsspec.implementations.local import AbstractFileSystem
2424

25-
from pytorch_lightning.utilities.types import _PATH
25+
from pytorch_lightning.utilities.types import _DEVICE, _PATH
2626

2727

2828
def load(
2929
path_or_url: Union[IO, _PATH],
30-
map_location: Optional[
31-
Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]]
32-
] = None,
30+
map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None,
3331
) -> Any:
3432
"""Loads a checkpoint.
3533
@@ -41,7 +39,10 @@ def load(
4139
# any sort of BytesIO or similar
4240
return torch.load(path_or_url, map_location=map_location)
4341
if str(path_or_url).startswith("http"):
44-
return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location)
42+
return torch.hub.load_state_dict_from_url(
43+
str(path_or_url),
44+
map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct
45+
)
4546
fs = get_filesystem(path_or_url)
4647
with fs.open(path_or_url, "rb") as f:
4748
return torch.load(f, map_location=map_location)

0 commit comments

Comments
 (0)