File tree Expand file tree Collapse file tree 2 files changed +26
-4
lines changed Expand file tree Collapse file tree 2 files changed +26
-4
lines changed Original file line number Diff line number Diff line change @@ -753,13 +753,23 @@ function transitions_from_chain(
753753 md = vi. metadata
754754 for v in keys (md)
755755 for vn in md[v]. vns
756- vn_symbol = Symbol (vn)
757- if vn_symbol ∈ c. name_map. parameters
758- val = c[vn_symbol]
756+ vn_sym = Symbol (vn)
757+
758+ # Cannot use `vn_sym` to index in the chain
759+ # so we have to extract the corresponding "linear"
760+ # indices and use those.
761+ # `ks` is empty if `vn_sym` not in `c`.
762+ ks = MCMCChains. namesingroup (c, vn_sym)
763+
764+ if ! isempty (ks)
765+ # 1st dimension is of size 1 since `c`
766+ # only contains a single sample, and the
767+ # last dimension is of size 1 since
768+ # we're assuming we're working with a single chain.
769+ val = copy (vec (c[ks]. value))
759770 DynamicPPL. setval! (vi, val, vn)
760771 DynamicPPL. settrans! (vi, false , vn)
761772 else
762- # delete so we can sample from prior
763773 DynamicPPL. set_flag! (vi, vn, " del" )
764774 end
765775 end
Original file line number Diff line number Diff line change @@ -11,6 +11,11 @@ using Random
1111 end
1212 end
1313
14+ @model function linear_reg_vec (x, y, σ = 0.1 )
15+ β ~ Normal (0 , 1 )
16+ y ~ MvNormal (β .* x, σ)
17+ end
18+
1419 f (x) = 2 * x + 0.1 * randn ()
1520
1621 Δ = 0.1
@@ -28,4 +33,11 @@ using Random
2833 ys_pred = vec (mean (Array (group (predictions, :y )); dims = 1 ))
2934
3035 @test sum (abs2, ys_test - ys_pred) ≤ 0.1
36+
37+ # Predict on two last indices for vectorized
38+ m_lin_reg_test = linear_reg_vec (xs_test, missing );
39+ predictions_vec = Turing. Inference. predict (m_lin_reg_test, chain_lin_reg)
40+ ys_pred_vec = vec (mean (Array (group (predictions_vec, :y )); dims = 1 ))
41+
42+ @test sum (abs2, ys_test - ys_pred_vec) ≤ 0.1
3143end
You can’t perform that action at this time.
0 commit comments