@@ -142,6 +142,40 @@ def _sample_stats_to_xarray(posterior):
142142 return data
143143
144144
145+ def _blackjax_stats_to_dict (sample_stats , potential_energy ) -> Dict :
146+ """Extract compatible stats from blackjax NUTS sampler
147+ with PyMC/Arviz naming conventions.
148+
149+ Parameters
150+ ----------
151+ sample_stats: NUTSInfo
152+ Blackjax NUTSInfo object containing sampler statistics
153+ potential_energy: ArrayLike
154+ Potential energy values of sampled positions.
155+
156+ Returns
157+ -------
158+ Dict[str, ArrayLike]
159+ Dictionary of sampler statistics.
160+ """
161+ rename_key = {
162+ "is_divergent" : "diverging" ,
163+ "energy" : "energy" ,
164+ "num_trajectory_expansions" : "tree_depth" ,
165+ "num_integration_steps" : "n_steps" ,
166+ "acceptance_rate" : "acceptance_rate" , # naming here is
167+ "acceptance_probability" : "acceptance_rate" , # depending on blackjax version
168+ }
169+ converted_stats = {}
170+ converted_stats ["lp" ] = potential_energy
171+ for old_name , new_name in rename_key .items ():
172+ value = getattr (sample_stats , old_name , None )
173+ if value is None :
174+ continue
175+ converted_stats [new_name ] = value
176+ return converted_stats
177+
178+
145179def _get_log_likelihood (model : Model , samples , backend = None ) -> Dict :
146180 """Compute log-likelihood for all observations"""
147181 elemwise_logp = model .logp (model .observed_RVs , sum = False )
@@ -360,9 +394,9 @@ def sample_blackjax_nuts(
360394 "Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
361395 )
362396
363- states , _ = map_fn (get_posterior_samples )(keys , init_params )
397+ states , stats = map_fn (get_posterior_samples )(keys , init_params )
364398 raw_mcmc_samples = states .position
365-
399+ potential_energy = states . potential_energy
366400 tic3 = datetime .now ()
367401 print ("Sampling time = " , tic3 - tic2 , file = sys .stdout )
368402
@@ -372,7 +406,7 @@ def sample_blackjax_nuts(
372406 * jax .device_put (raw_mcmc_samples , jax .devices (postprocessing_backend )[0 ])
373407 )
374408 mcmc_samples = {v .name : r for v , r in zip (vars_to_sample , result )}
375-
409+ mcmc_stats = _blackjax_stats_to_dict ( stats , potential_energy )
376410 tic4 = datetime .now ()
377411 print ("Transformation time = " , tic4 - tic3 , file = sys .stdout )
378412
@@ -410,7 +444,7 @@ def sample_blackjax_nuts(
410444 dims = dims ,
411445 attrs = make_attrs (attrs , library = blackjax ),
412446 )
413- az_trace = to_trace (posterior = posterior , ** idata_kwargs )
447+ az_trace = to_trace (posterior = posterior , sample_stats = mcmc_stats , ** idata_kwargs )
414448
415449 return az_trace
416450
0 commit comments