@@ -365,3 +365,50 @@ def test_numpyro_nuts_kwargs_are_used(mocked: mock.MagicMock):
365365 assert nuts_sampler ._adapt_step_size == adapt_step_size
366366 assert nuts_sampler ._adapt_mass_matrix
367367 assert nuts_sampler ._target_accept_prob == target_accept
368+
369+
370+ @pytest .mark .parametrize (
371+ "sampler_name" ,
372+ [
373+ "sample_blackjax_nuts" ,
374+ "sample_numpyro_nuts" ,
375+ ],
376+ )
377+ def test_idata_contains_stats (sampler_name : str ):
378+ """Tests whether sampler statistics were written to sample_stats
379+ group of InferenceData"""
380+ if sampler_name == "sample_blackjax_nuts" :
381+ sampler = sample_blackjax_nuts
382+ elif sampler_name == "sample_numpyro_nuts" :
383+ sampler = sample_numpyro_nuts
384+
385+ with pm .Model ():
386+ pm .Normal ("a" )
387+ idata = sampler (draws = 10 , tune = 10 )
388+
389+ stats = idata .get ("sample_stats" )
390+ assert stats is not None
391+ n_chains = stats .dims ["chain" ]
392+ n_draws = stats .dims ["draw" ]
393+
394+ # Stats vars expected for both samplers
395+ expected_stat_vars = {
396+ "acceptance_rate" : (n_chains , n_draws ),
397+ "diverging" : (n_chains , n_draws ),
398+ "energy" : (n_chains , n_draws ),
399+ "n_steps" : (n_chains , n_draws ),
400+ "tree_depth" : (n_chains , n_draws ),
401+ "lp" : (n_chains , n_draws ),
402+ }
403+ # Stats only expected for blackjax nuts
404+ if sampler_name == "sample_blackjax_nuts" :
405+ blackjax_special_vars = {}
406+ stat_vars = expected_stat_vars | blackjax_special_vars
407+ # Stats only expected for numpyro nuts
408+ elif sampler_name == "sample_numpyro_nuts" :
409+ numpyro_special_vars = {"step_size" : (n_chains , n_draws )}
410+ stat_vars = expected_stat_vars | numpyro_special_vars
411+ # test existence and dimensionality
412+ for stat_var , stat_var_dims in stat_vars .items ():
413+ assert stat_var in stats
414+ assert stats .get (stat_var ).values .shape == stat_var_dims
0 commit comments