Skip to content

Commit 5f02ad9

Browse files
moved sample stats argument to partial call
1 parent 32f3da3 commit 5f02ad9

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pymc/sampling/jax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,12 @@ def sample_blackjax_nuts(
440440
log_likelihood=log_likelihood,
441441
observed_data=find_observations(model),
442442
constant_data=find_constants(model),
443+
sample_stats=mcmc_stats,
443444
coords=coords,
444445
dims=dims,
445446
attrs=make_attrs(attrs, library=blackjax),
446447
)
447-
az_trace = to_trace(posterior=posterior, sample_stats=mcmc_stats, **idata_kwargs)
448+
az_trace = to_trace(posterior=posterior, **idata_kwargs)
448449

449450
return az_trace
450451

0 commit comments

Comments
 (0)