Skip to content

Commit 32f3da3

Browse files
record blackjax sample stats
1 parent 1dd85c1 commit 32f3da3

File tree

1 file changed

+38
-4
lines changed

1 file changed

+38
-4
lines changed

pymc/sampling/jax.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,40 @@ def _sample_stats_to_xarray(posterior):
142142
return data
143143

144144

145+
def _blackjax_stats_to_dict(sample_stats, potential_energy) -> Dict:
146+
"""Extract compatible stats from blackjax NUTS sampler
147+
with PyMC/Arviz naming conventions.
148+
149+
Parameters
150+
----------
151+
sample_stats: NUTSInfo
152+
Blackjax NUTSInfo object containing sampler statistics
153+
potential_energy: ArrayLike
154+
Potential energy values of sampled positions.
155+
156+
Returns
157+
-------
158+
Dict[str, ArrayLike]
159+
Dictionary of sampler statistics.
160+
"""
161+
rename_key = {
162+
"is_divergent": "diverging",
163+
"energy": "energy",
164+
"num_trajectory_expansions": "tree_depth",
165+
"num_integration_steps": "n_steps",
166+
"acceptance_rate": "acceptance_rate", # naming here is
167+
"acceptance_probability": "acceptance_rate", # depending on blackjax version
168+
}
169+
converted_stats = {}
170+
converted_stats["lp"] = potential_energy
171+
for old_name, new_name in rename_key.items():
172+
value = getattr(sample_stats, old_name, None)
173+
if value is None:
174+
continue
175+
converted_stats[new_name] = value
176+
return converted_stats
177+
178+
145179
def _get_log_likelihood(model: Model, samples, backend=None) -> Dict:
146180
"""Compute log-likelihood for all observations"""
147181
elemwise_logp = model.logp(model.observed_RVs, sum=False)
@@ -360,9 +394,9 @@ def sample_blackjax_nuts(
360394
"Only supporting the following methods to draw chains:" ' "parallel" or "vectorized"'
361395
)
362396

363-
states, _ = map_fn(get_posterior_samples)(keys, init_params)
397+
states, stats = map_fn(get_posterior_samples)(keys, init_params)
364398
raw_mcmc_samples = states.position
365-
399+
potential_energy = states.potential_energy
366400
tic3 = datetime.now()
367401
print("Sampling time = ", tic3 - tic2, file=sys.stdout)
368402

@@ -372,7 +406,7 @@ def sample_blackjax_nuts(
372406
*jax.device_put(raw_mcmc_samples, jax.devices(postprocessing_backend)[0])
373407
)
374408
mcmc_samples = {v.name: r for v, r in zip(vars_to_sample, result)}
375-
409+
mcmc_stats = _blackjax_stats_to_dict(stats, potential_energy)
376410
tic4 = datetime.now()
377411
print("Transformation time = ", tic4 - tic3, file=sys.stdout)
378412

@@ -410,7 +444,7 @@ def sample_blackjax_nuts(
410444
dims=dims,
411445
attrs=make_attrs(attrs, library=blackjax),
412446
)
413-
az_trace = to_trace(posterior=posterior, **idata_kwargs)
447+
az_trace = to_trace(posterior=posterior, sample_stats=mcmc_stats, **idata_kwargs)
414448

415449
return az_trace
416450

0 commit comments

Comments
 (0)