Skip to content

Commit 6a2ed47

Browse files
Consolidate tests of return options
Removes a regression test added in #3821 because it took 14 seconds.
1 parent a386095 commit 6a2ed47

File tree

1 file changed

+79
-107
lines changed

1 file changed

+79
-107
lines changed

tests/sampling/test_mcmc.py

Lines changed: 79 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -233,22 +233,20 @@ def test_sample_start_good_shape(self, start):
233233
def test_sample_callback(self):
234234
callback = mock.Mock()
235235
test_cores = [1, 2]
236-
test_chains = [1, 2]
237236
with self.model:
238237
for cores in test_cores:
239-
for chain in test_chains:
240-
with warnings.catch_warnings():
241-
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
242-
pm.sample(
243-
10,
244-
tune=0,
245-
chains=chain,
246-
step=self.step,
247-
cores=cores,
248-
random_seed=self.random_seed,
249-
callback=callback,
250-
)
251-
assert callback.called
238+
with warnings.catch_warnings():
239+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
240+
pm.sample(
241+
10,
242+
tune=0,
243+
chains=2,
244+
step=self.step,
245+
cores=cores,
246+
random_seed=self.random_seed,
247+
callback=callback,
248+
)
249+
assert callback.called
252250

253251
def test_callback_can_cancel(self):
254252
trace_cancel_length = 5
@@ -339,107 +337,81 @@ def astep(self, q0):
339337
class TestSampleReturn:
340338
"""Tests related to kwargs that parametrize how `pm.sample` results are returned."""
341339

342-
def test_sample_tune_len(self):
343-
with pm.Model():
340+
def test_sample_return_lengths(self):
341+
with pm.Model() as model:
344342
pm.Normal("n")
345-
with warnings.catch_warnings():
346-
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
347-
warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning)
348-
trace = pm.sample(
349-
draws=100, tune=50, cores=1, step=pm.Metropolis(), return_inferencedata=False
350-
)
351-
assert len(trace) == 100
352-
trace = pm.sample(
353-
draws=100,
354-
tune=50,
355-
cores=1,
356-
step=pm.Metropolis(),
357-
return_inferencedata=False,
358-
discard_tuned_samples=False,
359-
)
360-
assert len(trace) == 150
361-
trace = pm.sample(
362-
draws=100,
363-
tune=50,
364-
cores=4,
365-
step=pm.Metropolis(),
366-
return_inferencedata=False,
367-
)
368-
assert len(trace) == 100
369343

