Skip to content

Commit 85c34fd

Browse files
[7.2.x] Fix different behavior with unittest when warlus operator (#10803)
Co-authored-by: Alessio Izzo <[email protected]>
1 parent ec25744 commit 85c34fd

File tree

4 files changed

+223
-4
lines changed

4 files changed

+223
-4
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Adam Uhlir
1212
Ahn Ki-Wook
1313
Akiomi Kamakura
1414
Alan Velasco
15+
Alessio Izzo
1516
Alexander Johnson
1617
Alexander King
1718
Alexei Kozlenok

changelog/10743.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
The assertion rewriting mechanism now works correctly when assertion expressions contain the walrus operator.

src/_pytest/assertion/rewrite.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,13 @@
4444
if TYPE_CHECKING:
4545
from _pytest.assertion import AssertionState
4646

47+
if sys.version_info >= (3, 8):
48+
namedExpr = ast.NamedExpr
49+
else:
50+
namedExpr = ast.Expr
4751

48-
assertstate_key = StashKey["AssertionState"]()
4952

53+
assertstate_key = StashKey["AssertionState"]()
5054

5155
# pytest caches rewritten pycs in pycache dirs
5256
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
@@ -636,8 +640,12 @@ class AssertionRewriter(ast.NodeVisitor):
636640
.push_format_context() and .pop_format_context() which allows
637641
to build another %-formatted string while already building one.
638642
639-
This state is reset on every new assert statement visited and used
640-
by the other visitors.
643+
:variables_overwrite: A dict filled with references to variables
644+
that change value within an assert. This happens when a variable is
645+
reassigned with the walrus operator
646+
647+
This state, except the variables_overwrite, is reset on every new assert
648+
statement visited and used by the other visitors.
641649
"""
642650

643651
def __init__(
@@ -653,6 +661,7 @@ def __init__(
653661
else:
654662
self.enable_assertion_pass_hook = False
655663
self.source = source
664+
self.variables_overwrite: Dict[str, str] = {}
656665

657666
def run(self, mod: ast.Module) -> None:
658667
"""Find all assert statements in *mod* and rewrite them."""
@@ -667,7 +676,7 @@ def run(self, mod: ast.Module) -> None:
667676
if doc is not None and self.is_rewrite_disabled(doc):
668677
return
669678
pos = 0
670-
lineno = 1
679+
item = None
671680
for item in mod.body:
672681
if (
673682
expect_docstring
@@ -938,6 +947,18 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
938947
ast.copy_location(node, assert_)
939948
return self.statements
940949

950+
def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]:
951+
# This method handles the 'walrus operator' repr of the target
952+
# name if it's a local variable or _should_repr_global_name()
953+
# thinks it's acceptable.
954+
locs = ast.Call(self.builtin("locals"), [], [])
955+
target_id = name.target.id # type: ignore[attr-defined]
956+
inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs])
957+
dorepr = self.helper("_should_repr_global_name", name)
958+
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
959+
expr = ast.IfExp(test, self.display(name), ast.Str(target_id))
960+
return name, self.explanation_param(expr)
961+
941962
def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]:
942963
# Display the repr of the name if it's a local variable or
943964
# _should_repr_global_name() thinks it's acceptable.
@@ -964,6 +985,20 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
964985
# cond is set in a prior loop iteration below
965986
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
966987
self.expl_stmts = fail_inner
988+
# Check if the left operand is a namedExpr and the value has already been visited
989+
if (
990+
isinstance(v, ast.Compare)
991+
and isinstance(v.left, namedExpr)
992+
and v.left.target.id
993+
in [
994+
ast_expr.id
995+
for ast_expr in boolop.values[:i]
996+
if hasattr(ast_expr, "id")
997+
]
998+
):
999+
pytest_temp = self.variable()
1000+
self.variables_overwrite[v.left.target.id] = pytest_temp
1001+
v.left.target.id = pytest_temp
9671002
self.push_format_context()
9681003
res, expl = self.visit(v)
9691004
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
@@ -1039,6 +1074,9 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10391074

10401075
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
10411076
self.push_format_context()
1077+
# We first check if we have overwritten a variable in the previous assert
1078+
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
1079+
comp.left.id = self.variables_overwrite[comp.left.id]
10421080
left_res, left_expl = self.visit(comp.left)
10431081
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
10441082
left_expl = f"({left_expl})"
@@ -1050,6 +1088,13 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
10501088
syms = []
10511089
results = [left_res]
10521090
for i, op, next_operand in it:
1091+
if (
1092+
isinstance(next_operand, namedExpr)
1093+
and isinstance(left_res, ast.Name)
1094+
and next_operand.target.id == left_res.id
1095+
):
1096+
next_operand.target.id = self.variable()
1097+
self.variables_overwrite[left_res.id] = next_operand.target.id
10531098
next_res, next_expl = self.visit(next_operand)
10541099
if isinstance(next_operand, (ast.Compare, ast.BoolOp)):
10551100
next_expl = f"({next_expl})"
@@ -1073,6 +1118,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
10731118
res: ast.expr = ast.BoolOp(ast.And(), load_names)
10741119
else:
10751120
res = load_names[0]
1121+
10761122
return res, self.explanation_param(self.pop_format_context(expl_call))
10771123

10781124

testing/test_assertrewrite.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1265,6 +1265,177 @@ def test_simple_failure():
12651265
result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"])
12661266

12671267

1268+
@pytest.mark.skipif(
1269+
sys.version_info < (3, 8), reason="walrus operator not available in py<38"
1270+
)
1271+
class TestIssue10743:
1272+
def test_assertion_walrus_operator(self, pytester: Pytester) -> None:
1273+
pytester.makepyfile(
1274+
"""
1275+
def my_func(before, after):
1276+
return before == after
1277+
1278+
def change_value(value):
1279+
return value.lower()
1280+
1281+
def test_walrus_conversion():
1282+
a = "Hello"
1283+
assert not my_func(a, a := change_value(a))
1284+
assert a == "hello"
1285+
"""
1286+
)
1287+
result = pytester.runpytest()
1288+
assert result.ret == 0
1289+
1290+
def test_assertion_walrus_operator_dont_rewrite(self, pytester: Pytester) -> None:
1291+
pytester.makepyfile(
1292+
"""
1293+
'PYTEST_DONT_REWRITE'
1294+
def my_func(before, after):
1295+
return before == after
1296+
1297+
def change_value(value):
1298+
return value.lower()
1299+
1300+
def test_walrus_conversion_dont_rewrite():
1301+
a = "Hello"
1302+
assert not my_func(a, a := change_value(a))
1303+
assert a == "hello"
1304+
"""
1305+
)
1306+
result = pytester.runpytest()
1307+
assert result.ret == 0
1308+
1309+
def test_assertion_inline_walrus_operator(self, pytester: Pytester) -> None:
1310+
pytester.makepyfile(
1311+
"""
1312+
def my_func(before, after):
1313+
return before == after
1314+
1315+
def test_walrus_conversion_inline():
1316+
a = "Hello"
1317+
assert not my_func(a, a := a.lower())
1318+
assert a == "hello"
1319+
"""
1320+
)
1321+
result = pytester.runpytest()
1322+
assert result.ret == 0
1323+
1324+
def test_assertion_inline_walrus_operator_reverse(self, pytester: Pytester) -> None:
1325+
pytester.makepyfile(
1326+
"""
1327+
def my_func(before, after):
1328+
return before == after
1329+
1330+
def test_walrus_conversion_reverse():
1331+
a = "Hello"
1332+
assert my_func(a := a.lower(), a)
1333+
assert a == 'hello'
1334+
"""
1335+
)
1336+
result = pytester.runpytest()
1337+
assert result.ret == 0
1338+
1339+
def test_assertion_walrus_no_variable_name_conflict(
1340+
self, pytester: Pytester
1341+
) -> None:
1342+
pytester.makepyfile(
1343+
"""
1344+
def test_walrus_conversion_no_conflict():
1345+
a = "Hello"
1346+
assert a == (b := a.lower())
1347+
"""
1348+
)
1349+
result = pytester.runpytest()
1350+
assert result.ret == 1
1351+
result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"])
1352+
1353+
def test_assertion_walrus_operator_true_assertion_and_changes_variable_value(
1354+
self, pytester: Pytester
1355+
) -> None:
1356+
pytester.makepyfile(
1357+
"""
1358+
def test_walrus_conversion_succeed():
1359+
a = "Hello"
1360+
assert a != (a := a.lower())
1361+
assert a == 'hello'
1362+
"""
1363+
)
1364+
result = pytester.runpytest()
1365+
assert result.ret == 0
1366+
1367+
def test_assertion_walrus_operator_fail_assertion(self, pytester: Pytester) -> None:
1368+
pytester.makepyfile(
1369+
"""
1370+
def test_walrus_conversion_fails():
1371+
a = "Hello"
1372+
assert a == (a := a.lower())
1373+
"""
1374+
)
1375+
result = pytester.runpytest()
1376+
assert result.ret == 1
1377+
result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"])
1378+
1379+
def test_assertion_walrus_operator_boolean_composite(
1380+
self, pytester: Pytester
1381+
) -> None:
1382+
pytester.makepyfile(
1383+
"""
1384+
def test_walrus_operator_change_boolean_value():
1385+
a = True
1386+
assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None)
1387+
assert a is None
1388+
"""
1389+
)
1390+
result = pytester.runpytest()
1391+
assert result.ret == 0
1392+
1393+
def test_assertion_walrus_operator_compare_boolean_fails(
1394+
self, pytester: Pytester
1395+
) -> None:
1396+
pytester.makepyfile(
1397+
"""
1398+
def test_walrus_operator_change_boolean_value():
1399+
a = True
1400+
assert not (a and ((a := False) is False))
1401+
"""
1402+
)
1403+
result = pytester.runpytest()
1404+
assert result.ret == 1
1405+
result.stdout.fnmatch_lines(["*assert not (True and False is False)"])
1406+
1407+
def test_assertion_walrus_operator_boolean_none_fails(
1408+
self, pytester: Pytester
1409+
) -> None:
1410+
pytester.makepyfile(
1411+
"""
1412+
def test_walrus_operator_change_boolean_value():
1413+
a = True
1414+
assert not (a and ((a := None) is None))
1415+
"""
1416+
)
1417+
result = pytester.runpytest()
1418+
assert result.ret == 1
1419+
result.stdout.fnmatch_lines(["*assert not (True and None is None)"])
1420+
1421+
def test_assertion_walrus_operator_value_changes_cleared_after_each_test(
1422+
self, pytester: Pytester
1423+
) -> None:
1424+
pytester.makepyfile(
1425+
"""
1426+
def test_walrus_operator_change_value():
1427+
a = True
1428+
assert (a := None) is None
1429+
1430+
def test_walrus_operator_not_override_value():
1431+
a = True
1432+
assert a is True
1433+
"""
1434+
)
1435+
result = pytester.runpytest()
1436+
assert result.ret == 0
1437+
1438+
12681439
@pytest.mark.skipif(
12691440
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
12701441
)

0 commit comments

Comments
 (0)