Skip to content

Commit 1999645

Browse files
authored
Fix for issue #1352 (#1357)
* transitions_from_chain now compatible with upstream updates * transitions_from_chain compatible with MCMCChains v4 * removed dot from copy * added test for predict with model containing multivariate variable
1 parent 742dc2b commit 1999645

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

src/inference/Inference.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -753,13 +753,23 @@ function transitions_from_chain(
753753
md = vi.metadata
754754
for v in keys(md)
755755
for vn in md[v].vns
756-
vn_symbol = Symbol(vn)
757-
if vn_symbol c.name_map.parameters
758-
val = c[vn_symbol]
756+
vn_sym = Symbol(vn)
757+
758+
# Cannot use `vn_sym` to index in the chain
759+
# so we have to extract the corresponding "linear"
760+
# indices and use those.
761+
# `ks` is empty if `vn_sym` not in `c`.
762+
ks = MCMCChains.namesingroup(c, vn_sym)
763+
764+
if !isempty(ks)
765+
# 1st dimension is of size 1 since `c`
766+
# only contains a single sample, and the
767+
# last dimension is of size 1 since
768+
# we're assuming we're working with a single chain.
769+
val = copy(vec(c[ks].value))
759770
DynamicPPL.setval!(vi, val, vn)
760771
DynamicPPL.settrans!(vi, false, vn)
761772
else
762-
# delete so we can sample from prior
763773
DynamicPPL.set_flag!(vi, vn, "del")
764774
end
765775
end

test/inference/utilities.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ using Random
1111
end
1212
end
1313

14+
@model function linear_reg_vec(x, y, σ = 0.1)
15+
β ~ Normal(0, 1)
16+
y ~ MvNormal.* x, σ)
17+
end
18+
1419
f(x) = 2 * x + 0.1 * randn()
1520

1621
Δ = 0.1
@@ -28,4 +33,11 @@ using Random
2833
ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1))
2934

3035
@test sum(abs2, ys_test - ys_pred) 0.1
36+
37+
# Predict on two last indices for vectorized
38+
m_lin_reg_test = linear_reg_vec(xs_test, missing);
39+
predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg)
40+
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims = 1))
41+
42+
@test sum(abs2, ys_test - ys_pred_vec) 0.1
3143
end

0 commit comments

Comments
 (0)