From bffdee6a1bd4adb84bed71c5cc0c1450c808b87a Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Fri, 17 Jul 2020 21:31:55 +0200 Subject: [PATCH 1/4] transitions_from_chain now compatible with upstream updates --- src/inference/Inference.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index d529085515..099f9f8329 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -753,9 +753,17 @@ function transitions_from_chain( md = vi.metadata for v in keys(md) for vn in md[v].vns - vn_symbol = Symbol(vn) - if vn_symbol ∈ c.name_map.parameters - val = c[vn_symbol] + vn_sym = Symbol(vn) + + # This returns `()` if `vn` not present in `c` + # Otherwise it returns `(a = ..., )` even if + # `a` represents non-univariate. + res = get(c, vn_sym; flatten = false) + if !isempty(res) + # FIXME: this does not handle the cases where + # only a subset of the indices are set, e.g. + # if `a[1]` is in `chain` but `a[2]` is not. + val = copy.(vec(c[vn_sym].value)) DynamicPPL.setval!(vi, val, vn) DynamicPPL.settrans!(vi, false, vn) else From 3bc416760243b8d31a484fac91d712d237f12f60 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Fri, 17 Jul 2020 22:16:20 +0200 Subject: [PATCH 2/4] transitions_from_chain compatible with MCMCChains v4 --- src/inference/Inference.jl | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 099f9f8329..6ff1ac0cc7 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -755,19 +755,21 @@ function transitions_from_chain( for vn in md[v].vns vn_sym = Symbol(vn) - # This returns `()` if `vn` not present in `c` - # Otherwise it returns `(a = ..., )` even if - # `a` represents non-univariate. - res = get(c, vn_sym; flatten = false) - if !isempty(res) - # FIXME: this does not handle the cases where - # only a subset of the indices are set, e.g. - # if `a[1]` is in `chain` but `a[2]` is not. - val = copy.(vec(c[vn_sym].value)) + # Cannot use `vn_sym` to index in the chain + # so we have to extract the corresponding "linear" + # indices and use those. + # `ks` is empty if `vn_sym` not in `c`. + ks = MCMCChains.namesingroup(c, vn_sym) + + if !isempty(ks) + # 1st dimension is of size 1 since `c` + # only contains a single sample, and the + # last dimension is of size 1 since + # we're assuming we're working with a single chain. + val = copy.(vec(c[ks].value)) DynamicPPL.setval!(vi, val, vn) DynamicPPL.settrans!(vi, false, vn) else - # delete so we can sample from prior DynamicPPL.set_flag!(vi, vn, "del") end end From c8fade3b10b4665c103886026010717b3f39820d Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Fri, 17 Jul 2020 23:04:13 +0200 Subject: [PATCH 3/4] removed dot from copy --- src/inference/Inference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 6ff1ac0cc7..69507d9380 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -766,7 +766,7 @@ function transitions_from_chain( # only contains a single sample, and the # last dimension is of size 1 since # we're assuming we're working with a single chain. - val = copy.(vec(c[ks].value)) + val = copy(vec(c[ks].value)) DynamicPPL.setval!(vi, val, vn) DynamicPPL.settrans!(vi, false, vn) else From a23b6f6f36c7d8c67065cf684c66c41b0961eea1 Mon Sep 17 00:00:00 2001 From: "tor.erlend95@gmail.com" Date: Fri, 17 Jul 2020 23:04:27 +0200 Subject: [PATCH 4/4] added test for predict with model containing multivariate variable --- test/inference/utilities.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/inference/utilities.jl b/test/inference/utilities.jl index c9a95341d5..d96c54b3ef 100644 --- a/test/inference/utilities.jl +++ b/test/inference/utilities.jl @@ -11,6 +11,11 @@ using Random end end + @model function linear_reg_vec(x, y, σ = 0.1) + β ~ Normal(0, 1) + y ~ MvNormal(β .* x, σ) + end + f(x) = 2 * x + 0.1 * randn() Δ = 0.1 @@ -28,4 +33,11 @@ using Random ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1)) @test sum(abs2, ys_test - ys_pred) ≤ 0.1 + + # Predict on two last indices for vectorized + m_lin_reg_test = linear_reg_vec(xs_test, missing); + predictions_vec = Turing.Inference.predict(m_lin_reg_test, chain_lin_reg) + ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims = 1)) + + @test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 end