From 404e244a5057b15d56943e0a282a69b1898f53b7 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 7 Aug 2021 10:54:17 +0200 Subject: [PATCH 1/2] Fix mypy typing for . --- pyproject.toml | 1 + pytorch_lightning/utilities/warnings.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d07f19ef10986..65cdf5be8a20c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ module = [ "pytorch_lightning.utilities.distributed", "pytorch_lightning.utilities.memory", "pytorch_lightning.utilities.parsing", + "pytorch_lightning.utilities.warnings", "pytorch_lightning.utilities.xla_device", ] ignore_errors = "False" diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index 1949f7ec3e378..700b8c05e6219 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -14,17 +14,20 @@ """Warning-related utilities""" import warnings from functools import partial +from typing import Any, Sequence, Tuple, Type, Union from pytorch_lightning.utilities.distributed import rank_zero_only -def _warn(*args, stacklevel: int = 2, **kwargs): - warnings.warn(*args, stacklevel=stacklevel, **kwargs) +def _warn(m: Union[str, Warning], category: Union[Type[Warning], Any], stacklevel: int = 2, **kwargs: Any) -> None: + warnings.warn(m, category, stacklevel=stacklevel, **kwargs) @rank_zero_only -def rank_zero_warn(*args, stacklevel: int = 4, **kwargs): - _warn(*args, stacklevel=stacklevel, **kwargs) +def rank_zero_warn( + m: Union[str, Warning], category: Union[Type[Warning], Any], stacklevel: int = 4, **kwargs: Any +) -> None: + _warn(m, category, stacklevel=stacklevel, **kwargs) class LightningDeprecationWarning(DeprecationWarning): @@ -38,12 +41,12 @@ class LightningDeprecationWarning(DeprecationWarning): class WarningCache(set): - def warn(self, m, *args, stacklevel: int = 5, **kwargs): + def warn(self, m: str, *args: Any, stacklevel: int = 5, **kwargs: Any) -> None: if m not in self: self.add(m) rank_zero_warn(m, *args, stacklevel=stacklevel, **kwargs) - def deprecation(self, m, *args, stacklevel: int = 5, **kwargs): + def deprecation(self, m: str, *args: Any, stacklevel: int = 5, **kwargs: Any) -> None: if m not in self: self.add(m) rank_zero_deprecation(m, *args, stacklevel=stacklevel, **kwargs) From b064501ec925e15667bc772a6d95945bdbc5ae75 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 7 Aug 2021 10:57:31 +0200 Subject: [PATCH 2/2] Clean import --- pytorch_lightning/utilities/warnings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/warnings.py b/pytorch_lightning/utilities/warnings.py index 700b8c05e6219..4c3b50142eaae 100644 --- a/pytorch_lightning/utilities/warnings.py +++ b/pytorch_lightning/utilities/warnings.py @@ -14,7 +14,7 @@ """Warning-related utilities""" import warnings from functools import partial -from typing import Any, Sequence, Tuple, Type, Union +from typing import Any, Type, Union from pytorch_lightning.utilities.distributed import rank_zero_only