370-
@pytest.mark.parametrize("discard", [True, False])
371-
def test_trace_report(self, discard):
372-
with pm.Model():
373-
pm.Uniform("uni")
374-
with warnings.catch_warnings():
375-
warnings.filterwarnings(
376-
"ignore", ".*Tuning samples will be included.*", UserWarning
377-
)
378-
trace = pm.sample(
344+
# Get a MultiTrace with warmup
345+
with pytest.warns(UserWarning, match="will be included"):
346+
mtrace = pm.sample(
379347
draws=100,
380348
tune=50,
381349
cores=1,
382-
discard_tuned_samples=discard,
350+
chains=3,
383351
step=pm.Metropolis(),
384-
compute_convergence_checks=False,
385352
return_inferencedata=False,
386-
)
387-
assert trace.report.n_tune == 50
388-
assert trace.report.n_draws == 100
389-
assert isinstance(trace.report.t_sampling, float)
390-
391-
def test_return_inferencedata(self):
392-
model, _, step, _ = simple_init()
393-
with model:
394-
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=step)
395-
396-
# trace with tuning
397-
with pytest.warns(UserWarning, match="will be included"):
398-
result = pm.sample(
399-
**kwargs, return_inferencedata=False, discard_tuned_samples=False
400-
)
401-
assert isinstance(result, pm.backends.base.MultiTrace)
402-
assert len(result) == 150
403-
404-
# inferencedata with tuning
405-
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False)
406-
assert isinstance(result, InferenceData)
407-
assert result.posterior.sizes["draw"] == 100
408-
assert result.posterior.sizes["chain"] == 2
409-
assert len(result._groups_warmup) > 0
410-
411-
# inferencedata without tuning, with idata_kwargs
412-
prior = pm.sample_prior_predictive(return_inferencedata=False)
413-
result = pm.sample(
414-
**kwargs,
415-
return_inferencedata=True,
416-
discard_tuned_samples=True,
417-
idata_kwargs={"prior": prior},
418-
random_seed=-1,
419-
)
420-
assert "prior" in result
421-
assert isinstance(result, InferenceData)
422-
assert result.posterior.sizes["draw"] == 100
423-
assert result.posterior.sizes["chain"] == 2
424-
assert len(result._groups_warmup) == 0
425-
426-
@pytest.mark.parametrize("cores", [1, 2])
427-
def test_sampler_stat_tune(self, cores):
428-
with pm.Model():
429-
pm.Normal("n")
430-
with warnings.catch_warnings():
431-
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
432-
warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning)
433-
tune_stat = pm.sample(
434-
tune=5,
435-
draws=7,
436-
cores=cores,
437353
discard_tuned_samples=False,
438-
return_inferencedata=False,
439-
step=pm.Metropolis(),
440-
).get_sampler_stats("tune", chains=1)
441-
assert list(tune_stat).count(True) == 5
442-
assert list(tune_stat).count(False) == 7
354+
)
355+
assert isinstance(mtrace, pm.backends.base.MultiTrace)
356+
assert len(mtrace) == 150
357+
358+
# Now instead of running more MCMCs, we'll test the other return
359+
# options using the basetraces inside the MultiTrace.
360+
traces = list(mtrace._straces.values())
361+
assert len(traces) == 3
362+
363+
# MultiTrace without warmup
364+
mtrace_pst = pm.sampling.mcmc._sample_return(
365+
traces=traces,
366+
tune=50,
367+
t_sampling=123.4,
368+
discard_tuned_samples=True,
369+
return_inferencedata=False,
370+
compute_convergence_checks=False,
371+
keep_warning_stat=True,
372+
idata_kwargs={},
373+
model=model,
374+
)
375+
assert isinstance(mtrace_pst, pm.backends.base.MultiTrace)
376+
assert len(mtrace_pst) == 100
377+
assert mtrace_pst.report.t_sampling == 123.4
378+
assert mtrace_pst.report.n_tune == 50
379+
assert mtrace_pst.report.n_draws == 100
380+
381+
# InferenceData with warmup
382+
idata_w = pm.sampling.mcmc._sample_return(
383+
traces=traces,
384+
tune=50,
385+
t_sampling=123.4,
386+
discard_tuned_samples=False,
387+
compute_convergence_checks=False,
388+
return_inferencedata=True,
389+
keep_warning_stat=True,
390+
idata_kwargs={},
391+
model=model,
392+
)
393+
assert isinstance(idata_w, InferenceData)
394+
assert hasattr(idata_w, "warmup_posterior")
395+
assert idata_w.warmup_posterior.sizes["draw"] == 50
396+
assert idata_w.posterior.sizes["draw"] == 100
397+
assert idata_w.posterior.sizes["chain"] == 3
398+
399+
# InferenceData without warmup
400+
idata = pm.sampling.mcmc._sample_return(
401+
traces=traces,
402+
tune=50,
403+
t_sampling=123.4,
404+
discard_tuned_samples=True,
405+
compute_convergence_checks=False,
406+
return_inferencedata=True,
407+
keep_warning_stat=False,
408+
idata_kwargs={},
409+
model=model,
410+
)
411+
assert isinstance(idata, InferenceData)
412+
assert not hasattr(idata, "warmup_posterior")
413+
assert idata.posterior.sizes["draw"] == 100
414+
assert idata.posterior.sizes["chain"] == 3
443415

444416
@pytest.mark.parametrize("cores", [1, 2])
445417
def test_logs_sampler_warnings(self, caplog, cores):

0 commit comments

Comments
 (0)