5555function tilde_assume (context:: PriorContext{<:NamedTuple} , right, vn, vi)
5656 if haskey (context. vars, getsym (vn))
5757 vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
58- settrans! (vi, false , vn)
58+ settrans!! (vi, false , vn)
5959 end
6060 return tilde_assume (PriorContext (), right, vn, vi)
6161end
@@ -64,15 +64,15 @@ function tilde_assume(
6464)
6565 if haskey (context. vars, getsym (vn))
6666 vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
67- settrans! (vi, false , vn)
67+ settrans!! (vi, false , vn)
6868 end
6969 return tilde_assume (rng, PriorContext (), sampler, right, vn, vi)
7070end
7171
7272function tilde_assume (context:: LikelihoodContext{<:NamedTuple} , right, vn, vi)
7373 if haskey (context. vars, getsym (vn))
7474 vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
75- settrans! (vi, false , vn)
75+ settrans!! (vi, false , vn)
7676 end
7777 return tilde_assume (LikelihoodContext (), right, vn, vi)
7878end
@@ -86,7 +86,7 @@ function tilde_assume(
8686)
8787 if haskey (context. vars, getsym (vn))
8888 vi = setindex!! (vi, vectorize (right, get (context. vars, vn)), vn)
89- settrans! (vi, false , vn)
89+ settrans!! (vi, false , vn)
9090 end
9191 return tilde_assume (rng, LikelihoodContext (), sampler, right, vn, vi)
9292end
194194
195195# fallback without sampler
196196function assume (dist:: Distribution , vn:: VarName , vi)
197- r = vi[vn]
197+ r = vi[vn, dist ]
198198 return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn)), vi
199199end
200200
@@ -211,16 +211,21 @@ function assume(
211211 if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
212212 unset_flag! (vi, vn, " del" )
213213 r = init (rng, dist, sampler)
214- vi[vn] = vectorize (dist, r)
215- settrans! (vi, false , vn)
214+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r))
216215 setorder! (vi, vn, get_num_produce (vi))
217216 else
218- r = vi[vn]
217+ # Otherwise we just extract it.
218+ r = vi[vn, dist]
219219 end
220220 else
221221 r = init (rng, dist, sampler)
222- push!! (vi, vn, r, dist, sampler)
223- settrans! (vi, false , vn)
222+ if istrans (vi)
223+ push!! (vi, vn, link (dist, r), dist, sampler)
224+ # By default `push!!` sets the transformed flag to `false`.
225+ settrans!! (vi, true , vn)
226+ else
227+ push!! (vi, vn, r, dist, sampler)
228+ end
224229 end
225230
226231 return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn)), vi
@@ -286,7 +291,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
286291 var = get (context. vars, vn)
287292 _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
288293 set_val! (vi, _vns, _right, _left)
289- settrans! .( Ref (vi), false , _vns)
294+ settrans!! .( (vi, ), false , _vns)
290295 dot_tilde_assume (LikelihoodContext (), _right, _left, _vns, vi)
291296 else
292297 dot_tilde_assume (LikelihoodContext (), right, left, vn, vi)
@@ -305,19 +310,20 @@ function dot_tilde_assume(
305310 var = get (context. vars, vn)
306311 _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
307312 set_val! (vi, _vns, _right, _left)
308- settrans! .( Ref (vi), false , _vns)
313+ settrans!! .( (vi, ), false , _vns)
309314 dot_tilde_assume (rng, LikelihoodContext (), sampler, _right, _left, _vns, vi)
310315 else
311316 dot_tilde_assume (rng, LikelihoodContext (), sampler, right, left, vn, vi)
312317 end
313318end
319+
314320function dot_tilde_assume (context:: LikelihoodContext , right, left, vn, vi)
315- return dot_assume (NoDist . (right), left, vn, vi)
321+ return dot_assume (nodist (right), left, vn, vi)
316322end
317323function dot_tilde_assume (
318324 rng:: Random.AbstractRNG , context:: LikelihoodContext , sampler, right, left, vn, vi
319325)
320- return dot_assume (rng, sampler, NoDist . (right), vn, left, vi)
326+ return dot_assume (rng, sampler, nodist (right), vn, left, vi)
321327end
322328
323329# `PriorContext`
@@ -326,7 +332,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
326332 var = get (context. vars, vn)
327333 _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
328334 set_val! (vi, _vns, _right, _left)
329- settrans! .( Ref (vi), false , _vns)
335+ settrans!! .( (vi, ), false , _vns)
330336 dot_tilde_assume (PriorContext (), _right, _left, _vns, vi)
331337 else
332338 dot_tilde_assume (PriorContext (), right, left, vn, vi)
@@ -345,7 +351,7 @@ function dot_tilde_assume(
345351 var = get (context. vars, vn)
346352 _right, _left, _vns = unwrap_right_left_vns (right, var, vn)
347353 set_val! (vi, _vns, _right, _left)
348- settrans! .( Ref (vi), false , _vns)
354+ settrans!! .( (vi, ), false , _vns)
349355 dot_tilde_assume (rng, PriorContext (), sampler, _right, _left, _vns, vi)
350356 else
351357 dot_tilde_assume (rng, PriorContext (), sampler, right, left, vn, vi)
@@ -383,14 +389,14 @@ function dot_assume(
383389 vns:: AbstractVector{<:VarName} ,
384390 vi:: AbstractVarInfo ,
385391)
386- @assert length (dist) == size (var, 1 )
392+ @assert length (dist) == size (var, 1 ) " dimensionality of `var` ( $( size (var, 1 )) ) is incompatible with dimensionality of `dist` $( length (dist)) "
387393 # NOTE: We cannot work with `var` here because we might have a model of the form
388394 #
389395 # m = Vector{Float64}(undef, n)
390396 # m .~ Normal()
391397 #
392398 # in which case `var` will have `undef` elements, even if `m` is present in `vi`.
393- r = vi[vns]
399+ r = vi[vns, dist ]
394400 lp = sum (zip (vns, eachcol (r))) do (vn, ri)
395401 return Bijectors. logpdf_with_trans (dist, ri, istrans (vi, vn))
396402 end
@@ -412,19 +418,21 @@ function dot_assume(
412418end
413419
414420function dot_assume (
415- dists:: Union{Distribution,AbstractArray{<:Distribution}} ,
421+ dist:: Distribution , var:: AbstractArray , vns:: AbstractArray{<:VarName} , vi
422+ )
423+ r = getindex .((vi,), vns, (dist,))
424+ lp = sum (Bijectors. logpdf_with_trans .((dist,), r, istrans .((vi,), vns)))
425+ return r, lp, vi
426+ end
427+
428+ function dot_assume (
429+ dists:: AbstractArray{<:Distribution} ,
416430 var:: AbstractArray ,
417431 vns:: AbstractArray{<:VarName} ,
418432 vi,
419433)
420- # NOTE: We cannot work with `var` here because we might have a model of the form
421- #
422- # m = Vector{Float64}(undef, n)
423- # m .~ Normal()
424- #
425- # in which case `var` will have `undef` elements, even if `m` is present in `vi`.
426- r = reshape (vi[vec (vns)], size (vns))
427- lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans (vi, vns[1 ])))
434+ r = getindex .((vi,), vns, dists)
435+ lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans .((vi,), vns)))
428436 return r, lp, vi
429437end
430438
@@ -438,7 +446,7 @@ function dot_assume(
438446)
439447 r = get_and_set_val! (rng, vi, vns, dists, spl)
440448 # Make sure `r` is not a matrix for multivariate distributions
441- lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans ( vi, vns[ 1 ] )))
449+ lp = sum (Bijectors. logpdf_with_trans .(dists, r, istrans .(( vi,), vns)))
442450 return r, lp, vi
443451end
444452function dot_assume (rng, spl:: Sampler , :: Any , :: AbstractArray{<:VarName} , :: Any , :: Any )
@@ -462,19 +470,23 @@ function get_and_set_val!(
462470 r = init (rng, dist, spl, n)
463471 for i in 1 : n
464472 vn = vns[i]
465- vi[vn] = vectorize (dist, r[:, i])
466- settrans! (vi, false , vn)
473+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r[:, i]))
467474 setorder! (vi, vn, get_num_produce (vi))
468475 end
469476 else
470- r = vi[vns]
477+ r = vi[vns, dist ]
471478 end
472479 else
473480 r = init (rng, dist, spl, n)
474481 for i in 1 : n
475482 vn = vns[i]
476- push!! (vi, vn, r[:, i], dist, spl)
477- settrans! (vi, false , vn)
483+ if istrans (vi)
484+ push!! (vi, vn, Bijectors. link (dist, r[:, i]), dist, spl)
485+ # `push!!` sets the trans-flag to `false` by default.
486+ settrans!! (vi, true , vn)
487+ else
488+ push!! (vi, vn, r[:, i], dist, spl)
489+ end
478490 end
479491 end
480492 return r
@@ -496,12 +508,13 @@ function get_and_set_val!(
496508 for i in eachindex (vns)
497509 vn = vns[i]
498510 dist = dists isa AbstractArray ? dists[i] : dists
499- vi[vn] = vectorize (dist, r[i])
500- settrans! (vi, false , vn)
511+ vi[vn] = vectorize (dist, maybe_link (vi, vn, dist, r[i]))
501512 setorder! (vi, vn, get_num_produce (vi))
502513 end
503514 else
504- r = reshape (vi[vec (vns)], size (vns))
515+ # r = reshape(vi[vec(vns)], size(vns))
516+ r_raw = getindex_raw (vi, vec (vns))
517+ r = maybe_invlink .((vi,), vns, dists, reshape (r_raw, size (vns)))
505518 end
506519 else
507520 f = (vn, dist) -> init (rng, dist, spl)
@@ -511,8 +524,13 @@ function get_and_set_val!(
511524 # 1. Figure out the broadcast size and use a `foreach`.
512525 # 2. Define an anonymous function which returns `nothing`, which
513526 # we then broadcast. This will allocate a vector of `nothing` though.
514- push!! .(Ref (vi), vns, r, dists, Ref (spl))
515- settrans! .(Ref (vi), false , vns)
527+ if istrans (vi)
528+ push!! .((vi,), vns, link .((vi,), vns, dists, r), dists, (spl,))
529+ # `push!!` sets the trans-flag to `false` by default.
530+ settrans!! .((vi,), true , vns)
531+ else
532+ push!! .((vi,), vns, r, dists, (spl,))
533+ end
516534 end
517535 return r
518536end
0 commit comments