11""" recording warnings during test function execution. """
2- import inspect
32import re
43import warnings
4+ from types import TracebackType
5+ from typing import Any
6+ from typing import Callable
7+ from typing import Iterator
8+ from typing import List
9+ from typing import Optional
10+ from typing import overload
11+ from typing import Pattern
12+ from typing import Tuple
13+ from typing import Union
514
615from _pytest .fixtures import yield_fixture
716from _pytest .outcomes import fail
817
18+ if False : # TYPE_CHECKING
19+ from typing import Type
20+
921
1022@yield_fixture
1123def recwarn ():
@@ -42,7 +54,32 @@ def deprecated_call(func=None, *args, **kwargs):
4254 return warns ((DeprecationWarning , PendingDeprecationWarning ), * args , ** kwargs )
4355
4456
45- def warns (expected_warning , * args , match = None , ** kwargs ):
57+ @overload
58+ def warns (
59+ expected_warning : Union ["Type[Warning]" , Tuple ["Type[Warning]" , ...]],
60+ * ,
61+ match : Optional [Union [str , Pattern ]] = ...
62+ ) -> "WarningsChecker" :
63+ ... # pragma: no cover
64+
65+
66+ @overload
67+ def warns (
68+ expected_warning : Union ["Type[Warning]" , Tuple ["Type[Warning]" , ...]],
69+ func : Callable ,
70+ * args : Any ,
71+ match : Optional [Union [str , Pattern ]] = ...,
72+ ** kwargs : Any
73+ ) -> Union [Any ]:
74+ ... # pragma: no cover
75+
76+
77+ def warns (
78+ expected_warning : Union ["Type[Warning]" , Tuple ["Type[Warning]" , ...]],
79+ * args : Any ,
80+ match : Optional [Union [str , Pattern ]] = None ,
81+ ** kwargs : Any
82+ ) -> Union ["WarningsChecker" , Any ]:
4683 r"""Assert that code raises a particular class of warning.
4784
4885 Specifically, the parameter ``expected_warning`` can be a warning class or
@@ -101,81 +138,107 @@ class WarningsRecorder(warnings.catch_warnings):
101138 def __init__ (self ):
102139 super ().__init__ (record = True )
103140 self ._entered = False
104- self ._list = []
141+ self ._list = [] # type: List[warnings._Record]
105142
106143 @property
107- def list (self ):
144+ def list (self ) -> List [ "warnings._Record" ] :
108145 """The list of recorded warnings."""
109146 return self ._list
110147
111- def __getitem__ (self , i ) :
148+ def __getitem__ (self , i : int ) -> "warnings._Record" :
112149 """Get a recorded warning by index."""
113150 return self ._list [i ]
114151
115- def __iter__ (self ):
152+ def __iter__ (self ) -> Iterator [ "warnings._Record" ] :
116153 """Iterate through the recorded warnings."""
117154 return iter (self ._list )
118155
119- def __len__ (self ):
156+ def __len__ (self ) -> int :
120157 """The number of recorded warnings."""
121158 return len (self ._list )
122159
123- def pop (self , cls = Warning ):
160+ def pop (self , cls : "Type[Warning]" = Warning ) -> "warnings._Record" :
124161 """Pop the first recorded warning, raise exception if not exists."""
125162 for i , w in enumerate (self ._list ):
126163 if issubclass (w .category , cls ):
127164 return self ._list .pop (i )
128165 __tracebackhide__ = True
129166 raise AssertionError ("%r not found in warning list" % cls )
130167
131- def clear (self ):
168+ def clear (self ) -> None :
132169 """Clear the list of recorded warnings."""
133170 self ._list [:] = []
134171
135- def __enter__ (self ):
172+ # Type ignored because it doesn't exactly warnings.catch_warnings.__enter__
173+ # -- it returns a List but we only emulate one.
174+ def __enter__ (self ) -> "WarningsRecorder" : # type: ignore
136175 if self ._entered :
137176 __tracebackhide__ = True
138177 raise RuntimeError ("Cannot enter %r twice" % self )
139- self ._list = super ().__enter__ ()
178+ _list = super ().__enter__ ()
179+ # record=True means it's None.
180+ assert _list is not None
181+ self ._list = _list
140182 warnings .simplefilter ("always" )
141183 return self
142184
143- def __exit__ (self , * exc_info ):
185+ def __exit__ (
186+ self ,
187+ exc_type : Optional ["Type[BaseException]" ],
188+ exc_val : Optional [BaseException ],
189+ exc_tb : Optional [TracebackType ],
190+ ) -> bool :
144191 if not self ._entered :
145192 __tracebackhide__ = True
146193 raise RuntimeError ("Cannot exit %r without entering first" % self )
147194
148- super ().__exit__ (* exc_info )
195+ super ().__exit__ (exc_type , exc_val , exc_tb )
149196
150197 # Built-in catch_warnings does not reset entered state so we do it
151198 # manually here for this context manager to become reusable.
152199 self ._entered = False
153200
201+ return False
202+
154203
155204class WarningsChecker (WarningsRecorder ):
156- def __init__ (self , expected_warning = None , match_expr = None ):
205+ def __init__ (
206+ self ,
207+ expected_warning : Optional [
208+ Union ["Type[Warning]" , Tuple ["Type[Warning]" , ...]]
209+ ] = None ,
210+ match_expr : Optional [Union [str , Pattern ]] = None ,
211+ ) -> None :
157212 super ().__init__ ()
158213
159214 msg = "exceptions must be derived from Warning, not %s"
160- if isinstance (expected_warning , tuple ):
215+ if expected_warning is None :
216+ expected_warning_tup = None
217+ elif isinstance (expected_warning , tuple ):
161218 for exc in expected_warning :
162- if not inspect . isclass (exc ):
219+ if not issubclass (exc , Warning ):
163220 raise TypeError (msg % type (exc ))
164- elif inspect .isclass (expected_warning ):
165- expected_warning = (expected_warning ,)
166- elif expected_warning is not None :
221+ expected_warning_tup = expected_warning
222+ elif issubclass (expected_warning , Warning ):
223+ expected_warning_tup = (expected_warning ,)
224+ else :
167225 raise TypeError (msg % type (expected_warning ))
168226
169- self .expected_warning = expected_warning
227+ self .expected_warning = expected_warning_tup
170228 self .match_expr = match_expr
171229
172- def __exit__ (self , * exc_info ):
173- super ().__exit__ (* exc_info )
230+ def __exit__ (
231+ self ,
232+ exc_type : Optional ["Type[BaseException]" ],
233+ exc_val : Optional [BaseException ],
234+ exc_tb : Optional [TracebackType ],
235+ ) -> bool :
236+ super ().__exit__ (exc_type , exc_val , exc_tb )
174237
175238 __tracebackhide__ = True
176239
177240 # only check if we're not currently handling an exception
178- if all ( a is None for a in exc_info ) :
241+ if exc_type is None and exc_val is None and exc_tb is None :
179242 if self .expected_warning is not None :
180243 if not any (issubclass (r .category , self .expected_warning ) for r in self ):
181244 __tracebackhide__ = True
@@ -200,3 +263,4 @@ def __exit__(self, *exc_info):
200263 [each .message for each in self ],
201264 )
202265 )
266+ return False
0 commit comments