Skip to content

Commit fc4d1e5

Browse files
committed
Fix assert rewriting with assignment expressions
1 parent 0a06db0 commit fc4d1e5

File tree

4 files changed

+44
-11
lines changed

4 files changed

+44
-11
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ Maho
235235
Maik Figura
236236
Mandeep Bhutani
237237
Manuel Krebber
238+
Marc Mueller
238239
Marc Schlaich
239240
Marcelo Duarte Trevisani
240241
Marcin Bachry

changelog/11239.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed ``:=`` in asserts impacting unrelated test cases.

src/_pytest/assertion/rewrite.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414
import tokenize
1515
import types
16+
from collections import defaultdict
1617
from pathlib import Path
1718
from pathlib import PurePath
1819
from typing import Callable
@@ -52,6 +53,8 @@
5253
PYC_EXT = ".py" + (__debug__ and "c" or "o")
5354
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
5455

56+
_SCOPE_END_MARKER = object()
57+
5558

5659
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
5760
"""PEP302/PEP451 import hook which rewrites asserts."""
@@ -634,6 +637,8 @@ class AssertionRewriter(ast.NodeVisitor):
634637
.push_format_context() and .pop_format_context() which allows
635638
to build another %-formatted string while already building one.
636639
640+
:scope: A tuple containing the current scope used for variables_overwrite.
641+
637642
:variables_overwrite: A dict filled with references to variables
638643
that change value within an assert. This happens when a variable is
639644
reassigned with the walrus operator
@@ -655,7 +660,8 @@ def __init__(
655660
else:
656661
self.enable_assertion_pass_hook = False
657662
self.source = source
658-
self.variables_overwrite: Dict[str, str] = {}
663+
self.scope: tuple[ast.AST, ...] = ()
664+
self.variables_overwrite: defaultdict[tuple[ast.AST, ...], Dict[str, str]] = defaultdict(dict)
659665

660666
def run(self, mod: ast.Module) -> None:
661667
"""Find all assert statements in *mod* and rewrite them."""
@@ -719,9 +725,17 @@ def run(self, mod: ast.Module) -> None:
719725
mod.body[pos:pos] = imports
720726

721727
# Collect asserts.
722-
nodes: List[ast.AST] = [mod]
728+
self.scope = (mod,)
729+
nodes: List[Union[ast.AST, object]] = [mod]
723730
while nodes:
724731
node = nodes.pop()
732+
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
733+
self.scope = tuple((*self.scope, node))
734+
nodes.append(_SCOPE_END_MARKER)
735+
if node == _SCOPE_END_MARKER:
736+
self.scope = self.scope[:-1]
737+
continue
738+
assert isinstance(node, ast.AST)
725739
for name, field in ast.iter_fields(node):
726740
if isinstance(field, list):
727741
new: List[ast.AST] = []
@@ -992,7 +1006,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
9921006
]
9931007
):
9941008
pytest_temp = self.variable()
995-
self.variables_overwrite[
1009+
self.variables_overwrite[self.scope][
9961010
v.left.target.id
9971011
] = v.left # type:ignore[assignment]
9981012
v.left.target.id = pytest_temp
@@ -1035,17 +1049,17 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
10351049
new_args = []
10361050
new_kwargs = []
10371051
for arg in call.args:
1038-
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
1039-
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
1052+
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(self.scope, {}):
1053+
arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment]
10401054
res, expl = self.visit(arg)
10411055
arg_expls.append(expl)
10421056
new_args.append(res)
10431057
for keyword in call.keywords:
10441058
if (
10451059
isinstance(keyword.value, ast.Name)
1046-
and keyword.value.id in self.variables_overwrite
1060+
and keyword.value.id in self.variables_overwrite.get(self.scope, {})
10471061
):
1048-
keyword.value = self.variables_overwrite[
1062+
keyword.value = self.variables_overwrite[self.scope][
10491063
keyword.value.id
10501064
] # type:ignore[assignment]
10511065
res, expl = self.visit(keyword.value)
@@ -1081,12 +1095,12 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10811095
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
10821096
self.push_format_context()
10831097
# We first check if we have overwritten a variable in the previous assert
1084-
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
1085-
comp.left = self.variables_overwrite[
1098+
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
1099+
comp.left = self.variables_overwrite[self.scope][
10861100
comp.left.id
10871101
] # type:ignore[assignment]
10881102
if isinstance(comp.left, ast.NamedExpr):
1089-
self.variables_overwrite[
1103+
self.variables_overwrite[self.scope][
10901104
comp.left.target.id
10911105
] = comp.left # type:ignore[assignment]
10921106
left_res, left_expl = self.visit(comp.left)
@@ -1106,7 +1120,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
11061120
and next_operand.target.id == left_res.id
11071121
):
11081122
next_operand.target.id = self.variable()
1109-
self.variables_overwrite[
1123+
self.variables_overwrite[self.scope][
11101124
left_res.id
11111125
] = next_operand # type:ignore[assignment]
11121126
next_res, next_expl = self.visit(next_operand)

testing/test_assertrewrite.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,23 @@ def test_gt():
15431543
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])
15441544

15451545

1546+
class TestIssue11239:
1547+
def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None:
1548+
pytester.makepyfile(
1549+
"""
1550+
def test_1():
1551+
state = {"x": 2}.get("x")
1552+
assert state is not None
1553+
1554+
def test_2():
1555+
db = {"x": 2}
1556+
assert (state := db.get("x")) is not None
1557+
"""
1558+
)
1559+
result = pytester.runpytest()
1560+
assert result.ret == 0
1561+
1562+
15461563
@pytest.mark.skipif(
15471564
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
15481565
)

0 commit comments

Comments
 (0)