194194
195195# fallback without sampler
196196function assume (dist:: Distribution , vn:: VarName , vi)
197- r = vi[ vn, dist]
198- return r, Bijectors . logpdf_with_trans (dist, r, istrans (vi, vn)) , vi
197+ r, logp = invlink_with_logpdf (vi, vn, dist)
198+ return r, logp , vi
199199end
200200
201201# SampleFromPrior and SampleFromUniform
@@ -211,7 +211,9 @@ function assume(
211211 if sampler isa SampleFromUniform || is_flagged (vi, vn, " del" )
212212 unset_flag! (vi, vn, " del" )
213213 r = init (rng, dist, sampler)
214- BangBang. setindex!! (vi, vectorize (dist, maybe_link (vi, vn, dist, r)), vn)
214+ BangBang. setindex!! (
215+ vi, vectorize (dist, maybe_reconstruct_and_link (vi, vn, dist, r)), vn
216+ )
215217 setorder! (vi, vn, get_num_produce (vi))
216218 else
217219 # Otherwise we just extract it.
@@ -220,15 +222,17 @@ function assume(
220222 else
221223 r = init (rng, dist, sampler)
222224 if istrans (vi)
223- push!! (vi, vn, link (dist, r), dist, sampler)
225+ push!! (vi, vn, reconstruct_and_link (dist, r), dist, sampler)
224226 # By default `push!!` sets the transformed flag to `false`.
225227 settrans!! (vi, true , vn)
226228 else
227229 push!! (vi, vn, r, dist, sampler)
228230 end
229231 end
230232
231- return r, Bijectors. logpdf_with_trans (dist, r, istrans (vi, vn)), vi
233+ # HACK: The above code might involve an `invlink` somewhere, etc. so we need to correct.
234+ logjac = logabsdetjac (istrans (vi, vn) ? link_transform (dist) : identity, r)
235+ return r, logpdf (dist, r) - logjac, vi
232236end
233237
234238# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`)
@@ -470,7 +474,11 @@ function get_and_set_val!(
470474 r = init (rng, dist, spl, n)
471475 for i in 1 : n
472476 vn = vns[i]
473- setindex!! (vi, vectorize (dist, maybe_link (vi, vn, dist, r[:, i])), vn)
477+ setindex!! (
478+ vi,
479+ vectorize (dist, maybe_reconstruct_and_link (vi, vn, dist, r[:, i])),
480+ vn,
481+ )
474482 setorder! (vi, vn, get_num_produce (vi))
475483 end
476484 else
@@ -508,13 +516,17 @@ function get_and_set_val!(
508516 for i in eachindex (vns)
509517 vn = vns[i]
510518 dist = dists isa AbstractArray ? dists[i] : dists
511- setindex!! (vi, vectorize (dist, maybe_link (vi, vn, dist, r[i])), vn)
519+ setindex!! (
520+ vi, vectorize (dist, maybe_reconstruct_and_link (vi, vn, dist, r[i])), vn
521+ )
512522 setorder! (vi, vn, get_num_produce (vi))
513523 end
514524 else
515525 # r = reshape(vi[vec(vns)], size(vns))
526+ # FIXME : Remove `reconstruct` in `getindex_raw(::VarInfo, ...)`
527+ # and fix the lines below.
516528 r_raw = getindex_raw (vi, vec (vns))
517- r = maybe_invlink .((vi,), vns, dists, reshape (r_raw, size (vns)))
529+ r = maybe_invlink_and_reconstruct .((vi,), vns, dists, reshape (r_raw, size (vns)))
518530 end
519531 else
520532 f = (vn, dist) -> init (rng, dist, spl)
@@ -525,7 +537,7 @@ function get_and_set_val!(
525537 # 2. Define an anonymous function which returns `nothing`, which
526538 # we then broadcast. This will allocate a vector of `nothing` though.
527539 if istrans (vi)
528- push!! .((vi,), vns, link .((vi,), vns, dists, r), dists, (spl,))
540+ push!! .((vi,), vns, reconstruct_and_link .((vi,), vns, dists, r), dists, (spl,))
529541 # NOTE: Need to add the correction.
530542 acclogp!! (vi, sum (logabsdetjac .(bijector .(dists), r)))
531543 # `push!!` sets the trans-flag to `false` by default.
0 commit comments