From 6cb1200c7e243b31e626c56f9e2a6d96804d380d Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 25 Jun 2021 13:39:56 +0200 Subject: [PATCH 1/5] Fix mypy for utilities.parsing --- pytorch_lightning/utilities/parsing.py | 89 ++++++++++++++++---------- 1 file changed, 55 insertions(+), 34 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index d498849ac1b1c..014aca4ec7f49 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -17,8 +17,11 @@ import types from argparse import Namespace from dataclasses import fields, is_dataclass -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Type, Union +from typing_extensions import Literal + +import pytorch_lightning as pl from pytorch_lightning.utilities.warnings import rank_zero_warn @@ -51,10 +54,10 @@ def str_to_bool(val: str) -> bool: >>> str_to_bool('FALSE') False """ - val = str_to_bool_or_str(val) - if isinstance(val, bool): - return val - raise ValueError(f'invalid truth value {val}') + val_converted = str_to_bool_or_str(val) + if isinstance(val_converted, bool): + return val_converted + raise ValueError(f'invalid truth value {val_converted}') def str_to_bool_or_int(val: str) -> Union[bool, int, str]: @@ -69,13 +72,13 @@ def str_to_bool_or_int(val: str) -> Union[bool, int, str]: >>> str_to_bool_or_int("abc") 'abc' """ - val = str_to_bool_or_str(val) - if isinstance(val, bool): - return val + val_converted = str_to_bool_or_str(val) + if isinstance(val_converted, bool): + return val_converted try: - return int(val) + return int(val_converted) except ValueError: - return val + return val_converted def is_picklable(obj: object) -> bool: @@ -88,7 +91,7 @@ def is_picklable(obj: object) -> bool: return False -def clean_namespace(hparams): +def clean_namespace(hparams: Union[Dict[str, Any], Namespace]) -> None: """Removes all unpicklable entries from hparams""" hparams_dict = hparams @@ -102,7 +105,7 @@ def clean_namespace(hparams): del hparams_dict[k] -def parse_class_init_keys(cls) -> Tuple[str, str, str]: +def parse_class_init_keys(cls: Type['pl.LightningModule']) -> Tuple[str, Optional[str], Optional[str]]: """Parse key words for standard self, *args and **kwargs >>> class Model(): @@ -118,10 +121,14 @@ def parse_class_init_keys(cls) -> Tuple[str, str, str]: # self is always first n_self = init_params[0].name - def _get_first_if_any(params, param_type): + def _get_first_if_any( + params: List[inspect.Parameter], + param_type: Literal[inspect._ParameterKind.VAR_POSITIONAL, inspect._ParameterKind.VAR_KEYWORD], + ) -> Optional[str]: for p in params: if p.kind == param_type: return p.name + return None n_args = _get_first_if_any(init_params, inspect.Parameter.VAR_POSITIONAL) n_kwargs = _get_first_if_any(init_params, inspect.Parameter.VAR_KEYWORD) @@ -129,7 +136,7 @@ def _get_first_if_any(params, param_type): return n_self, n_args, n_kwargs -def get_init_args(frame) -> dict: +def get_init_args(frame: types.FrameType) -> Dict[str, Any]: _, _, _, local_vars = inspect.getargvalues(frame) if '__class__' not in local_vars: return {} @@ -140,12 +147,17 @@ def get_init_args(frame) -> dict: exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') # only collect variables that appear in the signature local_args = {k: local_vars[k] for k in init_parameters.keys()} - local_args.update(local_args.get(kwargs_var, {})) + if kwargs_var: + local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} return local_args -def collect_init_args(frame, path_args: list, inside: bool = False) -> list: +def collect_init_args( + frame: types.FrameType, + path_args: List[Dict[str, Any]], + inside: bool = False, +) -> List[Dict[str, Any]]: """ Recursively collects the arguments passed to the child constructors in the inheritance tree. @@ -160,18 +172,22 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list: most specific class in the hierarchy. """ _, _, _, local_vars = inspect.getargvalues(frame) - if '__class__' in local_vars: - local_args = get_init_args(frame) - # recursive update - path_args.append(local_args) - return collect_init_args(frame.f_back, path_args, inside=True) - elif not inside: - return collect_init_args(frame.f_back, path_args, inside) + # frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy + if isinstance(frame.f_back, types.FrameType): + if '__class__' in local_vars: + local_args = get_init_args(frame) + # recursive update + path_args.append(local_args) + return collect_init_args(frame.f_back, path_args, inside=True) + elif not inside: + return collect_init_args(frame.f_back, path_args, inside) + else: + return path_args else: return path_args -def flatten_dict(source, result=None): +def flatten_dict(source: Dict[str, Any], result: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: if result is None: result = {} @@ -186,7 +202,7 @@ def flatten_dict(source, result=None): def save_hyperparameters( obj: Any, - *args, + *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None ) -> None: @@ -197,7 +213,12 @@ def save_hyperparameters( return if not frame: - frame = inspect.currentframe().f_back + current_frame = inspect.currentframe() + # inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if available + if current_frame: + frame = current_frame.f_back + if not isinstance(frame, types.FrameType): + raise AttributeError("There is no `frame` available while being required.") if is_dataclass(obj): init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} @@ -250,13 +271,13 @@ class AttributeDict(Dict): "my-key": 3.14 """ - def __getattr__(self, key): + def __getattr__(self, key: str) -> Optional[Any]: try: return self[key] except KeyError as exp: raise AttributeError(f'Missing attribute "{key}"') from exp - def __setattr__(self, key, val): + def __setattr__(self, key: str, val: Any) -> None: self[key] = val def __repr__(self): @@ -269,14 +290,14 @@ def __repr__(self): return out -def _lightning_get_all_attr_holders(model, attribute): +def _lightning_get_all_attr_holders(model: 'pl.LightningModule', attribute: str) -> List[Any]: """ Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """ trainer = getattr(model, 'trainer', None) - holders = [] + holders: List[Any] = [] # Check if attribute in model if hasattr(model, attribute): @@ -294,7 +315,7 @@ def _lightning_get_all_attr_holders(model, attribute): return holders -def _lightning_get_first_attr_holder(model, attribute): +def _lightning_get_first_attr_holder(model: 'pl.LightningModule', attribute: str) -> Optional[Any]: """ Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, @@ -307,7 +328,7 @@ def _lightning_get_first_attr_holder(model, attribute): return holders[-1] -def lightning_hasattr(model, attribute): +def lightning_hasattr(model: 'pl.LightningModule', attribute: str) -> bool: """ Special hasattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -315,7 +336,7 @@ def lightning_hasattr(model, attribute): return _lightning_get_first_attr_holder(model, attribute) is not None -def lightning_getattr(model, attribute): +def lightning_getattr(model: 'pl.LightningModule', attribute: str) -> Optional[Any]: """ Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. @@ -337,7 +358,7 @@ def lightning_getattr(model, attribute): return getattr(holder, attribute) -def lightning_setattr(model, attribute, value): +def lightning_setattr(model: 'pl.LightningModule', attribute: str, value: Any) -> None: """ Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. From 1e3e54c50b1b3f7a40e048cdf9002ee7a8b474d8 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 25 Jun 2021 17:01:33 +0200 Subject: [PATCH 2/5] Add an informative comment --- pytorch_lightning/utilities/parsing.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 014aca4ec7f49..03f3bb1fc45c5 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -147,6 +147,7 @@ def get_init_args(frame: types.FrameType) -> Dict[str, Any]: exclude_argnames = (*filtered_vars, '__class__', 'frame', 'frame_args') # only collect variables that appear in the signature local_args = {k: local_vars[k] for k in init_parameters.keys()} + # kwargs_var might be None => raised an error by mypy if kwargs_var: local_args.update(local_args.get(kwargs_var, {})) local_args = {k: v for k, v in local_args.items() if k not in exclude_argnames} From 90c6157c006c4a9a2779681b8dd2f791fdf7986f Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 25 Jun 2021 17:16:07 +0200 Subject: [PATCH 3/5] Disable ignoring mypy errors --- setup.cfg | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup.cfg b/setup.cfg index 74e02d932dc3c..78b68b74a90fd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -183,6 +183,8 @@ ignore_errors = True ignore_errors = True [mypy-pytorch_lightning.utilities.cli] ignore_errors = False +[mypy-pytorch_lightning.utilities.parsing] +ignore_errors = False # todo: add proper typing to this module... [mypy-pl_examples.*] From 1cb3287453e9cce3c4e40b7c296114ad88e8944f Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Fri, 25 Jun 2021 17:36:44 +0200 Subject: [PATCH 4/5] Add one missing annotation --- pytorch_lightning/utilities/parsing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 03f3bb1fc45c5..7a63aa8256379 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -281,7 +281,7 @@ def __getattr__(self, key: str) -> Optional[Any]: def __setattr__(self, key: str, val: Any) -> None: self[key] = val - def __repr__(self): + def __repr__(self) -> str: if not len(self): return "" max_key_length = max([len(str(k)) for k in self]) From 22892fb7b7e1b16431ae4aba9b45fd02cf413efb Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Mon, 28 Jun 2021 11:13:05 +0200 Subject: [PATCH 5/5] Incorporate awaelchli's suggestion --- pytorch_lightning/utilities/parsing.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 8d8c741f9ef2c..38f56078bfd02 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -177,16 +177,16 @@ def collect_init_args( """ _, _, _, local_vars = inspect.getargvalues(frame) # frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypy - if isinstance(frame.f_back, types.FrameType): - if '__class__' in local_vars: - local_args = get_init_args(frame) - # recursive update - path_args.append(local_args) - return collect_init_args(frame.f_back, path_args, inside=True) - elif not inside: - return collect_init_args(frame.f_back, path_args, inside) - else: - return path_args + if not isinstance(frame.f_back, types.FrameType): + return path_args + + if '__class__' in local_vars: + local_args = get_init_args(frame) + # recursive update + path_args.append(local_args) + return collect_init_args(frame.f_back, path_args, inside=True) + elif not inside: + return collect_init_args(frame.f_back, path_args, inside) else: return path_args