1010import sys
1111import types
1212
13+ import astor
1314import atomicwrites
1415
1516from _pytest ._io .saferepr import saferepr
@@ -134,7 +135,7 @@ def exec_module(self, module):
134135 co = _read_pyc (fn , pyc , state .trace )
135136 if co is None :
136137 state .trace ("rewriting {!r}" .format (fn ))
137- source_stat , co = _rewrite_test (fn )
138+ source_stat , co = _rewrite_test (fn , self . config )
138139 if write :
139140 self ._writing_pyc = True
140141 try :
@@ -278,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc):
278279 return True
279280
280281
281- def _rewrite_test (fn ):
282+ def _rewrite_test (fn , config ):
282283 """read and rewrite *fn* and return the code object."""
283284 stat = os .stat (fn )
284285 with open (fn , "rb" ) as f :
285286 source = f .read ()
286287 tree = ast .parse (source , filename = fn )
287- rewrite_asserts (tree , fn )
288+ rewrite_asserts (tree , fn , config )
288289 co = compile (tree , fn , "exec" , dont_inherit = True )
289290 return stat , co
290291
@@ -326,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
326327 return co
327328
328329
329- def rewrite_asserts (mod , module_path = None ):
330+ def rewrite_asserts (mod , module_path = None , config = None ):
330331 """Rewrite the assert statements in mod."""
331- AssertionRewriter (module_path ).run (mod )
332+ AssertionRewriter (module_path , config ).run (mod )
332333
333334
334335def _saferepr (obj ):
@@ -401,6 +402,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
401402 return expl
402403
403404
405+ def _call_assertion_pass (lineno , orig , expl ):
406+ if util ._assertion_pass is not None :
407+ util ._assertion_pass (lineno = lineno , orig = orig , expl = expl )
408+
409+
410+ def _check_if_assertion_pass_impl ():
411+ """Checks if any plugins implement the pytest_assertion_pass hook
412+ in order not to generate explanation unecessarily (might be expensive)"""
413+ return True if util ._assertion_pass else False
414+
415+
404416unary_map = {ast .Not : "not %s" , ast .Invert : "~%s" , ast .USub : "-%s" , ast .UAdd : "+%s" }
405417
406418binop_map = {
@@ -473,7 +485,8 @@ class AssertionRewriter(ast.NodeVisitor):
473485 original assert statement: it rewrites the test of an assertion
474486 to provide intermediate values and replace it with an if statement
475487 which raises an assertion error with a detailed explanation in
476- case the expression is false.
488+ case the expression is false and calls pytest_assertion_pass hook
489+ if expression is true.
477490
478491 For this .visit_Assert() uses the visitor pattern to visit all the
479492 AST nodes of the ast.Assert.test field, each visit call returning
@@ -491,9 +504,10 @@ class AssertionRewriter(ast.NodeVisitor):
491504 by statements. Variables are created using .variable() and
492505 have the form of "@py_assert0".
493506
494- :on_failure: The AST statements which will be executed if the
495- assertion test fails. This is the code which will construct
496- the failure message and raises the AssertionError.
507+ :expl_stmts: The AST statements which will be executed to get
508+ data from the assertion. This is the code which will construct
509+ the detailed assertion message that is used in the AssertionError
510+ or for the pytest_assertion_pass hook.
497511
498512 :explanation_specifiers: A dict filled by .explanation_param()
499513 with %-formatting placeholders and their corresponding
@@ -509,9 +523,16 @@ class AssertionRewriter(ast.NodeVisitor):
509523
510524 """
511525
512- def __init__ (self , module_path ):
526+ def __init__ (self , module_path , config ):
513527 super ().__init__ ()
514528 self .module_path = module_path
529+ self .config = config
530+ if config is not None :
531+ self .enable_assertion_pass_hook = config .getini (
532+ "enable_assertion_pass_hook"
533+ )
534+ else :
535+ self .enable_assertion_pass_hook = False
515536
516537 def run (self , mod ):
517538 """Find all assert statements in *mod* and rewrite them."""
@@ -642,7 +663,7 @@ def pop_format_context(self, expl_expr):
642663
643664 The expl_expr should be an ast.Str instance constructed from
644665 the %-placeholders created by .explanation_param(). This will
645- add the required code to format said string to .on_failure and
666+ add the required code to format said string to .expl_stmts and
646667 return the ast.Name instance of the formatted string.
647668
648669 """
@@ -653,7 +674,9 @@ def pop_format_context(self, expl_expr):
653674 format_dict = ast .Dict (keys , list (current .values ()))
654675 form = ast .BinOp (expl_expr , ast .Mod (), format_dict )
655676 name = "@py_format" + str (next (self .variable_counter ))
656- self .on_failure .append (ast .Assign ([ast .Name (name , ast .Store ())], form ))
677+ if self .enable_assertion_pass_hook :
678+ self .format_variables .append (name )
679+ self .expl_stmts .append (ast .Assign ([ast .Name (name , ast .Store ())], form ))
657680 return ast .Name (name , ast .Load ())
658681
659682 def generic_visit (self , node ):
@@ -687,8 +710,12 @@ def visit_Assert(self, assert_):
687710 self .statements = []
688711 self .variables = []
689712 self .variable_counter = itertools .count ()
713+
714+ if self .enable_assertion_pass_hook :
715+ self .format_variables = []
716+
690717 self .stack = []
691- self .on_failure = []
718+ self .expl_stmts = []
692719 self .push_format_context ()
693720 # Rewrite assert into a bunch of statements.
694721 top_condition , explanation = self .visit (assert_ .test )
@@ -699,24 +726,77 @@ def visit_Assert(self, assert_):
699726 top_condition , module_path = self .module_path , lineno = assert_ .lineno
700727 )
701728 )
702- # Create failure message.
703- body = self .on_failure
704- negation = ast .UnaryOp (ast .Not (), top_condition )
705- self .statements .append (ast .If (negation , body , []))
706- if assert_ .msg :
707- assertmsg = self .helper ("_format_assertmsg" , assert_ .msg )
708- explanation = "\n >assert " + explanation
709- else :
710- assertmsg = ast .Str ("" )
711- explanation = "assert " + explanation
712- template = ast .BinOp (assertmsg , ast .Add (), ast .Str (explanation ))
713- msg = self .pop_format_context (template )
714- fmt = self .helper ("_format_explanation" , msg )
715- err_name = ast .Name ("AssertionError" , ast .Load ())
716- exc = ast .Call (err_name , [fmt ], [])
717- raise_ = ast .Raise (exc , None )
718-
719- body .append (raise_ )
729+
730+ if self .enable_assertion_pass_hook : # Experimental pytest_assertion_pass hook
731+ negation = ast .UnaryOp (ast .Not (), top_condition )
732+ msg = self .pop_format_context (ast .Str (explanation ))
733+
734+ # Failed
735+ if assert_ .msg :
736+ assertmsg = self .helper ("_format_assertmsg" , assert_ .msg )
737+ gluestr = "\n >assert "
738+ else :
739+ assertmsg = ast .Str ("" )
740+ gluestr = "assert "
741+ err_explanation = ast .BinOp (ast .Str (gluestr ), ast .Add (), msg )
742+ err_msg = ast .BinOp (assertmsg , ast .Add (), err_explanation )
743+ err_name = ast .Name ("AssertionError" , ast .Load ())
744+ fmt = self .helper ("_format_explanation" , err_msg )
745+ exc = ast .Call (err_name , [fmt ], [])
746+ raise_ = ast .Raise (exc , None )
747+ statements_fail = []
748+ statements_fail .extend (self .expl_stmts )
749+ statements_fail .append (raise_ )
750+
751+ # Passed
752+ fmt_pass = self .helper ("_format_explanation" , msg )
753+ orig = astor .to_source (assert_ .test ).rstrip ("\n " ).lstrip ("(" ).rstrip (")" )
754+ hook_call_pass = ast .Expr (
755+ self .helper (
756+ "_call_assertion_pass" ,
757+ ast .Num (assert_ .lineno ),
758+ ast .Str (orig ),
759+ fmt_pass ,
760+ )
761+ )
762+ # If any hooks implement assert_pass hook
763+ hook_impl_test = ast .If (
764+ self .helper ("_check_if_assertion_pass_impl" ),
765+ self .expl_stmts + [hook_call_pass ],
766+ [],
767+ )
768+ statements_pass = [hook_impl_test ]
769+
770+ # Test for assertion condition
771+ main_test = ast .If (negation , statements_fail , statements_pass )
772+ self .statements .append (main_test )
773+ if self .format_variables :
774+ variables = [
775+ ast .Name (name , ast .Store ()) for name in self .format_variables
776+ ]
777+ clear_format = ast .Assign (variables , _NameConstant (None ))
778+ self .statements .append (clear_format )
779+
780+ else : # Original assertion rewriting
781+ # Create failure message.
782+ body = self .expl_stmts
783+ negation = ast .UnaryOp (ast .Not (), top_condition )
784+ self .statements .append (ast .If (negation , body , []))
785+ if assert_ .msg :
786+ assertmsg = self .helper ("_format_assertmsg" , assert_ .msg )
787+ explanation = "\n >assert " + explanation
788+ else :
789+ assertmsg = ast .Str ("" )
790+ explanation = "assert " + explanation
791+ template = ast .BinOp (assertmsg , ast .Add (), ast .Str (explanation ))
792+ msg = self .pop_format_context (template )
793+ fmt = self .helper ("_format_explanation" , msg )
794+ err_name = ast .Name ("AssertionError" , ast .Load ())
795+ exc = ast .Call (err_name , [fmt ], [])
796+ raise_ = ast .Raise (exc , None )
797+
798+ body .append (raise_ )
799+
720800 # Clear temporary variables by setting them to None.
721801 if self .variables :
722802 variables = [ast .Name (name , ast .Store ()) for name in self .variables ]
@@ -770,22 +850,22 @@ def visit_BoolOp(self, boolop):
770850 app = ast .Attribute (expl_list , "append" , ast .Load ())
771851 is_or = int (isinstance (boolop .op , ast .Or ))
772852 body = save = self .statements
773- fail_save = self .on_failure
853+ fail_save = self .expl_stmts
774854 levels = len (boolop .values ) - 1
775855 self .push_format_context ()
776856 # Process each operand, short-circuiting if needed.
777857 for i , v in enumerate (boolop .values ):
778858 if i :
779859 fail_inner = []
780860 # cond is set in a prior loop iteration below
781- self .on_failure .append (ast .If (cond , fail_inner , [])) # noqa
782- self .on_failure = fail_inner
861+ self .expl_stmts .append (ast .If (cond , fail_inner , [])) # noqa
862+ self .expl_stmts = fail_inner
783863 self .push_format_context ()
784864 res , expl = self .visit (v )
785865 body .append (ast .Assign ([ast .Name (res_var , ast .Store ())], res ))
786866 expl_format = self .pop_format_context (ast .Str (expl ))
787867 call = ast .Call (app , [expl_format ], [])
788- self .on_failure .append (ast .Expr (call ))
868+ self .expl_stmts .append (ast .Expr (call ))
789869 if i < levels :
790870 cond = res
791871 if is_or :
@@ -794,7 +874,7 @@ def visit_BoolOp(self, boolop):
794874 self .statements .append (ast .If (cond , inner , []))
795875 self .statements = body = inner
796876 self .statements = save
797- self .on_failure = fail_save
877+ self .expl_stmts = fail_save
798878 expl_template = self .helper ("_format_boolop" , expl_list , ast .Num (is_or ))
799879 expl = self .pop_format_context (expl_template )
800880 return ast .Name (res_var , ast .Load ()), self .explanation_param (expl )
0 commit comments