5555function tilde_assume (context:: PriorContext{<:NamedTuple} , right, vn, vi)
5656 if haskey (context. vars, getsym (vn))
5757 vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
58- settrans! (vi, false , vn)
58+ settrans!! (vi, false , vn)
5959 end
6060 return tilde_assume (PriorContext (), right, vn, vi)
6161end
@@ -64,15 +64,15 @@ function tilde_assume(
6464)
6565 if haskey (context. vars, getsym (vn))
6666 vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
67- settrans! (vi, false , vn)
67+ settrans!! (vi, false , vn)
6868 end
6969 return tilde_assume (rng, PriorContext (), sampler, right, vn, vi)
7070end
7171
7272function tilde_assume (context:: LikelihoodContext{<:NamedTuple} , right, vn, vi)
7373 if haskey (context. vars, getsym (vn))
7474 vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
75- settrans! (vi, false , vn)
75+ settrans!! (vi, false , vn)
7676 end
7777 return tilde_assume (LikelihoodContext (), right, vn, vi)
7878end
@@ -86,7 +86,7 @@ function tilde_assume(
8686)
8787 if haskey (context. vars, getsym (vn))
8888 vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
89- settrans! (vi, false , vn)
89+ settrans!! (vi, false , vn)
9090 end
9191 return tilde_assume (rng, LikelihoodContext (), sampler, right, vn, vi)
9292end
194194
195195# fallback without sampler
196196function assume (dist:: Distribution , vn:: VarName , vi)
197- r = vi[vn]
197+ # x = vi[vn]
198+ r_raw = getindex_raw (vi, vn)
199+ r = maybe_invlink (vi, vn, dist, r_raw)
198200 return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn)), vi
199201end
200202
@@ -211,16 +213,23 @@ function assume(
211213 if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
212214 unset_flag! (vi, vn, " del" )
213215 r = init (rng, dist, sampler)
214- vi[vn] = vectorize (dist, r)
215- settrans! (vi, false , vn)
216+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r))
216217 setorder! (vi, vn, get_num_produce (vi))
217218 else
218- r = vi[vn]
219+ # Otherwise we just extract it.
220+ # r = vi[vn]
221+ r_raw = getindex_raw (vi, vn)
222+ r = maybe_invlink (vi, vn, dist, r_raw)
219223 end
220224 else
221225 r = init (rng, dist, sampler)
222- push!! (vi, vn, r, dist, sampler)
223- settrans! (vi, false , vn)
226+ if istrans (vi)
227+ push!! (vi, vn, link (dist, r), dist, sampler)
228+ # By default `push!!` sets the transformed flag to `false`.
229+ settrans!! (vi, true , vn)
230+ else
231+ push!! (vi, vn, r, dist, sampler)
232+ end
224233 end
225234
226235 return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn)), vi
@@ -286,7 +295,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
286295 var = get (context. vars, vn)
287296 _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
288297 set_val! (vi, _vns, _right, _left)
289- settrans! .(Ref (vi), false , _vns)
298+ settrans!! .(Ref (vi), false , _vns)
290299 dot_tilde_assume (LikelihoodContext (), _right, _left, _vns, vi)
291300 else
292301 dot_tilde_assume (LikelihoodContext (), right, left, vn, vi)
@@ -305,7 +314,7 @@ function dot_tilde_assume(
305314 var = get (context. vars, vn)
306315 _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
307316 set_val! (vi, _vns, _right, _left)
308- settrans! .(Ref (vi), false , _vns)
317+ settrans!! .(Ref (vi), false , _vns)
309318 dot_tilde_assume (rng, LikelihoodContext (), sampler, _right, _left, _vns, vi)
310319 else
311320 dot_tilde_assume (rng, LikelihoodContext (), sampler, right, left, vn, vi)
@@ -326,7 +335,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
326335 var = get (context. vars, vn)
327336 _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
328337 set_val! (vi, _vns, _right, _left)
329- settrans! .(Ref (vi), false , _vns)
338+ settrans!! .(Ref (vi), false , _vns)
330339 dot_tilde_assume (PriorContext (), _right, _left, _vns, vi)
331340 else
332341 dot_tilde_assume (PriorContext (), right, left, vn, vi)
@@ -345,7 +354,7 @@ function dot_tilde_assume(
345354 var = get (context. vars, vn)
346355 _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
347356 set_val! (vi, _vns, _right, _left)
348- settrans! .(Ref (vi), false , _vns)
357+ settrans!! .(Ref (vi), false , _vns)
349358 dot_tilde_assume (rng, PriorContext (), sampler, _right, _left, _vns, vi)
350359 else
351360 dot_tilde_assume (rng, PriorContext (), sampler, right, left, vn, vi)
@@ -390,7 +399,9 @@ function dot_assume(
390399 # m .~ Normal()
391400 #
392401 # in which case `var` will have `undef` elements, even if `m` is present in `vi`.
393- r = vi[vns]
402+ # r = vi[vns]
403+ r_raw = getindex_raw (vi, vns)
404+ r = maybe_invlink (vi, vn, dist, r_raw)
394405 lp = sum (zip (vns, eachcol (r))) do (vn, ri)
395406 return Bijectors. logpdf_with_trans (dist, ri, istrans (vi, vn))
396407 end
@@ -423,7 +434,8 @@ function dot_assume(
423434 # m .~ Normal()
424435 #
425436 # in which case `var` will have `undef` elements, even if `m` is present in `vi`.
426- r = reshape (vi[vec (vns)], size (vns))
437+ r_raw = getindex_raw (vi, vec (vns))
438+ r = reshape (maybe_invlink .(Ref (vi), vns, dists, r_raw), size (vns))
427439 lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans (vi, vns[1 ])))
428440 return r, lp, vi
429441end
@@ -462,19 +474,24 @@ function get_and_set_val!(
462474 r = init (rng, dist, spl, n)
463475 for i in 1 : n
464476 vn = vns[i]
465- vi[vn] = vectorize (dist, r[:, i])
466- settrans! (vi, false , vn)
477+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r[:, i]))
467478 setorder! (vi, vn, get_num_produce (vi))
468479 end
469480 else
470- r = vi[vns]
481+ r_raw = getindex_raw (vi, vns)
482+ r = maybe_invlink (vi, vns, dist, r_raw)
471483 end
472484 else
473485 r = init (rng, dist, spl, n)
474486 for i in 1 : n
475487 vn = vns[i]
476- push!! (vi, vn, r[:, i], dist, spl)
477- settrans! (vi, false , vn)
488+ if istrans (vi)
489+ push!! (vi, vn, maybe_link (vi, vn, dist, r[:, i]), dist, spl)
490+ # `push!!` sets the trans-flag to `false` by default.
491+ setttrans!! (vi, true , vn)
492+ else
493+ push!! (vi, vn, r[:, i], dist, spl)
494+ end
478495 end
479496 end
480497 return r
@@ -496,12 +513,13 @@ function get_and_set_val!(
496513 for i in eachindex (vns)
497514 vn = vns[i]
498515 dist = dists isa AbstractArray ? dists[i] : dists
499- vi[vn] = vectorize (dist, r[i])
500- settrans! (vi, false , vn)
516+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r[i]))
501517 setorder! (vi, vn, get_num_produce (vi))
502518 end
503519 else
504- r = reshape (vi[vec (vns)], size (vns))
520+ # r = reshape(vi[vec(vns)], size(vns))
521+ r_raw = getindex_raw (vi, vec (vns))
522+ r = maybe_invlink .(Ref (vi), vns, dists, reshape (r_raw, size (vns)))
505523 end
506524 else
507525 f = (vn, dist) -> init (rng, dist, spl)
@@ -511,8 +529,13 @@ function get_and_set_val!(
511529 # 1. Figure out the broadcast size and use a `foreach`.
512530 # 2. Define an anonymous function which returns `nothing`, which
513531 # we then broadcast. This will allocate a vector of `nothing` though.
514- push!! .(Ref (vi), vns, r, dists, Ref (spl))
515- settrans! .(Ref (vi), false , vns)
532+ if istrans (vi)
533+ push!! .(Ref (vi), vns, link .(Ref (vi), vns, dists, r), dists, Ref (spl))
534+ # `push!!` sets the trans-flag to `false` by default.
535+ settrans!! .(Ref (vi), true , vns)
536+ else
537+ push!! .(Ref (vi), vns, r, dists, Ref (spl))
538+ end
516539 end
517540 return r
518541end
0 commit comments