12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import warnings
16
+
15
17
from contextlib import ExitStack as does_not_raise
16
18
17
19
import aesara
@@ -655,10 +657,10 @@ def mixmixlogp(value, point):
655
657
assert_allclose (priorlogp + mixmixlogpg .sum (), model .logp (test_point ), rtol = rtol )
656
658
657
659
def test_iterable_single_component_warning (self ):
658
- with pytest .warns (None ) as record :
660
+ with warnings .catch_warnings ():
661
+ warnings .simplefilter ("error" )
659
662
Mixture .dist (w = [0.5 , 0.5 ], comp_dists = Normal .dist (size = 2 ))
660
663
Mixture .dist (w = [0.5 , 0.5 ], comp_dists = [Normal .dist (size = 2 ), Normal .dist (size = 2 )])
661
- assert not record
662
664
663
665
with pytest .warns (UserWarning , match = "Single component will be treated as a mixture" ):
664
666
Mixture .dist (w = [0.5 , 0.5 ], comp_dists = [Normal .dist (size = 2 )])
@@ -1303,9 +1305,9 @@ def test_logp(self):
1303
1305
def test_warning (self ):
1304
1306
with Model () as m :
1305
1307
comp_dists = [HalfNormal .dist (), Exponential .dist (1 )]
1306
- with pytest .warns (None ) as rec :
1308
+ with warnings .catch_warnings ():
1309
+ warnings .simplefilter ("error" )
1307
1310
Mixture ("mix1" , w = [0.5 , 0.5 ], comp_dists = comp_dists )
1308
- assert not rec
1309
1311
1310
1312
comp_dists = [Uniform .dist (0 , 1 ), Uniform .dist (0 , 2 )]
1311
1313
with pytest .warns (MixtureTransformWarning ):
@@ -1315,16 +1317,16 @@ def test_warning(self):
1315
1317
with pytest .warns (MixtureTransformWarning ):
1316
1318
Mixture ("mix3" , w = [0.5 , 0.5 ], comp_dists = comp_dists )
1317
1319
1318
- with pytest .warns (None ) as rec :
1320
+ with warnings .catch_warnings ():
1321
+ warnings .simplefilter ("error" )
1319
1322
Mixture ("mix4" , w = [0.5 , 0.5 ], comp_dists = comp_dists , transform = None )
1320
- assert not rec
1321
1323
1322
- with pytest .warns (None ) as rec :
1324
+ with warnings .catch_warnings ():
1325
+ warnings .simplefilter ("error" )
1323
1326
Mixture ("mix5" , w = [0.5 , 0.5 ], comp_dists = comp_dists , observed = 1 )
1324
- assert not rec
1325
1327
1326
1328
# Case where the appropriate default transform is None
1327
1329
comp_dists = [Normal .dist (), Normal .dist ()]
1328
- with pytest .warns (None ) as rec :
1330
+ with warnings .catch_warnings ():
1331
+ warnings .simplefilter ("error" )
1329
1332
Mixture ("mix6" , w = [0.5 , 0.5 ], comp_dists = comp_dists )
1330
- assert not rec
0 commit comments