1616import aesara .tensor as at
1717import numpy as np
1818import pytest
19+ import scipy .stats as st
1920
2021from arviz .data .inference_data import InferenceData
2122
2223import pymc3 as pm
2324
25+ from pymc3 .aesaraf import floatX
2426from pymc3 .backends .base import MultiTrace
27+ from pymc3 .smc .smc import SMC
2528from pymc3 .tests .helpers import SeededTest
2629
2730
@@ -64,10 +67,6 @@ def two_gaussians(x):
6467 x = pm .Normal ("x" , 0 , 1 )
6568 y = pm .Normal ("y" , x , 1 , observed = 0 )
6669
67- with pm .Model () as self .slow_model :
68- x = pm .Normal ("x" , 0 , 1 )
69- y = pm .Normal ("y" , x , 1 , observed = 100 )
70-
7170 def test_sample (self ):
7271 with self .SMC_test :
7372 mtrace = pm .sample_smc (draws = self .samples , return_inferencedata = False )
@@ -76,12 +75,43 @@ def test_sample(self):
7675 mu1d = np .abs (x ).mean (axis = 0 )
7776 np .testing .assert_allclose (self .muref , mu1d , rtol = 0.0 , atol = 0.03 )
7877
79- def test_discrete_continuous (self ):
80- with pm .Model () as model :
81- a = pm .Poisson ("a" , 5 )
82- b = pm .HalfNormal ("b" , 10 )
83- y = pm .Normal ("y" , a , b , observed = [1 , 2 , 3 , 4 ])
84- trace = pm .sample_smc (draws = 10 )
78+ def test_discrete_rounding_proposal (self ):
79+ """
80+ Test that discrete variable values are automatically rounded
81+ in SMC logp functions
82+ """
83+
84+ with pm .Model () as m :
85+ z = pm .Bernoulli ("z" , p = 0.7 )
86+ like = pm .Potential ("like" , z * 1.0 )
87+
88+ smc = SMC (model = m )
89+ smc .initialize_population ()
90+ smc .setup_kernel ()
91+ smc .initialize_logp ()
92+
93+ assert smc .prior_logp_func (floatX (np .array ([- 0.51 ]))) == - np .inf
94+ assert np .isclose (smc .prior_logp_func (floatX (np .array ([- 0.49 ]))), np .log (0.3 ))
95+ assert np .isclose (smc .prior_logp_func (floatX (np .array ([0.49 ]))), np .log (0.3 ))
96+ assert np .isclose (smc .prior_logp_func (floatX (np .array ([0.51 ]))), np .log (0.7 ))
97+ assert smc .prior_logp_func (floatX (np .array ([1.51 ]))) == - np .inf
98+
99+ def test_unobserved_discrete (self ):
100+ n = 10
101+ rng = self .get_random_state ()
102+
103+ z_true = np .zeros (n , dtype = int )
104+ z_true [int (n / 2 ) :] = 1
105+ y = st .norm (np .array ([- 1 , 1 ])[z_true ], 0.25 ).rvs (random_state = rng )
106+
107+ with pm .Model () as m :
108+ z = pm .Bernoulli ("z" , p = 0.5 , size = n )
109+ mu = pm .math .switch (z , 1.0 , - 1.0 )
110+ like = pm .Normal ("like" , mu = mu , sigma = 0.25 , observed = y )
111+
112+ trace = pm .sample_smc (chains = 1 , return_inferencedata = False )
113+
114+ assert np .all (np .median (trace ["z" ], axis = 0 ) == z_true )
85115
86116 def test_ml (self ):
87117 data = np .repeat ([1 , 0 ], [50 , 50 ])
@@ -109,14 +139,6 @@ def test_start(self):
109139 }
110140 trace = pm .sample_smc (500 , chains = 1 , start = start )
111141
112- def test_slowdown_warning (self ):
113- with aesara .config .change_flags (floatX = "float32" ):
114- with pytest .warns (UserWarning , match = "SMC sampling may run slower due to" ):
115- with pm .Model () as model :
116- a = pm .Poisson ("a" , 5 )
117- y = pm .Normal ("y" , a , 5 , observed = [1 , 2 , 3 , 4 ])
118- trace = pm .sample_smc (draws = 100 , chains = 2 , cores = 1 )
119-
120142 @pytest .mark .parametrize ("chains" , (1 , 2 ))
121143 def test_return_datatype (self , chains ):
122144 draws = 10
0 commit comments