1313import sys
1414import tokenize
1515import types
16+ from collections import defaultdict
1617from pathlib import Path
1718from pathlib import PurePath
1819from typing import Callable
5253PYC_EXT = ".py" + (__debug__ and "c" or "o" )
5354PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
5455
56+ _SCOPE_END_MARKER = object ()
57+
5558
5659class AssertionRewritingHook (importlib .abc .MetaPathFinder , importlib .abc .Loader ):
5760 """PEP302/PEP451 import hook which rewrites asserts."""
@@ -634,6 +637,8 @@ class AssertionRewriter(ast.NodeVisitor):
634637 .push_format_context() and .pop_format_context() which allows
635638 to build another %-formatted string while already building one.
636639
640+ :scope: A tuple containing the current scope used for variables_overwrite.
641+
637642 :variables_overwrite: A dict filled with references to variables
638643 that change value within an assert. This happens when a variable is
639644 reassigned with the walrus operator
@@ -655,7 +660,8 @@ def __init__(
655660 else :
656661 self .enable_assertion_pass_hook = False
657662 self .source = source
658- self .variables_overwrite : Dict [str , str ] = {}
663+ self .scope : tuple [ast .AST , ...] = ()
664+ self .variables_overwrite : defaultdict [tuple [ast .AST , ...], Dict [str , str ]] = defaultdict (dict )
659665
660666 def run (self , mod : ast .Module ) -> None :
661667 """Find all assert statements in *mod* and rewrite them."""
@@ -719,9 +725,17 @@ def run(self, mod: ast.Module) -> None:
719725 mod .body [pos :pos ] = imports
720726
721727 # Collect asserts.
722- nodes : List [ast .AST ] = [mod ]
728+ self .scope = (mod ,)
729+ nodes : List [Union [ast .AST , object ]] = [mod ]
723730 while nodes :
724731 node = nodes .pop ()
732+ if isinstance (node , (ast .FunctionDef , ast .ClassDef )):
733+ self .scope = tuple ((* self .scope , node ))
734+ nodes .append (_SCOPE_END_MARKER )
735+ if node == _SCOPE_END_MARKER :
736+ self .scope = self .scope [:- 1 ]
737+ continue
738+ assert isinstance (node , ast .AST )
725739 for name , field in ast .iter_fields (node ):
726740 if isinstance (field , list ):
727741 new : List [ast .AST ] = []
@@ -992,7 +1006,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
9921006 ]
9931007 ):
9941008 pytest_temp = self .variable ()
995- self .variables_overwrite [
1009+ self .variables_overwrite [self . scope ][
9961010 v .left .target .id
9971011 ] = v .left # type:ignore[assignment]
9981012 v .left .target .id = pytest_temp
@@ -1035,17 +1049,17 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
10351049 new_args = []
10361050 new_kwargs = []
10371051 for arg in call .args :
1038- if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite :
1039- arg = self .variables_overwrite [arg .id ] # type:ignore[assignment]
1052+ if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite . get ( self . scope , {}) :
1053+ arg = self .variables_overwrite [self . scope ][ arg .id ] # type:ignore[assignment]
10401054 res , expl = self .visit (arg )
10411055 arg_expls .append (expl )
10421056 new_args .append (res )
10431057 for keyword in call .keywords :
10441058 if (
10451059 isinstance (keyword .value , ast .Name )
1046- and keyword .value .id in self .variables_overwrite
1060+ and keyword .value .id in self .variables_overwrite . get ( self . scope , {})
10471061 ):
1048- keyword .value = self .variables_overwrite [
1062+ keyword .value = self .variables_overwrite [self . scope ][
10491063 keyword .value .id
10501064 ] # type:ignore[assignment]
10511065 res , expl = self .visit (keyword .value )
@@ -1081,12 +1095,12 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10811095 def visit_Compare (self , comp : ast .Compare ) -> Tuple [ast .expr , str ]:
10821096 self .push_format_context ()
10831097 # We first check if we have overwritten a variable in the previous assert
1084- if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite :
1085- comp .left = self .variables_overwrite [
1098+ if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite . get ( self . scope , {}) :
1099+ comp .left = self .variables_overwrite [self . scope ][
10861100 comp .left .id
10871101 ] # type:ignore[assignment]
10881102 if isinstance (comp .left , ast .NamedExpr ):
1089- self .variables_overwrite [
1103+ self .variables_overwrite [self . scope ][
10901104 comp .left .target .id
10911105 ] = comp .left # type:ignore[assignment]
10921106 left_res , left_expl = self .visit (comp .left )
@@ -1106,7 +1120,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
11061120 and next_operand .target .id == left_res .id
11071121 ):
11081122 next_operand .target .id = self .variable ()
1109- self .variables_overwrite [
1123+ self .variables_overwrite [self . scope ][
11101124 left_res .id
11111125 ] = next_operand # type:ignore[assignment]
11121126 next_res , next_expl = self .visit (next_operand )
0 commit comments