diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index d529085515..69507d9380 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -753,13 +753,23 @@ function transitions_from_chain( md = vi.metadata for v in keys(md) for vn in md[v].vns - vn_symbol = Symbol(vn) - if vn_symbol ∈ c.name_map.parameters - val = c[vn_symbol] + vn_sym = Symbol(vn) + + # Cannot use `vn_sym` to index in the chain + # so we have to extract the corresponding "linear" + # indices and use those. + # `ks` is empty if `vn_sym` not in `c`. + ks = MCMCChains.namesingroup(c, vn_sym) + + if !isempty(ks) + # 1st dimension is of size 1 since `c` + # only contains a single sample, and the + # last dimension is of size 1 since + # we're assuming we're working with a single chain. + val = copy(vec(c[ks].value)) DynamicPPL.setval!(vi, val, vn) DynamicPPL.settrans!(vi, false, vn) else - # delete so we can sample from prior DynamicPPL.set_flag!(vi, vn, "del") end end diff --git a/test/inference/utilities.jl b/test/inference/utilities.jl index c9a95341d5..d96c54b3ef 100644 --- a/test/inference/utilities.jl +++ b/test/inference/utilities.jl @@ -11,6 +11,11 @@ using Random end end + @model function linear_reg_vec(x, y, σ = 0.1) + β ~ Normal(0, 1) + y ~ MvNormal(β .* x, σ) + end + f(x) = 2 * x + 0.1 * randn() Δ = 0.1 @@ -28,4 +33,11 @@ using Random ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1)) @test sum(abs2, ys_test - ys_pred) ≤ 0.1 + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing); + predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) + ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims = 1)) + + @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 end