1313import sys
1414import tokenize
1515import types
16+ from collections import defaultdict
1617from pathlib import Path
1718from pathlib import PurePath
1819from typing import Callable
4546 from _pytest .assertion import AssertionState
4647
4748
49+ class Sentinel :
50+ pass
51+
52+
4853assertstate_key = StashKey ["AssertionState" ]()
4954
5055# pytest caches rewritten pycs in pycache dirs
5156PYTEST_TAG = f"{ sys .implementation .cache_tag } -pytest-{ version } "
5257PYC_EXT = ".py" + (__debug__ and "c" or "o" )
5358PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
5459
60+ # Special marker that denotes we have just left a scope definition
61+ _SCOPE_END_MARKER = Sentinel ()
62+
5563
5664class AssertionRewritingHook (importlib .abc .MetaPathFinder , importlib .abc .Loader ):
5765 """PEP302/PEP451 import hook which rewrites asserts."""
@@ -634,6 +642,8 @@ class AssertionRewriter(ast.NodeVisitor):
634642 .push_format_context() and .pop_format_context() which allows
635643 to build another %-formatted string while already building one.
636644
645+ :scope: A tuple containing the current scope used for variables_overwrite.
646+
637647 :variables_overwrite: A dict filled with references to variables
638648 that change value within an assert. This happens when a variable is
639649 reassigned with the walrus operator
@@ -655,7 +665,10 @@ def __init__(
655665 else :
656666 self .enable_assertion_pass_hook = False
657667 self .source = source
658- self .variables_overwrite : Dict [str , str ] = {}
668+ self .scope : tuple [ast .AST , ...] = ()
669+ self .variables_overwrite : defaultdict [
670+ tuple [ast .AST , ...], Dict [str , str ]
671+ ] = defaultdict (dict )
659672
660673 def run (self , mod : ast .Module ) -> None :
661674 """Find all assert statements in *mod* and rewrite them."""
@@ -719,9 +732,17 @@ def run(self, mod: ast.Module) -> None:
719732 mod .body [pos :pos ] = imports
720733
721734 # Collect asserts.
722- nodes : List [ast .AST ] = [mod ]
735+ self .scope = (mod ,)
736+ nodes : List [Union [ast .AST , Sentinel ]] = [mod ]
723737 while nodes :
724738 node = nodes .pop ()
739+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
740+ self .scope = tuple ((* self .scope , node ))
741+ nodes .append (_SCOPE_END_MARKER )
742+ if node == _SCOPE_END_MARKER :
743+ self .scope = self .scope [:- 1 ]
744+ continue
745+ assert isinstance (node , ast .AST )
725746 for name , field in ast .iter_fields (node ):
726747 if isinstance (field , list ):
727748 new : List [ast .AST ] = []
@@ -992,7 +1013,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
9921013 ]
9931014 ):
9941015 pytest_temp = self .variable ()
995- self .variables_overwrite [
1016+ self .variables_overwrite [self . scope ][
9961017 v .left .target .id
9971018 ] = v .left # type:ignore[assignment]
9981019 v .left .target .id = pytest_temp
@@ -1035,17 +1056,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
10351056 new_args = []
10361057 new_kwargs = []
10371058 for arg in call .args :
1038- if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite :
1039- arg = self .variables_overwrite [arg .id ] # type:ignore[assignment]
1059+ if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite .get (
1060+ self .scope , {}
1061+ ):
1062+ arg = self .variables_overwrite [self .scope ][
1063+ arg .id
1064+ ] # type:ignore[assignment]
10401065 res , expl = self .visit (arg )
10411066 arg_expls .append (expl )
10421067 new_args .append (res )
10431068 for keyword in call .keywords :
1044- if (
1045- isinstance (keyword .value , ast .Name )
1046- and keyword .value .id in self .variables_overwrite
1047- ):
1048- keyword .value = self .variables_overwrite [
1069+ if isinstance (
1070+ keyword .value , ast .Name
1071+ ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1072+ keyword .value = self .variables_overwrite [self .scope ][
10491073 keyword .value .id
10501074 ] # type:ignore[assignment]
10511075 res , expl = self .visit (keyword .value )
@@ -1081,12 +1105,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10811105 def visit_Compare (self , comp : ast .Compare ) -> Tuple [ast .expr , str ]:
10821106 self .push_format_context ()
10831107 # We first check if we have overwritten a variable in the previous assert
1084- if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite :
1085- comp .left = self .variables_overwrite [
1108+ if isinstance (
1109+ comp .left , ast .Name
1110+ ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1111+ comp .left = self .variables_overwrite [self .scope ][
10861112 comp .left .id
10871113 ] # type:ignore[assignment]
10881114 if isinstance (comp .left , ast .NamedExpr ):
1089- self .variables_overwrite [
1115+ self .variables_overwrite [self . scope ][
10901116 comp .left .target .id
10911117 ] = comp .left # type:ignore[assignment]
10921118 left_res , left_expl = self .visit (comp .left )
@@ -1106,7 +1132,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
11061132 and next_operand .target .id == left_res .id
11071133 ):
11081134 next_operand .target .id = self .variable ()
1109- self .variables_overwrite [
1135+ self .variables_overwrite [self . scope ][
11101136 left_res .id
11111137 ] = next_operand # type:ignore[assignment]
11121138 next_res , next_expl = self .visit (next_operand )
0 commit comments