From a9c6ac104d69fe388094dffbb76865915f2f4750 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 8 Sep 2023 12:35:40 +0200 Subject: [PATCH 1/5] Fix assert rewriting with assignment expressions --- AUTHORS | 1 + changelog/11239.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 49 +++++++++++++++++++++++--------- testing/test_assertrewrite.py | 17 +++++++++++ 4 files changed, 54 insertions(+), 14 deletions(-) create mode 100644 changelog/11239.bugfix.rst diff --git a/AUTHORS b/AUTHORS index 466779f6d11..e9e033c73f0 100644 --- a/AUTHORS +++ b/AUTHORS @@ -235,6 +235,7 @@ Maho Maik Figura Mandeep Bhutani Manuel Krebber +Marc Mueller Marc Schlaich Marcelo Duarte Trevisani Marcin Bachry diff --git a/changelog/11239.bugfix.rst b/changelog/11239.bugfix.rst new file mode 100644 index 00000000000..a486224cdda --- /dev/null +++ b/changelog/11239.bugfix.rst @@ -0,0 +1 @@ +Fixed ``:=`` in asserts impacting unrelated test cases. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 9bf79f1e107..32e96253608 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -13,6 +13,7 @@ import sys import tokenize import types +from collections import defaultdict from pathlib import Path from pathlib import PurePath from typing import Callable @@ -52,6 +53,8 @@ PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT +_SCOPE_END_MARKER = object() + class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" @@ -634,6 +637,8 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. + :scope: A tuple containing the current scope used for variables_overwrite. + :variables_overwrite: A dict filled with references to variables that change value within an assert. This happens when a variable is reassigned with the walrus operator @@ -655,7 +660,10 @@ def __init__( else: self.enable_assertion_pass_hook = False self.source = source - self.variables_overwrite: Dict[str, str] = {} + self.scope: tuple[ast.AST, ...] = () + self.variables_overwrite: defaultdict[ + tuple[ast.AST, ...], Dict[str, str] + ] = defaultdict(dict) def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -719,9 +727,17 @@ def run(self, mod: ast.Module) -> None: mod.body[pos:pos] = imports # Collect asserts. - nodes: List[ast.AST] = [mod] + self.scope = (mod,) + nodes: List[Union[ast.AST, object]] = [mod] while nodes: node = nodes.pop() + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): + self.scope = tuple((*self.scope, node)) + nodes.append(_SCOPE_END_MARKER) + if node == _SCOPE_END_MARKER: + self.scope = self.scope[:-1] + continue + assert isinstance(node, ast.AST) for name, field in ast.iter_fields(node): if isinstance(field, list): new: List[ast.AST] = [] @@ -992,7 +1008,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: ] ): pytest_temp = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ v.left.target.id ] = v.left # type:ignore[assignment] v.left.target.id = pytest_temp @@ -1035,17 +1051,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]: new_args = [] new_kwargs = [] for arg in call.args: - if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite: - arg = self.variables_overwrite[arg.id] # type:ignore[assignment] + if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( + self.scope, {} + ): + arg = self.variables_overwrite[self.scope][ + arg.id + ] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - if ( - isinstance(keyword.value, ast.Name) - and keyword.value.id in self.variables_overwrite - ): - keyword.value = self.variables_overwrite[ + if isinstance( + keyword.value, ast.Name + ) and keyword.value.id in self.variables_overwrite.get(self.scope, {}): + keyword.value = self.variables_overwrite[self.scope][ keyword.value.id ] # type:ignore[assignment] res, expl = self.visit(keyword.value) @@ -1081,12 +1100,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() # We first check if we have overwritten a variable in the previous assert - if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: - comp.left = self.variables_overwrite[ + if isinstance( + comp.left, ast.Name + ) and comp.left.id in self.variables_overwrite.get(self.scope, {}): + comp.left = self.variables_overwrite[self.scope][ comp.left.id ] # type:ignore[assignment] if isinstance(comp.left, ast.NamedExpr): - self.variables_overwrite[ + self.variables_overwrite[self.scope][ comp.left.target.id ] = comp.left # type:ignore[assignment] left_res, left_expl = self.visit(comp.left) @@ -1106,7 +1127,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: and next_operand.target.id == left_res.id ): next_operand.target.id = self.variable() - self.variables_overwrite[ + self.variables_overwrite[self.scope][ left_res.id ] = next_operand # type:ignore[assignment] next_res, next_expl = self.visit(next_operand) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 08813c4dcf0..d69ca538c21 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1543,6 +1543,23 @@ def test_gt(): result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"]) +class TestIssue11239: + def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_1(): + state = {"x": 2}.get("x") + assert state is not None + + def test_2(): + db = {"x": 2} + assert (state := db.get("x")) is not None + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" ) From 3e95a370d0b3b5df949c176aae991ec2e822b579 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Fri, 8 Sep 2023 14:49:22 +0200 Subject: [PATCH 2/5] Update test docstring Co-authored-by: Bruno Oliveira --- testing/test_assertrewrite.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index d69ca538c21..b3fd0c2f2e7 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1545,6 +1545,10 @@ def test_gt(): class TestIssue11239: def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None: + """Regression for (#11239) + + Walrus operator rewriting would leak to separate test cases if they used the same variables. + """ pytester.makepyfile( """ def test_1(): From 35771f643d6e26d66202c3be3202a058af8125f9 Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Fri, 8 Sep 2023 18:44:12 -0300 Subject: [PATCH 3/5] Use an enum for sentinel --- src/_pytest/assertion/rewrite.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 32e96253608..f894f2be3a9 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -14,6 +14,7 @@ import tokenize import types from collections import defaultdict +from enum import Enum from pathlib import Path from pathlib import PurePath from typing import Callable @@ -53,7 +54,11 @@ PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT -_SCOPE_END_MARKER = object() + +class ScopeEndMarkerType(Enum): + """Special marker that denotes we have just left a function or class definition.""" + + ScopeEndMarker = 1 class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): @@ -728,13 +733,13 @@ def run(self, mod: ast.Module) -> None: # Collect asserts. self.scope = (mod,) - nodes: List[Union[ast.AST, object]] = [mod] + nodes: List[Union[ast.AST, ScopeEndMarkerType]] = [mod] while nodes: node = nodes.pop() if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): self.scope = tuple((*self.scope, node)) - nodes.append(_SCOPE_END_MARKER) - if node == _SCOPE_END_MARKER: + nodes.append(ScopeEndMarkerType.ScopeEndMarker) + if node is ScopeEndMarkerType.ScopeEndMarker: self.scope = self.scope[:-1] continue assert isinstance(node, ast.AST) From 1da54b45469f17ae360a49d8ae3a504642a4b9d9 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 9 Sep 2023 00:32:07 +0200 Subject: [PATCH 4/5] Revert "Use an enum for sentinel" This reverts commit 35771f643d6e26d66202c3be3202a058af8125f9. --- src/_pytest/assertion/rewrite.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index f894f2be3a9..32e96253608 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -14,7 +14,6 @@ import tokenize import types from collections import defaultdict -from enum import Enum from pathlib import Path from pathlib import PurePath from typing import Callable @@ -54,11 +53,7 @@ PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT - -class ScopeEndMarkerType(Enum): - """Special marker that denotes we have just left a function or class definition.""" - - ScopeEndMarker = 1 +_SCOPE_END_MARKER = object() class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): @@ -733,13 +728,13 @@ def run(self, mod: ast.Module) -> None: # Collect asserts. self.scope = (mod,) - nodes: List[Union[ast.AST, ScopeEndMarkerType]] = [mod] + nodes: List[Union[ast.AST, object]] = [mod] while nodes: node = nodes.pop() if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): self.scope = tuple((*self.scope, node)) - nodes.append(ScopeEndMarkerType.ScopeEndMarker) - if node is ScopeEndMarkerType.ScopeEndMarker: + nodes.append(_SCOPE_END_MARKER) + if node == _SCOPE_END_MARKER: self.scope = self.scope[:-1] continue assert isinstance(node, ast.AST) From d4f641a9a79927282ed7c546af6c8235a6809126 Mon Sep 17 00:00:00 2001 From: Marc Mueller <30130371+cdce8p@users.noreply.github.com> Date: Sat, 9 Sep 2023 00:39:53 +0200 Subject: [PATCH 5/5] Use sentinel value --- src/_pytest/assertion/rewrite.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 32e96253608..258ed9f9ab0 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -46,6 +46,10 @@ from _pytest.assertion import AssertionState +class Sentinel: + pass + + assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -53,7 +57,8 @@ PYC_EXT = ".py" + (__debug__ and "c" or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT -_SCOPE_END_MARKER = object() +# Special marker that denotes we have just left a scope definition +_SCOPE_END_MARKER = Sentinel() class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): @@ -728,7 +733,7 @@ def run(self, mod: ast.Module) -> None: # Collect asserts. self.scope = (mod,) - nodes: List[Union[ast.AST, object]] = [mod] + nodes: List[Union[ast.AST, Sentinel]] = [mod] while nodes: node = nodes.pop() if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):