Skip to content

Commit 9806f3f

Browse files
tests for added sample statistics
1 parent b5db350 commit 9806f3f

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

pymc/tests/sampling/test_jax.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,50 @@ def test_numpyro_nuts_kwargs_are_used(mocked: mock.MagicMock):
365365
assert nuts_sampler._adapt_step_size == adapt_step_size
366366
assert nuts_sampler._adapt_mass_matrix
367367
assert nuts_sampler._target_accept_prob == target_accept
368+
369+
370+
@pytest.mark.parametrize(
371+
"sampler_name",
372+
[
373+
"sample_blackjax_nuts",
374+
"sample_numpyro_nuts",
375+
],
376+
)
377+
def test_idata_contains_stats(sampler_name: str):
378+
"""Tests whether sampler statistics were written to sample_stats
379+
group of InferenceData"""
380+
if sampler_name == "sample_blackjax_nuts":
381+
sampler = sample_blackjax_nuts
382+
elif sampler_name == "sample_numpyro_nuts":
383+
sampler = sample_numpyro_nuts
384+
385+
with pm.Model():
386+
pm.Normal("a")
387+
idata = sampler(draws=10, tune=10)
388+
389+
stats = idata.get("sample_stats")
390+
assert stats is not None
391+
n_chains = stats.dims["chain"]
392+
n_draws = stats.dims["draw"]
393+
394+
# Stats vars expected for both samplers
395+
expected_stat_vars = {
396+
"acceptance_rate": (n_chains, n_draws),
397+
"diverging": (n_chains, n_draws),
398+
"energy": (n_chains, n_draws),
399+
"n_steps": (n_chains, n_draws),
400+
"tree_depth": (n_chains, n_draws),
401+
"lp": (n_chains, n_draws),
402+
}
403+
# Stats only expected for blackjax nuts
404+
if sampler_name == "sample_blackjax_nuts":
405+
blackjax_special_vars = {}
406+
stat_vars = expected_stat_vars | blackjax_special_vars
407+
# Stats only expected for numpyro nuts
408+
elif sampler_name == "sample_numpyro_nuts":
409+
numpyro_special_vars = {"step_size": (n_chains, n_draws)}
410+
stat_vars = expected_stat_vars | numpyro_special_vars
411+
# test existence and dimensionality
412+
for stat_var, stat_var_dims in stat_vars.items():
413+
assert stat_var in stats
414+
assert stats.get(stat_var).values.shape == stat_var_dims

0 commit comments

Comments
 (0)