Skip to content

Commit 62cb709

Browse files
Speed up test_mcmc
1 parent a3df225 commit 62cb709

File tree

2 files changed

+10
-32
lines changed

2 files changed

+10
-32
lines changed

tests/backends/test_ndarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ def test_multitrace_nonunique(self):
123123
with pytest.raises(ValueError):
124124
base.MultiTrace([self.strace0, self.strace1])
125125

126+
def test_multitrace_iter_notimplemented(self):
127+
mtrace = base.MultiTrace([self.strace0])
128+
with pytest.raises(NotImplementedError):
129+
for _ in mtrace:
130+
pass
131+
126132

127133
class TestSqueezeCat:
128134
def setup_method(self):

tests/sampling/test_mcmc.py

Lines changed: 4 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -151,24 +151,6 @@ def test_sample_does_not_rely_on_external_global_seeding(self):
151151
assert np.all(idata12["x"] != idata22["x"])
152152
assert np.all(idata13["x"] != idata23["x"])
153153

154-
def test_sample(self):
155-
test_cores = [1]
156-
with self.model:
157-
for cores in test_cores:
158-
for steps in [1, 10, 300]:
159-
with warnings.catch_warnings():
160-
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
161-
warnings.filterwarnings(
162-
"ignore", "More chains .* than draws .*", UserWarning
163-
)
164-
pm.sample(
165-
steps,
166-
tune=0,
167-
step=self.step,
168-
cores=cores,
169-
random_seed=self.random_seed,
170-
)
171-
172154
def test_sample_init(self):
173155
with self.model:
174156
for init in (
@@ -199,11 +181,11 @@ def test_sample_init(self):
199181
def test_sample_args(self):
200182
with self.model:
201183
with pytest.raises(ValueError) as excinfo:
202-
pm.sample(50, tune=0, foo=1)
184+
pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo=1)
203185
assert "'foo'" in str(excinfo.value)
204186

205187
with pytest.raises(ValueError) as excinfo:
206-
pm.sample(50, tune=0, foo={})
188+
pm.sample(50, tune=0, chains=1, step=pm.Metropolis(), foo={})
207189
assert "foo" in str(excinfo.value)
208190

209191
def test_parallel_start(self):
@@ -232,6 +214,7 @@ def test_sample_tune_len(self):
232214
draws=100,
233215
tune=50,
234216
cores=1,
217+
step=pm.Metropolis(),
235218
return_inferencedata=False,
236219
discard_tuned_samples=False,
237220
)
@@ -387,18 +370,7 @@ def test_sequential_backend(self):
387370
backend = NDArray()
388371
with warnings.catch_warnings():
389372
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
390-
pm.sample(10, cores=1, chains=2, trace=backend)
391-
392-
def test_exceptions(self):
393-
# Test iteration over MultiTrace NotImplementedError
394-
with pm.Model() as model:
395-
mu = pm.Normal("mu", 0.0, 1.0)
396-
a = pm.Normal("a", mu=mu, sigma=1, observed=np.array([0.5, 0.2]))
397-
with warnings.catch_warnings():
398-
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
399-
trace = pm.sample(tune=0, draws=10, chains=2, return_inferencedata=False)
400-
with pytest.raises(NotImplementedError):
401-
xvars = [t["mu"] for t in trace]
373+
pm.sample(10, tune=5, cores=1, chains=2, step=pm.Metropolis(), trace=backend)
402374

403375
@pytest.mark.parametrize("symbolic_rv", (False, True))
404376
def test_deterministic_of_unobserved(self, symbolic_rv):

0 commit comments

Comments
 (0)