1313import sys
1414import tokenize
1515import types
16+ from collections import defaultdict
1617from pathlib import Path
1718from pathlib import PurePath
1819from typing import Callable
5253PYC_EXT = ".py" + (__debug__ and "c" or "o" )
5354PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
5455
56+ _SCOPE_END_MARKER = object ()
57+
5558
5659class AssertionRewritingHook (importlib .abc .MetaPathFinder , importlib .abc .Loader ):
5760 """PEP302/PEP451 import hook which rewrites asserts."""
@@ -634,6 +637,8 @@ class AssertionRewriter(ast.NodeVisitor):
634637 .push_format_context() and .pop_format_context() which allows
635638 to build another %-formatted string while already building one.
636639
640+ :scope: A tuple containing the current scope used for variables_overwrite.
641+
637642 :variables_overwrite: A dict filled with references to variables
638643 that change value within an assert. This happens when a variable is
639644 reassigned with the walrus operator
@@ -655,7 +660,10 @@ def __init__(
655660 else :
656661 self .enable_assertion_pass_hook = False
657662 self .source = source
658- self .variables_overwrite : Dict [str , str ] = {}
663+ self .scope : tuple [ast .AST , ...] = ()
664+ self .variables_overwrite : defaultdict [
665+ tuple [ast .AST , ...], Dict [str , str ]
666+ ] = defaultdict (dict )
659667
660668 def run (self , mod : ast .Module ) -> None :
661669 """Find all assert statements in *mod* and rewrite them."""
@@ -719,9 +727,17 @@ def run(self, mod: ast.Module) -> None:
719727 mod .body [pos :pos ] = imports
720728
721729 # Collect asserts.
722- nodes : List [ast .AST ] = [mod ]
730+ self .scope = (mod ,)
731+ nodes : List [Union [ast .AST , object ]] = [mod ]
723732 while nodes :
724733 node = nodes .pop ()
734+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
735+ self .scope = tuple ((* self .scope , node ))
736+ nodes .append (_SCOPE_END_MARKER )
737+ if node == _SCOPE_END_MARKER :
738+ self .scope = self .scope [:- 1 ]
739+ continue
740+ assert isinstance (node , ast .AST )
725741 for name , field in ast .iter_fields (node ):
726742 if isinstance (field , list ):
727743 new : List [ast .AST ] = []
@@ -992,7 +1008,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
9921008 ]
9931009 ):
9941010 pytest_temp = self .variable ()
995- self .variables_overwrite [
1011+ self .variables_overwrite [self . scope ][
9961012 v .left .target .id
9971013 ] = v .left # type:ignore[assignment]
9981014 v .left .target .id = pytest_temp
@@ -1035,17 +1051,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
10351051 new_args = []
10361052 new_kwargs = []
10371053 for arg in call .args :
1038- if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite :
1039- arg = self .variables_overwrite [arg .id ] # type:ignore[assignment]
1054+ if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite .get (
1055+ self .scope , {}
1056+ ):
1057+ arg = self .variables_overwrite [self .scope ][
1058+ arg .id
1059+ ] # type:ignore[assignment]
10401060 res , expl = self .visit (arg )
10411061 arg_expls .append (expl )
10421062 new_args .append (res )
10431063 for keyword in call .keywords :
1044- if (
1045- isinstance (keyword .value , ast .Name )
1046- and keyword .value .id in self .variables_overwrite
1047- ):
1048- keyword .value = self .variables_overwrite [
1064+ if isinstance (
1065+ keyword .value , ast .Name
1066+ ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1067+ keyword .value = self .variables_overwrite [self .scope ][
10491068 keyword .value .id
10501069 ] # type:ignore[assignment]
10511070 res , expl = self .visit (keyword .value )
@@ -1081,12 +1100,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10811100 def visit_Compare (self , comp : ast .Compare ) -> Tuple [ast .expr , str ]:
10821101 self .push_format_context ()
10831102 # We first check if we have overwritten a variable in the previous assert
1084- if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite :
1085- comp .left = self .variables_overwrite [
1103+ if isinstance (
1104+ comp .left , ast .Name
1105+ ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1106+ comp .left = self .variables_overwrite [self .scope ][
10861107 comp .left .id
10871108 ] # type:ignore[assignment]
10881109 if isinstance (comp .left , ast .NamedExpr ):
1089- self .variables_overwrite [
1110+ self .variables_overwrite [self . scope ][
10901111 comp .left .target .id
10911112 ] = comp .left # type:ignore[assignment]
10921113 left_res , left_expl = self .visit (comp .left )
@@ -1106,7 +1127,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
11061127 and next_operand .target .id == left_res .id
11071128 ):
11081129 next_operand .target .id = self .variable ()
1109- self .variables_overwrite [
1130+ self .variables_overwrite [self . scope ][
11101131 left_res .id
11111132 ] = next_operand # type:ignore[assignment]
11121133 next_res , next_expl = self .visit (next_operand )
0 commit comments