1717from typing import List
1818from typing import Optional
1919from typing import Set
20+ from typing import Tuple
2021
2122import atomicwrites
2223
@@ -48,13 +49,13 @@ def __init__(self, config):
4849 except ValueError :
4950 self .fnpats = ["test_*.py" , "*_test.py" ]
5051 self .session = None
51- self ._rewritten_names = set ()
52- self ._must_rewrite = set ()
52+ self ._rewritten_names = set () # type: Set[str]
53+ self ._must_rewrite = set () # type: Set[str]
5354 # flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
5455 # which might result in infinite recursion (#3506)
5556 self ._writing_pyc = False
5657 self ._basenames_to_check_rewrite = {"conftest" }
57- self ._marked_for_rewrite_cache = {}
58+ self ._marked_for_rewrite_cache = {} # type: Dict[str, bool]
5859 self ._session_paths_checked = False
5960
6061 def set_session (self , session ):
@@ -203,7 +204,7 @@ def _should_rewrite(self, name, fn, state):
203204
204205 return self ._is_marked_for_rewrite (name , state )
205206
206- def _is_marked_for_rewrite (self , name , state ):
207+ def _is_marked_for_rewrite (self , name : str , state ):
207208 try :
208209 return self ._marked_for_rewrite_cache [name ]
209210 except KeyError :
@@ -218,7 +219,7 @@ def _is_marked_for_rewrite(self, name, state):
218219 self ._marked_for_rewrite_cache [name ] = False
219220 return False
220221
221- def mark_rewrite (self , * names ) :
222+ def mark_rewrite (self , * names : str ) -> None :
222223 """Mark import names as needing to be rewritten.
223224
224225 The named module or package as well as any nested modules will
@@ -385,6 +386,7 @@ def _format_boolop(explanations, is_or):
385386
386387
387388def _call_reprcompare (ops , results , expls , each_obj ):
389+ # type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
388390 for i , res , expl in zip (range (len (ops )), results , expls ):
389391 try :
390392 done = not res
@@ -400,11 +402,13 @@ def _call_reprcompare(ops, results, expls, each_obj):
400402
401403
402404def _call_assertion_pass (lineno , orig , expl ):
405+ # type: (int, str, str) -> None
403406 if util ._assertion_pass is not None :
404- util ._assertion_pass (lineno = lineno , orig = orig , expl = expl )
407+ util ._assertion_pass (lineno , orig , expl )
405408
406409
407410def _check_if_assertion_pass_impl ():
411+ # type: () -> bool
408412 """Checks if any plugins implement the pytest_assertion_pass hook
409413 in order not to generate explanation unecessarily (might be expensive)"""
410414 return True if util ._assertion_pass else False
@@ -578,7 +582,7 @@ def __init__(self, module_path, config, source):
578582 def _assert_expr_to_lineno (self ):
579583 return _get_assertion_exprs (self .source )
580584
581- def run (self , mod ) :
585+ def run (self , mod : ast . Module ) -> None :
582586 """Find all assert statements in *mod* and rewrite them."""
583587 if not mod .body :
584588 # Nothing to do.
@@ -620,12 +624,12 @@ def run(self, mod):
620624 ]
621625 mod .body [pos :pos ] = imports
622626 # Collect asserts.
623- nodes = [mod ]
627+ nodes = [mod ] # type: List[ast.AST]
624628 while nodes :
625629 node = nodes .pop ()
626630 for name , field in ast .iter_fields (node ):
627631 if isinstance (field , list ):
628- new = []
632+ new = [] # type: List
629633 for i , child in enumerate (field ):
630634 if isinstance (child , ast .Assert ):
631635 # Transform assert.
@@ -699,7 +703,7 @@ def push_format_context(self):
699703 .explanation_param().
700704
701705 """
702- self .explanation_specifiers = {}
706+ self .explanation_specifiers = {} # type: Dict[str, ast.expr]
703707 self .stack .append (self .explanation_specifiers )
704708
705709 def pop_format_context (self , expl_expr ):
@@ -742,7 +746,8 @@ def visit_Assert(self, assert_):
742746 from _pytest .warning_types import PytestAssertRewriteWarning
743747 import warnings
744748
745- warnings .warn_explicit (
749+ # Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121
750+ warnings .warn_explicit ( # type: ignore
746751 PytestAssertRewriteWarning (
747752 "assertion is always true, perhaps remove parentheses?"
748753 ),
@@ -751,15 +756,15 @@ def visit_Assert(self, assert_):
751756 lineno = assert_ .lineno ,
752757 )
753758
754- self .statements = []
755- self .variables = []
759+ self .statements = [] # type: List[ast.stmt]
760+ self .variables = [] # type: List[str]
756761 self .variable_counter = itertools .count ()
757762
758763 if self .enable_assertion_pass_hook :
759- self .format_variables = []
764+ self .format_variables = [] # type: List[str]
760765
761- self .stack = []
762- self .expl_stmts = []
766+ self .stack = [] # type: List[Dict[str, ast.expr]]
767+ self .expl_stmts = [] # type: List[ast.stmt]
763768 self .push_format_context ()
764769 # Rewrite assert into a bunch of statements.
765770 top_condition , explanation = self .visit (assert_ .test )
@@ -897,7 +902,7 @@ def visit_BoolOp(self, boolop):
897902 # Process each operand, short-circuiting if needed.
898903 for i , v in enumerate (boolop .values ):
899904 if i :
900- fail_inner = []
905+ fail_inner = [] # type: List[ast.stmt]
901906 # cond is set in a prior loop iteration below
902907 self .expl_stmts .append (ast .If (cond , fail_inner , [])) # noqa
903908 self .expl_stmts = fail_inner
@@ -908,10 +913,10 @@ def visit_BoolOp(self, boolop):
908913 call = ast .Call (app , [expl_format ], [])
909914 self .expl_stmts .append (ast .Expr (call ))
910915 if i < levels :
911- cond = res
916+ cond = res # type: ast.expr
912917 if is_or :
913918 cond = ast .UnaryOp (ast .Not (), cond )
914- inner = []
919+ inner = [] # type: List[ast.stmt]
915920 self .statements .append (ast .If (cond , inner , []))
916921 self .statements = body = inner
917922 self .statements = save
@@ -977,7 +982,7 @@ def visit_Attribute(self, attr):
977982 expl = pat % (res_expl , res_expl , value_expl , attr .attr )
978983 return res , expl
979984
980- def visit_Compare (self , comp ):
985+ def visit_Compare (self , comp : ast . Compare ):
981986 self .push_format_context ()
982987 left_res , left_expl = self .visit (comp .left )
983988 if isinstance (comp .left , (ast .Compare , ast .BoolOp )):
@@ -1010,7 +1015,7 @@ def visit_Compare(self, comp):
10101015 ast .Tuple (results , ast .Load ()),
10111016 )
10121017 if len (comp .ops ) > 1 :
1013- res = ast .BoolOp (ast .And (), load_names )
1018+ res = ast .BoolOp (ast .And (), load_names ) # type: ast.expr
10141019 else :
10151020 res = load_names [0 ]
10161021 return res , self .explanation_param (self .pop_format_context (expl_call ))
0 commit comments