From 948f5e83eaaf491211e18cf5dbcb8594d20ed1e2 Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Tue, 21 Feb 2023 00:02:15 +0100 Subject: [PATCH 01/10] add visit_NamedExpr in assertrewrite --- src/_pytest/assertion/rewrite.py | 16 +++- testing/test_assertrewrite.py | 154 +++++++++++++++++++++++++++++++ 2 files changed, 168 insertions(+), 2 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 42664add432..548a3fce3a4 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -666,7 +666,7 @@ def run(self, mod: ast.Module) -> None: if doc is not None and self.is_rewrite_disabled(doc): return pos = 0 - lineno = 1 + item = None for item in mod.body: if ( expect_docstring @@ -937,6 +937,17 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: ast.copy_location(node, assert_) return self.statements + def visit_NamedExpr(self, name: ast.NamedExpr) -> Tuple[ast.NamedExpr, str]: + # Display the repr of the target name if it's a local variable or + # _should_repr_global_name() thinks it's acceptable. + locs = ast.Call(self.builtin("locals"), [], []) + target_id = name.target.id # type: ignore[attr-defined] + inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs]) + dorepr = self.helper("_should_repr_global_name", name) + test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) + expr = ast.IfExp(test, self.display(name), ast.Str(target_id)) + return name, self.explanation_param(expr) + def visit_Name(self, name: ast.Name) -> Tuple[ast.Name, str]: # Display the repr of the name if it's a local variable or # _should_repr_global_name() thinks it's acceptable. @@ -1050,7 +1061,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: results = [left_res] for i, op, next_operand in it: next_res, next_expl = self.visit(next_operand) - if isinstance(next_operand, (ast.Compare, ast.BoolOp)): + if isinstance(next_operand, (ast.Compare, ast.BoolOp, ast.NamedExpr)): next_expl = f"({next_expl})" results.append(next_res) sym = BINOP_MAP[op.__class__] @@ -1072,6 +1083,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: res: ast.expr = ast.BoolOp(ast.And(), load_names) else: res = load_names[0] + return res, self.explanation_param(self.pop_format_context(expl_call)) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 3c98392ed98..89f2f8f6e0e 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1265,6 +1265,160 @@ def test_simple_failure(): result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"]) +class TestIssue10743: + def test_assertion_walrus_operator(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def change_value(value): + return value.lower() + + def test_walrus_conversion(): + a = "Hello" + assert not my_func(a, a := change_value(a)) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_dont_rewrite(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + 'PYTEST_DONT_REWRITE' + def my_func(before, after): + return before == after + + def change_value(value): + return value.lower() + + def test_walrus_conversion_dont_rewrite(): + a = "Hello" + assert not my_func(a, a := change_value(a)) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_inline_walrus_operator(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def test_walrus_conversion_inline(): + a = "Hello" + assert not my_func(a, a := a.lower()) + assert a == "hello" + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_inline_walrus_operator_reverse(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def my_func(before, after): + return before == after + + def test_walrus_conversion_reverse(): + a = "Hello" + assert my_func(a := a.lower(), a) + assert a == 'hello' + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_no_variable_name_conflict( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_no_conflict(): + a = "Hello" + assert a == (b := a.lower()) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"]) + + def test_assertion_walrus_operator_true_assertion_and_changes_variable_value( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_succeed(): + a = "Hello" + assert a != (a := a.lower()) + assert a == 'hello' + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_fail_assertion(self, pytester: Pytester) -> None: + pytester.makepyfile( + """ + def test_walrus_conversion_fails(): + a = "Hello" + assert a == (a := a.lower()) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + # This is not optimal as error message but it depends on how the rewrite is structured + result.stdout.fnmatch_lines(["*AssertionError: assert 'hello' == 'hello'"]) + + def test_assertion_walrus_operator_boolean_composite( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None) + + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_assertion_walrus_operator_compare_boolean_fails( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert not (a and ((a := False) is False)) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + # This is not optimal as error message but it depends on how the rewrite is structured + result.stdout.fnmatch_lines(["*assert not (False)"]) + + def test_assertion_walrus_operator_boolean_none_fails( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_boolean_value(): + a = True + assert not (a and ((a := None) is None)) + """ + ) + result = pytester.runpytest() + assert result.ret == 1 + # This is not optimal as error message but it depends on how the rewrite is structured + result.stdout.fnmatch_lines(["*assert not (None)"]) + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" ) From 0e6d6161f6102c0c5f3e8a5e1074ce9b4eefb3da Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Wed, 22 Feb 2023 21:58:00 +0100 Subject: [PATCH 02/10] add changelog --- AUTHORS | 1 + changelog/10743.bugfix.rst | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog/10743.bugfix.rst diff --git a/AUTHORS b/AUTHORS index 0395feceb60..448b71d3a8a 100644 --- a/AUTHORS +++ b/AUTHORS @@ -12,6 +12,7 @@ Adam Uhlir Ahn Ki-Wook Akiomi Kamakura Alan Velasco +Alessio Izzo Alexander Johnson Alexander King Alexei Kozlenok diff --git a/changelog/10743.bugfix.rst b/changelog/10743.bugfix.rst new file mode 100644 index 00000000000..943d27ee25c --- /dev/null +++ b/changelog/10743.bugfix.rst @@ -0,0 +1 @@ +Fixed different behavior from std lib unittest of asserts with expression that contains the walrus operator in it that changes the value of a variable From 3e90bf573f4415e6b34b9187f625ef265d634aeb Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Wed, 22 Feb 2023 22:28:18 +0100 Subject: [PATCH 03/10] add version check for py<38 --- changelog/10743.bugfix.rst | 2 +- src/_pytest/assertion/rewrite.py | 9 +++++++-- testing/test_assertrewrite.py | 3 +++ 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/changelog/10743.bugfix.rst b/changelog/10743.bugfix.rst index 943d27ee25c..db8eb73b072 100644 --- a/changelog/10743.bugfix.rst +++ b/changelog/10743.bugfix.rst @@ -1 +1 @@ -Fixed different behavior from std lib unittest of asserts with expression that contains the walrus operator in it that changes the value of a variable +Fixed different behavior from std lib unittest of asserts with expression that contains the walrus operator in it that changes the value of a variable. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 548a3fce3a4..2f011bba824 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -44,6 +44,11 @@ if TYPE_CHECKING: from _pytest.assertion import AssertionState +if sys.version_info > (3, 8): + namedExpr = ast.NamedExpr +else: + namedExpr = ast.Expr + assertstate_key = StashKey["AssertionState"]() @@ -937,7 +942,7 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: ast.copy_location(node, assert_) return self.statements - def visit_NamedExpr(self, name: ast.NamedExpr) -> Tuple[ast.NamedExpr, str]: + def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]: # Display the repr of the target name if it's a local variable or # _should_repr_global_name() thinks it's acceptable. locs = ast.Call(self.builtin("locals"), [], []) @@ -1061,7 +1066,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: results = [left_res] for i, op, next_operand in it: next_res, next_expl = self.visit(next_operand) - if isinstance(next_operand, (ast.Compare, ast.BoolOp, ast.NamedExpr)): + if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = f"({next_expl})" results.append(next_res) sym = BINOP_MAP[op.__class__] diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 89f2f8f6e0e..624947d7416 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1265,6 +1265,9 @@ def test_simple_failure(): result.stdout.fnmatch_lines(["*E*assert (1 + 1) == 3"]) +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="walrus operator not available in py<38" +) class TestIssue10743: def test_assertion_walrus_operator(self, pytester: Pytester) -> None: pytester.makepyfile( From 78f3a963cc063d6a3d1646af08a94b172b7ea512 Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Wed, 1 Mar 2023 00:51:37 +0100 Subject: [PATCH 04/10] add check and replace on already visited variables that have been changing by the walrus operator --- src/_pytest/assertion/rewrite.py | 34 ++++++++++++++++++++++++++++++-- testing/test_assertrewrite.py | 11 ++++------- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 2f011bba824..61f2ebf254b 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -657,6 +657,7 @@ def __init__( else: self.enable_assertion_pass_hook = False self.source = source + self.overwrite: Dict[str, str] = {} def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -943,8 +944,8 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: return self.statements def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]: - # Display the repr of the target name if it's a local variable or - # _should_repr_global_name() thinks it's acceptable. + # This method handles the 'walrus operator' repr of the target + # name if it's a local variable or _should_repr_global_name() thinks it's acceptable. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id # type: ignore[attr-defined] inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs]) @@ -973,12 +974,32 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: levels = len(boolop.values) - 1 self.push_format_context() # Process each operand, short-circuiting if needed. + pytest_temp = None for i, v in enumerate(boolop.values): if i: fail_inner: List[ast.stmt] = [] # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts = fail_inner + if isinstance(v, ast.Compare): + if isinstance(v.left, ast.NamedExpr) and ( + v.left.target.id + in [ + ast_expr.id + for ast_expr in boolop.values[:i] + if hasattr(ast_expr, "id") + ] + or v.left.target.id == pytest_temp + ): + pytest_temp = f"pytest_{v.left.target.id}_temp" + self.overwrite[v.left.target.id] = pytest_temp + v.left.target.id = pytest_temp + + elif isinstance(v.left, ast.Name) and ( + pytest_temp is not None + and v.left.id == pytest_temp.lstrip("pytest_").rstrip("_temp") + ): + v.left.id = pytest_temp self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1054,6 +1075,8 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() + if isinstance(comp.left, ast.Name) and comp.left.id in self.overwrite: + comp.left.id = self.overwrite[comp.left.id] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): left_expl = f"({left_expl})" @@ -1065,6 +1088,13 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: syms = [] results = [left_res] for i, op, next_operand in it: + if ( + isinstance(next_operand, ast.NamedExpr) + and isinstance(left_res, ast.Name) + and next_operand.target.id == left_res.id + ): + next_operand.target.id = f"pytest_{left_res.id}_temp" + self.overwrite[left_res.id] = next_operand.target.id next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = f"({next_expl})" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 624947d7416..fcede242f39 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1374,8 +1374,7 @@ def test_walrus_conversion_fails(): ) result = pytester.runpytest() assert result.ret == 1 - # This is not optimal as error message but it depends on how the rewrite is structured - result.stdout.fnmatch_lines(["*AssertionError: assert 'hello' == 'hello'"]) + result.stdout.fnmatch_lines(["*AssertionError: assert 'Hello' == 'hello'"]) def test_assertion_walrus_operator_boolean_composite( self, pytester: Pytester @@ -1385,7 +1384,7 @@ def test_assertion_walrus_operator_boolean_composite( def test_walrus_operator_change_boolean_value(): a = True assert a and True and ((a := False) is False) and (a is False) and ((a := None) is None) - + assert a is None """ ) result = pytester.runpytest() @@ -1403,8 +1402,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - # This is not optimal as error message but it depends on how the rewrite is structured - result.stdout.fnmatch_lines(["*assert not (False)"]) + result.stdout.fnmatch_lines(["*assert not (True and False is False)"]) def test_assertion_walrus_operator_boolean_none_fails( self, pytester: Pytester @@ -1418,8 +1416,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - # This is not optimal as error message but it depends on how the rewrite is structured - result.stdout.fnmatch_lines(["*assert not (None)"]) + result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) @pytest.mark.skipif( From 52f818d2bec133cfdd77261e8415aeeed53811a3 Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Wed, 1 Mar 2023 01:06:43 +0100 Subject: [PATCH 05/10] add test on variables that have been overwritten are cleared after each test --- src/_pytest/assertion/rewrite.py | 14 +++++++------- testing/test_assertrewrite.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 61f2ebf254b..ded658ba25f 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -657,7 +657,7 @@ def __init__( else: self.enable_assertion_pass_hook = False self.source = source - self.overwrite: Dict[str, str] = {} + self.variables_overwrite: Dict[str, str] = {} def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -982,7 +982,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts = fail_inner if isinstance(v, ast.Compare): - if isinstance(v.left, ast.NamedExpr) and ( + if isinstance(v.left, namedExpr) and ( v.left.target.id in [ ast_expr.id @@ -992,7 +992,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: or v.left.target.id == pytest_temp ): pytest_temp = f"pytest_{v.left.target.id}_temp" - self.overwrite[v.left.target.id] = pytest_temp + self.variables_overwrite[v.left.target.id] = pytest_temp v.left.target.id = pytest_temp elif isinstance(v.left, ast.Name) and ( @@ -1075,8 +1075,8 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() - if isinstance(comp.left, ast.Name) and comp.left.id in self.overwrite: - comp.left.id = self.overwrite[comp.left.id] + if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: + comp.left.id = self.variables_overwrite[comp.left.id] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, (ast.Compare, ast.BoolOp)): left_expl = f"({left_expl})" @@ -1089,12 +1089,12 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: results = [left_res] for i, op, next_operand in it: if ( - isinstance(next_operand, ast.NamedExpr) + isinstance(next_operand, namedExpr) and isinstance(left_res, ast.Name) and next_operand.target.id == left_res.id ): next_operand.target.id = f"pytest_{left_res.id}_temp" - self.overwrite[left_res.id] = next_operand.target.id + self.variables_overwrite[left_res.id] = next_operand.target.id next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): next_expl = f"({next_expl})" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index fcede242f39..8d944140307 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1418,6 +1418,23 @@ def test_walrus_operator_change_boolean_value(): assert result.ret == 1 result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) + def test_assertion_walrus_operator_value_changes_cleared_after_each_test( + self, pytester: Pytester + ) -> None: + pytester.makepyfile( + """ + def test_walrus_operator_change_value(): + a = True + assert (a := None) is None + + def test_walrus_operator_not_override_value(): + a = True + assert a is True + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" From 0bc4bdc063b3b35ea13ae24ddc43c89dbbe4fcf8 Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Wed, 1 Mar 2023 11:27:41 +0100 Subject: [PATCH 06/10] refactor trying to clean the code and add comments where conditions on instances of walrus operator --- src/_pytest/assertion/rewrite.py | 29 ++++++++++++++--------------- src/_pytest/assertion/util.py | 4 ++++ 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index ded658ba25f..96ac0d9bff8 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -52,7 +52,6 @@ assertstate_key = StashKey["AssertionState"]() - # pytest caches rewritten pycs in pycache dirs PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}" PYC_EXT = ".py" + (__debug__ and "c" or "o") @@ -945,7 +944,8 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]: def visit_NamedExpr(self, name: namedExpr) -> Tuple[namedExpr, str]: # This method handles the 'walrus operator' repr of the target - # name if it's a local variable or _should_repr_global_name() thinks it's acceptable. + # name if it's a local variable or _should_repr_global_name() + # thinks it's acceptable. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id # type: ignore[attr-defined] inlocs = ast.Compare(ast.Str(target_id), [ast.In()], [locs]) @@ -981,8 +981,11 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: # cond is set in a prior loop iteration below self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa self.expl_stmts = fail_inner - if isinstance(v, ast.Compare): - if isinstance(v.left, namedExpr) and ( + # Check if the left operand is a namedExpr and the value has already been visited + if ( + isinstance(v, ast.Compare) + and isinstance(v.left, namedExpr) + and ( v.left.target.id in [ ast_expr.id @@ -990,16 +993,11 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: if hasattr(ast_expr, "id") ] or v.left.target.id == pytest_temp - ): - pytest_temp = f"pytest_{v.left.target.id}_temp" - self.variables_overwrite[v.left.target.id] = pytest_temp - v.left.target.id = pytest_temp - - elif isinstance(v.left, ast.Name) and ( - pytest_temp is not None - and v.left.id == pytest_temp.lstrip("pytest_").rstrip("_temp") - ): - v.left.id = pytest_temp + ) + ): + pytest_temp = util.compose_temp_variable(v.left.target.id) + self.variables_overwrite[v.left.target.id] = pytest_temp + v.left.target.id = pytest_temp self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1075,6 +1073,7 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: self.push_format_context() + # We first check if we have overwritten a variable in the previous assert if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite: comp.left.id = self.variables_overwrite[comp.left.id] left_res, left_expl = self.visit(comp.left) @@ -1093,7 +1092,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: and isinstance(left_res, ast.Name) and next_operand.target.id == left_res.id ): - next_operand.target.id = f"pytest_{left_res.id}_temp" + next_operand.target.id = util.compose_temp_variable(left_res.id) self.variables_overwrite[left_res.id] = next_operand.target.id next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index fc5dfdbd5ba..e5bf201c7eb 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -520,3 +520,7 @@ def running_on_ci() -> bool: """Check if we're currently running on a CI system.""" env_vars = ["CI", "BUILD_NUMBER"] return any(var in os.environ for var in env_vars) + + +def compose_temp_variable(original_variable: str) -> str: + return f"pytest_{original_variable}_temp" From ea73f4a1d2e6db96b102f1866b89f88e83215e99 Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Thu, 2 Mar 2023 22:06:14 +0100 Subject: [PATCH 07/10] refactor using self.variable --- src/_pytest/assertion/rewrite.py | 28 ++++++++++++++-------------- src/_pytest/assertion/util.py | 4 ---- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 96ac0d9bff8..14e6a2d5248 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -639,8 +639,12 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. - This state is reset on every new assert statement visited and used - by the other visitors. + :variables_overwrite: A dict filled with references to variables + that change value within an assert. This happens when a variable is + reassigned with the walrus operator + + This state, except the variables_overwrite, is reset on every new assert + statement visited and used by the other visitors. """ def __init__( @@ -974,7 +978,6 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: levels = len(boolop.values) - 1 self.push_format_context() # Process each operand, short-circuiting if needed. - pytest_temp = None for i, v in enumerate(boolop.values): if i: fail_inner: List[ast.stmt] = [] @@ -985,17 +988,14 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]: if ( isinstance(v, ast.Compare) and isinstance(v.left, namedExpr) - and ( - v.left.target.id - in [ - ast_expr.id - for ast_expr in boolop.values[:i] - if hasattr(ast_expr, "id") - ] - or v.left.target.id == pytest_temp - ) + and v.left.target.id + in [ + ast_expr.id + for ast_expr in boolop.values[:i] + if hasattr(ast_expr, "id") + ] ): - pytest_temp = util.compose_temp_variable(v.left.target.id) + pytest_temp = self.variable() self.variables_overwrite[v.left.target.id] = pytest_temp v.left.target.id = pytest_temp self.push_format_context() @@ -1092,7 +1092,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]: and isinstance(left_res, ast.Name) and next_operand.target.id == left_res.id ): - next_operand.target.id = util.compose_temp_variable(left_res.id) + next_operand.target.id = self.variable() self.variables_overwrite[left_res.id] = next_operand.target.id next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, (ast.Compare, ast.BoolOp)): diff --git a/src/_pytest/assertion/util.py b/src/_pytest/assertion/util.py index e5bf201c7eb..fc5dfdbd5ba 100644 --- a/src/_pytest/assertion/util.py +++ b/src/_pytest/assertion/util.py @@ -520,7 +520,3 @@ def running_on_ci() -> bool: """Check if we're currently running on a CI system.""" env_vars = ["CI", "BUILD_NUMBER"] return any(var in os.environ for var in env_vars) - - -def compose_temp_variable(original_variable: str) -> str: - return f"pytest_{original_variable}_temp" From c8c85976334009c69a4275025ba589adc9b687d5 Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Thu, 2 Mar 2023 22:53:22 +0100 Subject: [PATCH 08/10] fix usage of NamedExpr if python_version >= 3.8 --- src/_pytest/assertion/rewrite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 14e6a2d5248..8b182347052 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -44,7 +44,7 @@ if TYPE_CHECKING: from _pytest.assertion import AssertionState -if sys.version_info > (3, 8): +if sys.version_info >= (3, 8): namedExpr = ast.NamedExpr else: namedExpr = ast.Expr From 0bd69f4e5c09c2edc619a68679ab72bb000c9e2f Mon Sep 17 00:00:00 2001 From: Alessio Izzo Date: Tue, 7 Mar 2023 23:09:51 +0100 Subject: [PATCH 09/10] fix changelog after review --- changelog/10743.bugfix.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog/10743.bugfix.rst b/changelog/10743.bugfix.rst index db8eb73b072..3fafd3a1121 100644 --- a/changelog/10743.bugfix.rst +++ b/changelog/10743.bugfix.rst @@ -1 +1 @@ -Fixed different behavior from std lib unittest of asserts with expression that contains the walrus operator in it that changes the value of a variable. +The assertion rewriting mechanism now works correctly when assertion expressions contain the walrus operator. \ No newline at end of file From ea18dd80bcf0b8d332aa395822eefdfbcb5b060f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Mar 2023 22:11:58 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- changelog/10743.bugfix.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog/10743.bugfix.rst b/changelog/10743.bugfix.rst index 3fafd3a1121..ad5c63e80ee 100644 --- a/changelog/10743.bugfix.rst +++ b/changelog/10743.bugfix.rst @@ -1 +1 @@ -The assertion rewriting mechanism now works correctly when assertion expressions contain the walrus operator. \ No newline at end of file +The assertion rewriting mechanism now works correctly when assertion expressions contain the walrus operator.