@@ -396,7 +396,6 @@ def test_idata_contains_stats(sampler_name: str):
396396 "acceptance_rate" : (n_chains , n_draws ),
397397 "diverging" : (n_chains , n_draws ),
398398 "energy" : (n_chains , n_draws ),
399- "n_steps" : (n_chains , n_draws ),
400399 "tree_depth" : (n_chains , n_draws ),
401400 "lp" : (n_chains , n_draws ),
402401 }
@@ -406,9 +405,12 @@ def test_idata_contains_stats(sampler_name: str):
406405 stat_vars = expected_stat_vars | blackjax_special_vars
407406 # Stats only expected for numpyro nuts
408407 elif sampler_name == "sample_numpyro_nuts" :
409- numpyro_special_vars = {"step_size" : (n_chains , n_draws )}
408+ numpyro_special_vars = {
409+ "step_size" : (n_chains , n_draws ),
410+ "n_steps" : (n_chains , n_draws ),
411+ }
410412 stat_vars = expected_stat_vars | numpyro_special_vars
411413 # test existence and dimensionality
412414 for stat_var , stat_var_dims in stat_vars .items ():
413- assert stat_var in stats
415+ assert stat_var in stats . variables
414416 assert stats .get (stat_var ).values .shape == stat_var_dims
0 commit comments