Skip to content

Commit 1dd85c1

Browse files
redesigned test for older blackjax version
1 parent 1011837 commit 1dd85c1

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pymc/tests/sampling/test_jax.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,6 @@ def test_idata_contains_stats(sampler_name: str):
396396
"acceptance_rate": (n_chains, n_draws),
397397
"diverging": (n_chains, n_draws),
398398
"energy": (n_chains, n_draws),
399-
"n_steps": (n_chains, n_draws),
400399
"tree_depth": (n_chains, n_draws),
401400
"lp": (n_chains, n_draws),
402401
}
@@ -406,9 +405,12 @@ def test_idata_contains_stats(sampler_name: str):
406405
stat_vars = expected_stat_vars | blackjax_special_vars
407406
# Stats only expected for numpyro nuts
408407
elif sampler_name == "sample_numpyro_nuts":
409-
numpyro_special_vars = {"step_size": (n_chains, n_draws)}
408+
numpyro_special_vars = {
409+
"step_size": (n_chains, n_draws),
410+
"n_steps": (n_chains, n_draws),
411+
}
410412
stat_vars = expected_stat_vars | numpyro_special_vars
411413
# test existence and dimensionality
412414
for stat_var, stat_var_dims in stat_vars.items():
413-
assert stat_var in stats
415+
assert stat_var in stats.variables
414416
assert stats.get(stat_var).values.shape == stat_var_dims

0 commit comments

Comments
 (0)