@@ -151,24 +151,6 @@ def test_sample_does_not_rely_on_external_global_seeding(self):
151151 assert np .all (idata12 ["x" ] != idata22 ["x" ])
152152 assert np .all (idata13 ["x" ] != idata23 ["x" ])
153153
154- def test_sample (self ):
155- test_cores = [1 ]
156- with self .model :
157- for cores in test_cores :
158- for steps in [1 , 10 , 300 ]:
159- with warnings .catch_warnings ():
160- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
161- warnings .filterwarnings (
162- "ignore" , "More chains .* than draws .*" , UserWarning
163- )
164- pm .sample (
165- steps ,
166- tune = 0 ,
167- step = self .step ,
168- cores = cores ,
169- random_seed = self .random_seed ,
170- )
171-
172154 def test_sample_init (self ):
173155 with self .model :
174156 for init in (
@@ -199,11 +181,11 @@ def test_sample_init(self):
199181 def test_sample_args (self ):
200182 with self .model :
201183 with pytest .raises (ValueError ) as excinfo :
202- pm .sample (50 , tune = 0 , foo = 1 )
184+ pm .sample (50 , tune = 0 , chains = 1 , step = pm . Metropolis (), foo = 1 )
203185 assert "'foo'" in str (excinfo .value )
204186
205187 with pytest .raises (ValueError ) as excinfo :
206- pm .sample (50 , tune = 0 , foo = {})
188+ pm .sample (50 , tune = 0 , chains = 1 , step = pm . Metropolis (), foo = {})
207189 assert "foo" in str (excinfo .value )
208190
209191 def test_parallel_start (self ):
@@ -232,6 +214,7 @@ def test_sample_tune_len(self):
232214 draws = 100 ,
233215 tune = 50 ,
234216 cores = 1 ,
217+ step = pm .Metropolis (),
235218 return_inferencedata = False ,
236219 discard_tuned_samples = False ,
237220 )
@@ -387,18 +370,7 @@ def test_sequential_backend(self):
387370 backend = NDArray ()
388371 with warnings .catch_warnings ():
389372 warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
390- pm .sample (10 , cores = 1 , chains = 2 , trace = backend )
391-
392- def test_exceptions (self ):
393- # Test iteration over MultiTrace NotImplementedError
394- with pm .Model () as model :
395- mu = pm .Normal ("mu" , 0.0 , 1.0 )
396- a = pm .Normal ("a" , mu = mu , sigma = 1 , observed = np .array ([0.5 , 0.2 ]))
397- with warnings .catch_warnings ():
398- warnings .filterwarnings ("ignore" , ".*number of samples.*" , UserWarning )
399- trace = pm .sample (tune = 0 , draws = 10 , chains = 2 , return_inferencedata = False )
400- with pytest .raises (NotImplementedError ):
401- xvars = [t ["mu" ] for t in trace ]
373+ pm .sample (10 , tune = 5 , cores = 1 , chains = 2 , step = pm .Metropolis (), trace = backend )
402374
403375 @pytest .mark .parametrize ("symbolic_rv" , (False , True ))
404376 def test_deterministic_of_unobserved (self , symbolic_rv ):
0 commit comments