From deb65737a60360ae3bc82f3b51ce81f8e9a275ea Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 28 Jan 2023 19:52:59 +0100 Subject: [PATCH] Preserve parent CallContext when inferring nested functions --- ChangeLog | 5 ++++- astroid/brain/brain_typing.py | 34 ------------------------------- astroid/context.py | 6 +++++- astroid/inference.py | 5 ++++- astroid/protocols.py | 2 +- tests/unittest_brain.py | 26 +++++++++++++++++++---- tests/unittest_inference_calls.py | 5 ++--- 7 files changed, 38 insertions(+), 45 deletions(-) diff --git a/ChangeLog b/ChangeLog index d7c430e130..5b1fc4dfc0 100644 --- a/ChangeLog +++ b/ChangeLog @@ -14,9 +14,12 @@ Release date: TBA * Fix issues with ``typing_extensions.TypeVar``. - * Fix ``ClassDef.fromlino`` for PyPy 3.8 (v7.3.11) if class is wrapped by a decorator. +* Preserve parent CallContext when inferring nested functions. + + Closes PyCQA/pylint#8074 + What's New in astroid 2.13.3? ============================= diff --git a/astroid/brain/brain_typing.py b/astroid/brain/brain_typing.py index b11bfa1965..6a13407222 100644 --- a/astroid/brain/brain_typing.py +++ b/astroid/brain/brain_typing.py @@ -28,7 +28,6 @@ Const, JoinedStr, Name, - NodeNG, Subscript, Tuple, ) @@ -380,36 +379,6 @@ def infer_special_alias( return iter([class_def]) -def _looks_like_typing_cast(node: Call) -> bool: - return isinstance(node, Call) and ( - isinstance(node.func, Name) - and node.func.name == "cast" - or isinstance(node.func, Attribute) - and node.func.attrname == "cast" - ) - - -def infer_typing_cast( - node: Call, ctx: context.InferenceContext | None = None -) -> Iterator[NodeNG]: - """Infer call to cast() returning same type as casted-from var.""" - if not isinstance(node.func, (Name, Attribute)): - raise UseInferenceDefault - - try: - func = next(node.func.infer(context=ctx)) - except (InferenceError, StopIteration) as exc: - raise UseInferenceDefault from exc - if ( - not isinstance(func, FunctionDef) - or func.qname() != "typing.cast" - or len(node.args) != 2 - ): - raise UseInferenceDefault - - return node.args[1].infer(context=ctx) - - AstroidManager().register_transform( Call, inference_tip(infer_typing_typevar_or_newtype), @@ -418,9 +387,6 @@ def infer_typing_cast( AstroidManager().register_transform( Subscript, inference_tip(infer_typing_attr), _looks_like_typing_subscript ) -AstroidManager().register_transform( - Call, inference_tip(infer_typing_cast), _looks_like_typing_cast -) if PY39_PLUS: AstroidManager().register_transform( diff --git a/astroid/context.py b/astroid/context.py index b469964805..81b02f11c4 100644 --- a/astroid/context.py +++ b/astroid/context.py @@ -161,13 +161,14 @@ def __str__(self) -> str: class CallContext: """Holds information for a call site.""" - __slots__ = ("args", "keywords", "callee") + __slots__ = ("args", "keywords", "callee", "parent_call_context") def __init__( self, args: list[NodeNG], keywords: list[Keyword] | None = None, callee: NodeNG | None = None, + parent_call_context: CallContext | None = None, ): self.args = args # Call positional arguments if keywords: @@ -176,6 +177,9 @@ def __init__( arg_value_pairs = [] self.keywords = arg_value_pairs # Call keyword arguments self.callee = callee # Function being called + self.parent_call_context = ( + parent_call_context # Parent CallContext for nested calls + ) def copy_context(context: InferenceContext | None) -> InferenceContext: diff --git a/astroid/inference.py b/astroid/inference.py index e8fec289fa..59bc4eca56 100644 --- a/astroid/inference.py +++ b/astroid/inference.py @@ -273,7 +273,10 @@ def infer_call( try: if hasattr(callee, "infer_call_result"): callcontext.callcontext = CallContext( - args=self.args, keywords=self.keywords, callee=callee + args=self.args, + keywords=self.keywords, + callee=callee, + parent_call_context=callcontext.callcontext, ) yield from callee.infer_call_result(caller=self, context=callcontext) except InferenceError: diff --git a/astroid/protocols.py b/astroid/protocols.py index 72549b7952..48f0cd0f09 100644 --- a/astroid/protocols.py +++ b/astroid/protocols.py @@ -470,7 +470,7 @@ def arguments_assigned_stmts( # reset call context/name callcontext = context.callcontext context = copy_context(context) - context.callcontext = None + context.callcontext = callcontext.parent_call_context args = arguments.CallSite(callcontext, context=context) return args.infer_argument(self.parent, node_name, context) return _arguments_infer_argname(self, node_name, context) diff --git a/tests/unittest_brain.py b/tests/unittest_brain.py index 3374556bcf..0ffccb3542 100644 --- a/tests/unittest_brain.py +++ b/tests/unittest_brain.py @@ -2132,8 +2132,7 @@ class A: pass b = 42 - a = cast(A, b) - a + cast(A, b) """ ) inferred = next(node.infer()) @@ -2148,14 +2147,33 @@ class A: pass b = 42 - a = typing.cast(A, b) - a + typing.cast(A, b) """ ) inferred = next(node.infer()) assert isinstance(inferred, nodes.Const) assert inferred.value == 42 + def test_typing_cast_multiple_inference_calls(self) -> None: + ast_nodes = builder.extract_node( + """ + from typing import TypeVar, cast + T = TypeVar("T") + def ident(var: T) -> T: + return cast(T, var) + + ident(2) #@ + ident("Hello") #@ + """ + ) + i0 = next(ast_nodes[0].infer()) + assert isinstance(i0, nodes.Const) + assert i0.value == 2 + + i1 = next(ast_nodes[1].infer()) + assert isinstance(i1, nodes.Const) + assert i1.value == "Hello" + @pytest.mark.skipif( not HAS_TYPING_EXTENSIONS, diff --git a/tests/unittest_inference_calls.py b/tests/unittest_inference_calls.py index 72afb9898c..84a611d3a4 100644 --- a/tests/unittest_inference_calls.py +++ b/tests/unittest_inference_calls.py @@ -146,8 +146,6 @@ def g(y): def test_inner_call_with_dynamic_argument() -> None: """Test function where return value is the result of a separate function call, with a dynamic value passed to the inner function. - - Currently, this is Uninferable. """ node = builder.extract_node( """ @@ -163,7 +161,8 @@ def g(y): assert isinstance(node, nodes.NodeNG) inferred = node.inferred() assert len(inferred) == 1 - assert inferred[0] is Uninferable + assert isinstance(inferred[0], nodes.Const) + assert inferred[0].value == 3 def test_method_const_instance_attr() -> None: