From 50fe3acd5707c7313eb79bbfa8f0642e6b9b005e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 5 Aug 2022 14:33:55 +0200 Subject: [PATCH] Run mypy with PyTorch 1.12 --- .github/workflows/code-checks.yml | 2 +- pyproject.toml | 1 - .../plugins/precision/fully_sharded_native_amp.py | 2 +- .../strategies/fully_sharded_native.py | 2 +- .../strategies/launchers/multiprocessing.py | 2 +- src/pytorch_lightning/utilities/cloud_io.py | 11 ++++++----- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/.github/workflows/code-checks.yml b/.github/workflows/code-checks.yml index ed9cd46adbe44..d1948daedf1ae 100644 --- a/.github/workflows/code-checks.yml +++ b/.github/workflows/code-checks.yml @@ -32,7 +32,7 @@ jobs: - name: Install dependencies run: | - pip install torch==1.11 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html + pip install torch==1.12 --find-links https://download.pytorch.org/whl/cpu/torch_stable.html python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt pip install -r requirements/pytorch/devel.txt --find-links https://download.pytorch.org/whl/cpu/torch_stable.html pip list diff --git a/pyproject.toml b/pyproject.toml index 5473e73c52e19..9b8400ba27577 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,6 @@ module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.quantization", "pytorch_lightning.core.datamodule", - "pytorch_lightning.core.decorators", "pytorch_lightning.core.module", "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", diff --git a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py index 8c693f2975bbd..60e53b880c84d 100644 --- a/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py +++ b/src/pytorch_lightning/plugins/precision/fully_sharded_native_amp.py @@ -23,7 +23,7 @@ if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision else: - MixedPrecision = None + MixedPrecision = None # type: ignore[misc,assignment] class FullyShardedNativeMixedPrecisionPlugin(ShardedNativeMixedPrecisionPlugin): diff --git a/src/pytorch_lightning/strategies/fully_sharded_native.py b/src/pytorch_lightning/strategies/fully_sharded_native.py index 4c351f26fa3b9..d92931fb5cdb2 100644 --- a/src/pytorch_lightning/strategies/fully_sharded_native.py +++ b/src/pytorch_lightning/strategies/fully_sharded_native.py @@ -51,7 +51,7 @@ ) from torch.distributed.fsdp.wrap import enable_wrap else: - MixedPrecision = None + MixedPrecision = None # type: ignore[misc,assignment] BackwardPrefetch = None # type: ignore[misc,assignment] CPUOffload = None # type: ignore[misc,assignment] diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 39bba092e9c60..2617e5fe27b10 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -144,7 +144,7 @@ def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", train # load last weights if worker_output.weights_path is not None: ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path) - trainer.lightning_module.load_state_dict(ckpt) # type: ignore[arg-type] + trainer.lightning_module.load_state_dict(ckpt) self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path) trainer.state = worker_output.trainer_state diff --git a/src/pytorch_lightning/utilities/cloud_io.py b/src/pytorch_lightning/utilities/cloud_io.py index 81482a8ab24f9..ee3358be59541 100644 --- a/src/pytorch_lightning/utilities/cloud_io.py +++ b/src/pytorch_lightning/utilities/cloud_io.py @@ -22,14 +22,12 @@ from fsspec.core import url_to_fs from fsspec.implementations.local import AbstractFileSystem -from pytorch_lightning.utilities.types import _PATH +from pytorch_lightning.utilities.types import _DEVICE, _PATH def load( path_or_url: Union[IO, _PATH], - map_location: Optional[ - Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] - ] = None, + map_location: Optional[Union[_DEVICE, Callable[[_DEVICE], _DEVICE], Dict[_DEVICE, _DEVICE]]] = None, ) -> Any: """Loads a checkpoint. @@ -41,7 +39,10 @@ def load( # any sort of BytesIO or similar return torch.load(path_or_url, map_location=map_location) if str(path_or_url).startswith("http"): - return torch.hub.load_state_dict_from_url(str(path_or_url), map_location=map_location) + return torch.hub.load_state_dict_from_url( + str(path_or_url), + map_location=map_location, # type: ignore[arg-type] # upstream annotation is not correct + ) fs = get_filesystem(path_or_url) with fs.open(path_or_url, "rb") as f: return torch.load(f, map_location=map_location)