From 262d68fa088ada32a5ddce96b8aab13af76de652 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 25 Nov 2021 00:42:52 +0100 Subject: [PATCH 1/8] improve typing in pytorch_lightning/lite --- pytorch_lightning/lite/lite.py | 16 ++++++++++++++-- pytorch_lightning/lite/wrappers.py | 1 + .../training_type/training_type_plugin.py | 2 +- 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index b2adeeac4bd5b..45592c3056a8c 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union, overload import torch import torch.nn as nn @@ -201,7 +201,7 @@ def setup_dataloaders( for dataloader in dataloaders ] dataloaders = dataloaders[0] if len(dataloaders) == 1 else dataloaders - return dataloaders + return dataloaders # type: ignore[return-value] def _setup_dataloader( self, dataloader: DataLoader, replace_sampler: bool = True, move_to_device: bool = True @@ -284,6 +284,18 @@ def autocast(self) -> Generator[None, None, None]: with self._precision_plugin.forward_context(): yield + @overload + def to_device(self, obj: nn.Module) -> nn.Module: + ... + + @overload + def to_device(self, obj: Tensor) -> Tensor: + ... + + @overload + def to_device(self, obj: Any) -> Any: + ... + def to_device(self, obj: Union[nn.Module, Tensor, Any]) -> Union[nn.Module, Tensor, Any]: """Move a :class:`torch.nn.Module` or a collection of tensors to the current device, if it is not already on that device. diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 908ba06bdb84d..cfc224347e13c 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -131,6 +131,7 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = iter(self._dataloader) if self._device is None: yield from iterator + return None for item in iterator: yield move_data_to_device(item, self._device) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 7010c0e878dc9..786c692a38924 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -241,7 +241,7 @@ def validation_step_end(self, output): def test_step_end(self, output): return output - def process_dataloader(self, dataloader: Union[Iterable, DataLoader]) -> Union[Iterable, DataLoader]: + def process_dataloader(self, dataloader: DataLoader) -> DataLoader: """Wraps the dataloader if necessary. Args: From e8e82fc3630e28e94e23c77ac4d09fafc27c6147 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 24 Nov 2021 23:46:49 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/lite/lite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/lite.py b/pytorch_lightning/lite/lite.py index 45592c3056a8c..9efa01267ef11 100644 --- a/pytorch_lightning/lite/lite.py +++ b/pytorch_lightning/lite/lite.py @@ -16,7 +16,7 @@ from contextlib import contextmanager from functools import partial from pathlib import Path -from typing import Any, Callable, cast, Dict, Generator, List, Optional, Sequence, Tuple, Union, overload +from typing import Any, Callable, cast, Dict, Generator, List, Optional, overload, Sequence, Tuple, Union import torch import torch.nn as nn From 6aaba91c9ec58875fe6920ff2ef4801acae76835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 25 Nov 2021 01:30:22 +0100 Subject: [PATCH 3/8] include lite again --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index e3c373aee5aeb..60d92d683e7b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ module = [ "pytorch_lightning.callbacks.pruning", "pytorch_lightning.callbacks.rich_model_summary", "pytorch_lightning.core.optimizer", + "pytorch_lightning.lite.*", "pytorch_lightning.loops.optimization.closure.py", "pytorch_lightning.loops.optimization.manual_loop.py", "pytorch_lightning.loops.evaluation_loop", From 84aad52b35a996277deb21b915a8ab68e4b77736 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 25 Nov 2021 01:36:15 +0100 Subject: [PATCH 4/8] unused import --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 786c692a38924..be51cc9f929a4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor From b897cc134e82d547ec2c3ba52696e5a1796598c9 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 25 Nov 2021 11:06:43 +0100 Subject: [PATCH 5/8] Update training_type_plugin.py --- pytorch_lightning/plugins/training_type/training_type_plugin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 786c692a38924..be51cc9f929a4 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -13,7 +13,7 @@ # limitations under the License. import contextlib from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple, Union import torch from torch import Tensor From 3c937f2a62762669878bc8afbdb136248e59ece9 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Thu, 25 Nov 2021 11:06:54 +0100 Subject: [PATCH 6/8] Update pytorch_lightning/lite/wrappers.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index cfc224347e13c..202404ef7162a 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -131,7 +131,7 @@ def __iter__(self) -> Union[Iterator[Any], Generator[Any, None, None]]: iterator = iter(self._dataloader) if self._device is None: yield from iterator - return None + return for item in iterator: yield move_data_to_device(item, self._device) From 33c96c9bc9ff650db7fb535a2e50b72d77ddb77d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 25 Nov 2021 14:49:59 +0100 Subject: [PATCH 7/8] lite toml update --- pyproject.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c266e0684e974..af5e4d88b6b0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,8 +63,6 @@ module = [ "pytorch_lightning.core.mixins.hparams_mixin", "pytorch_lightning.core.saving", "pytorch_lightning.distributed.dist", - "pytorch_lightning.lite.lite", - "pytorch_lightning.lite.wrappers", "pytorch_lightning.loggers.base", "pytorch_lightning.loggers.comet", "pytorch_lightning.loggers.csv_logs", From 20ef548cd17f55ae40d1f89e7985c929ad931436 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Nov 2021 20:53:28 +0100 Subject: [PATCH 8/8] Fix comment --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index aa44d94bb3868..168e60e1e2e81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ disable_error_code = "attr-defined" # style choices warn_no_return = "False" -# Changes mypy default to ignore all errors +# Ignore mypy errors for these files # TODO: the goal is for this to be empty [[tool.mypy.overrides]] # the list can be generated with: