Skip to content

Commit a386095

Browse files
Group tests related to pm.sample return parameters
1 parent 62cb709 commit a386095

File tree

1 file changed

+182
-171
lines changed

1 file changed

+182
-171
lines changed

tests/sampling/test_mcmc.py

Lines changed: 182 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -203,25 +203,6 @@ def test_parallel_start(self):
203203
assert idata.warmup_posterior["x"].sel(chain=0, draw=0).values[0] > 0
204204
assert idata.warmup_posterior["x"].sel(chain=1, draw=0).values[0] < 0
205205

206-
def test_sample_tune_len(self):
207-
with self.model:
208-
with warnings.catch_warnings():
209-
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
210-
warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning)
211-
trace = pm.sample(draws=100, tune=50, cores=1, return_inferencedata=False)
212-
assert len(trace) == 100
213-
trace = pm.sample(
214-
draws=100,
215-
tune=50,
216-
cores=1,
217-
step=pm.Metropolis(),
218-
return_inferencedata=False,
219-
discard_tuned_samples=False,
220-
)
221-
assert len(trace) == 150
222-
trace = pm.sample(draws=100, tune=50, cores=4, return_inferencedata=False)
223-
assert len(trace) == 100
224-
225206
def test_reset_tuning(self):
226207
with self.model:
227208
tune = 50
@@ -233,80 +214,6 @@ def test_reset_tuning(self):
233214
assert step.potential._n_samples == tune
234215
assert step.step_adapt._count == tune + 1
235216

236-
@pytest.mark.parametrize("step_cls", [pm.NUTS, pm.Metropolis, pm.Slice])
237-
@pytest.mark.parametrize("discard", [True, False])
238-
def test_trace_report(self, step_cls, discard):
239-
with self.model:
240-
# add more variables, because stats are 2D with CompoundStep!
241-
pm.Uniform("uni")
242-
with warnings.catch_warnings():
243-
warnings.filterwarnings(
244-
"ignore", ".*Tuning samples will be included.*", UserWarning
245-
)
246-
trace = pm.sample(
247-
draws=100,
248-
tune=50,
249-
cores=1,
250-
discard_tuned_samples=discard,
251-
step=step_cls(),
252-
compute_convergence_checks=False,
253-
return_inferencedata=False,
254-
)
255-
assert trace.report.n_tune == 50
256-
assert trace.report.n_draws == 100
257-
assert isinstance(trace.report.t_sampling, float)
258-
259-
def test_return_inferencedata(self):
260-
with self.model:
261-
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())
262-
263-
# trace with tuning
264-
with pytest.warns(UserWarning, match="will be included"):
265-
result = pm.sample(
266-
**kwargs, return_inferencedata=False, discard_tuned_samples=False
267-
)
268-
assert isinstance(result, pm.backends.base.MultiTrace)
269-
assert len(result) == 150
270-
271-
# inferencedata with tuning
272-
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False)
273-
assert isinstance(result, InferenceData)
274-
assert result.posterior.sizes["draw"] == 100
275-
assert result.posterior.sizes["chain"] == 2
276-
assert len(result._groups_warmup) > 0
277-
278-
# inferencedata without tuning, with idata_kwargs
279-
prior = pm.sample_prior_predictive(return_inferencedata=False)
280-
result = pm.sample(
281-
**kwargs,
282-
return_inferencedata=True,
283-
discard_tuned_samples=True,
284-
idata_kwargs={"prior": prior},
285-
random_seed=-1,
286-
)
287-
assert "prior" in result
288-
assert isinstance(result, InferenceData)
289-
assert result.posterior.sizes["draw"] == 100
290-
assert result.posterior.sizes["chain"] == 2
291-
assert len(result._groups_warmup) == 0
292-
293-
@pytest.mark.parametrize("cores", [1, 2])
294-
def test_sampler_stat_tune(self, cores):
295-
with self.model:
296-
with warnings.catch_warnings():
297-
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
298-
warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning)
299-
tune_stat = pm.sample(
300-
tune=5,
301-
draws=7,
302-
cores=cores,
303-
discard_tuned_samples=False,
304-
return_inferencedata=False,
305-
step=pm.Metropolis(),
306-
).get_sampler_stats("tune", chains=1)
307-
assert list(tune_stat).count(True) == 5
308-
assert list(tune_stat).count(False) == 7
309-
310217
@pytest.mark.parametrize(
311218
"start, error",
312219
[
@@ -411,6 +318,188 @@ def test_transform_with_rv_dependency(self, symbolic_rv):
411318
assert np.allclose(scipy.special.expit(trace["y_interval__"]), trace["y"])
412319

413320

321+
class ApocalypticMetropolis(pm.Metropolis):
322+
"""A stepper that warns in every iteration."""
323+
324+
stats_dtypes_shapes = {
325+
**pm.Metropolis.stats_dtypes_shapes,
326+
"warning": (SamplerWarning, None),
327+
}
328+
329+
def astep(self, q0):
330+
draw, stats = super().astep(q0)
331+
stats[0]["warning"] = SamplerWarning(
332+
WarningType.BAD_ENERGY,
333+
"Asteroid incoming!",
334+
"warn",
335+
)
336+
return draw, stats
337+
338+
339+
class TestSampleReturn:
340+
"""Tests related to kwargs that parametrize how `pm.sample` results are returned."""
341+
342+
def test_sample_tune_len(self):
343+
with pm.Model():
344+
pm.Normal("n")
345+
with warnings.catch_warnings():
346+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
347+
warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning)
348+
trace = pm.sample(
349+
draws=100, tune=50, cores=1, step=pm.Metropolis(), return_inferencedata=False
350+
)
351+
assert len(trace) == 100
352+
trace = pm.sample(
353+
draws=100,
354+
tune=50,
355+
cores=1,
356+
step=pm.Metropolis(),
357+
return_inferencedata=False,
358+
discard_tuned_samples=False,
359+
)
360+
assert len(trace) == 150
361+
trace = pm.sample(
362+
draws=100,
363+
tune=50,
364+
cores=4,
365+
step=pm.Metropolis(),
366+
return_inferencedata=False,
367+
)
368+
assert len(trace) == 100
369+
370+
@pytest.mark.parametrize("discard", [True, False])
371+
def test_trace_report(self, discard):
372+
with pm.Model():
373+
pm.Uniform("uni")
374+
with warnings.catch_warnings():
375+
warnings.filterwarnings(
376+
"ignore", ".*Tuning samples will be included.*", UserWarning
377+
)
378+
trace = pm.sample(
379+
draws=100,
380+
tune=50,
381+
cores=1,
382+
discard_tuned_samples=discard,
383+
step=pm.Metropolis(),
384+
compute_convergence_checks=False,
385+
return_inferencedata=False,
386+
)
387+
assert trace.report.n_tune == 50
388+
assert trace.report.n_draws == 100
389+
assert isinstance(trace.report.t_sampling, float)
390+
391+
def test_return_inferencedata(self):
392+
model, _, step, _ = simple_init()
393+
with model:
394+
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=step)
395+
396+
# trace with tuning
397+
with pytest.warns(UserWarning, match="will be included"):
398+
result = pm.sample(
399+
**kwargs, return_inferencedata=False, discard_tuned_samples=False
400+
)
401+
assert isinstance(result, pm.backends.base.MultiTrace)
402+
assert len(result) == 150
403+
404+
# inferencedata with tuning
405+
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=False)
406+
assert isinstance(result, InferenceData)
407+
assert result.posterior.sizes["draw"] == 100
408+
assert result.posterior.sizes["chain"] == 2
409+
assert len(result._groups_warmup) > 0
410+
411+
# inferencedata without tuning, with idata_kwargs
412+
prior = pm.sample_prior_predictive(return_inferencedata=False)
413+
result = pm.sample(
414+
**kwargs,
415+
return_inferencedata=True,
416+
discard_tuned_samples=True,
417+
idata_kwargs={"prior": prior},
418+
random_seed=-1,
419+
)
420+
assert "prior" in result
421+
assert isinstance(result, InferenceData)
422+
assert result.posterior.sizes["draw"] == 100
423+
assert result.posterior.sizes["chain"] == 2
424+
assert len(result._groups_warmup) == 0
425+
426+
@pytest.mark.parametrize("cores", [1, 2])
427+
def test_sampler_stat_tune(self, cores):
428+
with pm.Model():
429+
pm.Normal("n")
430+
with warnings.catch_warnings():
431+
warnings.filterwarnings("ignore", ".*number of samples.*", UserWarning)
432+
warnings.filterwarnings("ignore", "Tuning samples will be included.*", UserWarning)
433+
tune_stat = pm.sample(
434+
tune=5,
435+
draws=7,
436+
cores=cores,
437+
discard_tuned_samples=False,
438+
return_inferencedata=False,
439+
step=pm.Metropolis(),
440+
).get_sampler_stats("tune", chains=1)
441+
assert list(tune_stat).count(True) == 5
442+
assert list(tune_stat).count(False) == 7
443+
444+
@pytest.mark.parametrize("cores", [1, 2])
445+
def test_logs_sampler_warnings(self, caplog, cores):
446+
"""Asserts that "warning" sampler stats are logged during sampling."""
447+
with pm.Model():
448+
pm.Normal("n")
449+
with caplog.at_level(logging.WARNING):
450+
idata = pm.sample(
451+
tune=2,
452+
draws=3,
453+
cores=cores,
454+
chains=cores,
455+
step=ApocalypticMetropolis(),
456+
compute_convergence_checks=False,
457+
discard_tuned_samples=False,
458+
keep_warning_stat=True,
459+
)
460+
461+
# Sampler warnings should be logged
462+
nwarns = sum("Asteroid" in rec.message for rec in caplog.records)
463+
assert nwarns == (2 + 3) * cores
464+
465+
@pytest.mark.parametrize("keep_warning_stat", [None, True])
466+
def test_keep_warning_stat_setting(self, keep_warning_stat):
467+
"""The ``keep_warning_stat`` stat (aka "Adrian's kwarg) enables users
468+
to keep the ``SamplerWarning`` objects from the ``sample_stats.warning`` group.
469+
This breaks ``idata.to_netcdf()`` which is why it defaults to ``False``.
470+
"""
471+
sample_kwargs = dict(
472+
tune=2,
473+
draws=3,
474+
chains=1,
475+
compute_convergence_checks=False,
476+
discard_tuned_samples=False,
477+
keep_warning_stat=keep_warning_stat,
478+
)
479+
if keep_warning_stat:
480+
sample_kwargs["keep_warning_stat"] = True
481+
with pm.Model():
482+
pm.Normal("n")
483+
idata = pm.sample(step=ApocalypticMetropolis(), **sample_kwargs)
484+
485+
if keep_warning_stat:
486+
assert "warning" in idata.warmup_sample_stats
487+
assert "warning" in idata.sample_stats
488+
# And end up in the InferenceData
489+
assert "warning" in idata.sample_stats
490+
# NOTE: The stats are squeezed by default but this does not always work.
491+
# This tests flattens so we don't have to be exact in accessing (non-)squeezed items.
492+
# Also see https://github.com/pymc-devs/pymc/issues/6207.
493+
warn_objs = list(idata.sample_stats.warning.sel(chain=0).values.flatten())
494+
assert any(isinstance(w, SamplerWarning) for w in warn_objs)
495+
assert any("Asteroid" in w.message for w in warn_objs)
496+
else:
497+
assert "warning" not in idata.warmup_sample_stats
498+
assert "warning" not in idata.sample_stats
499+
assert "warning_dim_0" not in idata.warmup_sample_stats
500+
assert "warning_dim_0" not in idata.sample_stats
501+
502+
414503
def test_sample_find_MAP_does_not_modify_start():
415504
# see https://github.com/pymc-devs/pymc/pull/4458
416505
with pm.Model():
@@ -603,84 +692,6 @@ def test_step_args():
603692
npt.assert_allclose(idata1.sample_stats.scaling, 0)
604693

605694

606-
class ApocalypticMetropolis(pm.Metropolis):
607-
"""A stepper that warns in every iteration."""
608-
609-
stats_dtypes_shapes = {
610-
**pm.Metropolis.stats_dtypes_shapes,
611-
"warning": (SamplerWarning, None),
612-
}
613-
614-
def astep(self, q0):
615-
draw, stats = super().astep(q0)
616-
stats[0]["warning"] = SamplerWarning(
617-
WarningType.BAD_ENERGY,
618-
"Asteroid incoming!",
619-
"warn",
620-
)
621-
return draw, stats
622-
623-
624-
@pytest.mark.parametrize("cores", [1, 2])
625-
def test_logs_sampler_warnings(caplog, cores):
626-
"""Asserts that "warning" sampler stats are logged during sampling."""
627-
with pm.Model():
628-
pm.Normal("n")
629-
with caplog.at_level(logging.WARNING):
630-
idata = pm.sample(
631-
tune=2,
632-
draws=3,
633-
cores=cores,
634-
chains=cores,
635-
step=ApocalypticMetropolis(),
636-
compute_convergence_checks=False,
637-
discard_tuned_samples=False,
638-
keep_warning_stat=True,
639-
)
640-
641-
# Sampler warnings should be logged
642-
nwarns = sum("Asteroid" in rec.message for rec in caplog.records)
643-
assert nwarns == (2 + 3) * cores
644-
645-
646-
@pytest.mark.parametrize("keep_warning_stat", [None, True])
647-
def test_keep_warning_stat_setting(keep_warning_stat):
648-
"""The ``keep_warning_stat`` stat (aka "Adrian's kwarg) enables users
649-
to keep the ``SamplerWarning`` objects from the ``sample_stats.warning`` group.
650-
This breaks ``idata.to_netcdf()`` which is why it defaults to ``False``.
651-
"""
652-
sample_kwargs = dict(
653-
tune=2,
654-
draws=3,
655-
chains=1,
656-
compute_convergence_checks=False,
657-
discard_tuned_samples=False,
658-
keep_warning_stat=keep_warning_stat,
659-
)
660-
if keep_warning_stat:
661-
sample_kwargs["keep_warning_stat"] = True
662-
with pm.Model():
663-
pm.Normal("n")
664-
idata = pm.sample(step=ApocalypticMetropolis(), **sample_kwargs)
665-
666-
if keep_warning_stat:
667-
assert "warning" in idata.warmup_sample_stats
668-
assert "warning" in idata.sample_stats
669-
# And end up in the InferenceData
670-
assert "warning" in idata.sample_stats
671-
# NOTE: The stats are squeezed by default but this does not always work.
672-
# This tests flattens so we don't have to be exact in accessing (non-)squeezed items.
673-
# Also see https://github.com/pymc-devs/pymc/issues/6207.
674-
warn_objs = list(idata.sample_stats.warning.sel(chain=0).values.flatten())
675-
assert any(isinstance(w, SamplerWarning) for w in warn_objs)
676-
assert any("Asteroid" in w.message for w in warn_objs)
677-
else:
678-
assert "warning" not in idata.warmup_sample_stats
679-
assert "warning" not in idata.sample_stats
680-
assert "warning_dim_0" not in idata.warmup_sample_stats
681-
assert "warning_dim_0" not in idata.sample_stats
682-
683-
684695
def test_init_nuts(caplog):
685696
with pm.Model() as model:
686697
a = pm.Normal("a")

0 commit comments

Comments
 (0)