We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 32f3da3 commit 5f02ad9Copy full SHA for 5f02ad9
pymc/sampling/jax.py
@@ -440,11 +440,12 @@ def sample_blackjax_nuts(
440
log_likelihood=log_likelihood,
441
observed_data=find_observations(model),
442
constant_data=find_constants(model),
443
+ sample_stats=mcmc_stats,
444
coords=coords,
445
dims=dims,
446
attrs=make_attrs(attrs, library=blackjax),
447
)
- az_trace = to_trace(posterior=posterior, sample_stats=mcmc_stats, **idata_kwargs)
448
+ az_trace = to_trace(posterior=posterior, **idata_kwargs)
449
450
return az_trace
451
0 commit comments