Skip to content

Commit e01c0ca

Browse files
authored
Try #360:
2 parents 8990bfb + 8b870dc commit e01c0ca

File tree

8 files changed

+223
-96
lines changed

8 files changed

+223
-96
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.17.3"
3+
version = "0.17.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/DynamicPPL.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,4 +176,8 @@ include("test_utils.jl")
176176
@deprecate acclogp!(vi, logp) acclogp!!(vi, logp)
177177
@deprecate resetlogp!(vi) resetlogp!!(vi)
178178

179+
@deprecate settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) settrans!!(
180+
vi, trans, vn
181+
)
182+
179183
end # module

src/context_implementations.jl

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end
5555
function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi)
5656
if haskey(context.vars, getsym(vn))
5757
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
58-
settrans!(vi, false, vn)
58+
settrans!!(vi, false, vn)
5959
end
6060
return tilde_assume(PriorContext(), right, vn, vi)
6161
end
@@ -64,15 +64,15 @@ function tilde_assume(
6464
)
6565
if haskey(context.vars, getsym(vn))
6666
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
67-
settrans!(vi, false, vn)
67+
settrans!!(vi, false, vn)
6868
end
6969
return tilde_assume(rng, PriorContext(), sampler, right, vn, vi)
7070
end
7171

7272
function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi)
7373
if haskey(context.vars, getsym(vn))
7474
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
75-
settrans!(vi, false, vn)
75+
settrans!!(vi, false, vn)
7676
end
7777
return tilde_assume(LikelihoodContext(), right, vn, vi)
7878
end
@@ -86,7 +86,7 @@ function tilde_assume(
8686
)
8787
if haskey(context.vars, getsym(vn))
8888
vi = setindex!!(vi, vectorize(right, get(context.vars, vn)), vn)
89-
settrans!(vi, false, vn)
89+
settrans!!(vi, false, vn)
9090
end
9191
return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi)
9292
end
@@ -194,7 +194,9 @@ end
194194

195195
# fallback without sampler
196196
function assume(dist::Distribution, vn::VarName, vi)
197-
r = vi[vn]
197+
# x = vi[vn]
198+
r_raw = getindex_raw(vi, vn)
199+
r = maybe_invlink(vi, vn, dist, r_raw)
198200
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
199201
end
200202

@@ -211,16 +213,23 @@ function assume(
211213
if sampler isa SampleFromUniform || is_flagged(vi, vn, "del")
212214
unset_flag!(vi, vn, "del")
213215
r = init(rng, dist, sampler)
214-
vi[vn] = vectorize(dist, r)
215-
settrans!(vi, false, vn)
216+
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r))
216217
setorder!(vi, vn, get_num_produce(vi))
217218
else
218-
r = vi[vn]
219+
# Otherwise we just extract it.
220+
# r = vi[vn]
221+
r_raw = getindex_raw(vi, vn)
222+
r = maybe_invlink(vi, vn, dist, r_raw)
219223
end
220224
else
221225
r = init(rng, dist, sampler)
222-
push!!(vi, vn, r, dist, sampler)
223-
settrans!(vi, false, vn)
226+
if istrans(vi)
227+
push!!(vi, vn, link(dist, r), dist, sampler)
228+
# By default `push!!` sets the transformed flag to `false`.
229+
settrans!!(vi, true, vn)
230+
else
231+
push!!(vi, vn, r, dist, sampler)
232+
end
224233
end
225234

226235
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi
@@ -286,7 +295,7 @@ function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left,
286295
var = get(context.vars, vn)
287296
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
288297
set_val!(vi, _vns, _right, _left)
289-
settrans!.(Ref(vi), false, _vns)
298+
settrans!!.(Ref(vi), false, _vns)
290299
dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi)
291300
else
292301
dot_tilde_assume(LikelihoodContext(), right, left, vn, vi)
@@ -305,7 +314,7 @@ function dot_tilde_assume(
305314
var = get(context.vars, vn)
306315
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
307316
set_val!(vi, _vns, _right, _left)
308-
settrans!.(Ref(vi), false, _vns)
317+
settrans!!.(Ref(vi), false, _vns)
309318
dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi)
310319
else
311320
dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi)
@@ -326,7 +335,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn,
326335
var = get(context.vars, vn)
327336
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
328337
set_val!(vi, _vns, _right, _left)
329-
settrans!.(Ref(vi), false, _vns)
338+
settrans!!.(Ref(vi), false, _vns)
330339
dot_tilde_assume(PriorContext(), _right, _left, _vns, vi)
331340
else
332341
dot_tilde_assume(PriorContext(), right, left, vn, vi)
@@ -345,7 +354,7 @@ function dot_tilde_assume(
345354
var = get(context.vars, vn)
346355
_right, _left, _vns = unwrap_right_left_vns(right, var, vn)
347356
set_val!(vi, _vns, _right, _left)
348-
settrans!.(Ref(vi), false, _vns)
357+
settrans!!.(Ref(vi), false, _vns)
349358
dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi)
350359
else
351360
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi)
@@ -390,7 +399,9 @@ function dot_assume(
390399
# m .~ Normal()
391400
#
392401
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
393-
r = vi[vns]
402+
# r = vi[vns]
403+
r_raw = getindex_raw(vi, vns)
404+
r = maybe_invlink(vi, vn, dist, r_raw)
394405
lp = sum(zip(vns, eachcol(r))) do (vn, ri)
395406
return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn))
396407
end
@@ -423,7 +434,8 @@ function dot_assume(
423434
# m .~ Normal()
424435
#
425436
# in which case `var` will have `undef` elements, even if `m` is present in `vi`.
426-
r = reshape(vi[vec(vns)], size(vns))
437+
r_raw = getindex_raw(vi, vec(vns))
438+
r = reshape(maybe_invlink.(Ref(vi), vns, dists, r_raw), size(vns))
427439
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
428440
return r, lp, vi
429441
end
@@ -462,19 +474,24 @@ function get_and_set_val!(
462474
r = init(rng, dist, spl, n)
463475
for i in 1:n
464476
vn = vns[i]
465-
vi[vn] = vectorize(dist, r[:, i])
466-
settrans!(vi, false, vn)
477+
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[:, i]))
467478
setorder!(vi, vn, get_num_produce(vi))
468479
end
469480
else
470-
r = vi[vns]
481+
r_raw = getindex_raw(vi, vns)
482+
r = maybe_invlink(vi, vns, dist, r_raw)
471483
end
472484
else
473485
r = init(rng, dist, spl, n)
474486
for i in 1:n
475487
vn = vns[i]
476-
push!!(vi, vn, r[:, i], dist, spl)
477-
settrans!(vi, false, vn)
488+
if istrans(vi)
489+
push!!(vi, vn, maybe_link(vi, vn, dist, r[:, i]), dist, spl)
490+
# `push!!` sets the trans-flag to `false` by default.
491+
setttrans!!(vi, true, vn)
492+
else
493+
push!!(vi, vn, r[:, i], dist, spl)
494+
end
478495
end
479496
end
480497
return r
@@ -496,12 +513,13 @@ function get_and_set_val!(
496513
for i in eachindex(vns)
497514
vn = vns[i]
498515
dist = dists isa AbstractArray ? dists[i] : dists
499-
vi[vn] = vectorize(dist, r[i])
500-
settrans!(vi, false, vn)
516+
vi[vn] = vectorize(dist, maybe_link(vi, vn, dist, r[i]))
501517
setorder!(vi, vn, get_num_produce(vi))
502518
end
503519
else
504-
r = reshape(vi[vec(vns)], size(vns))
520+
# r = reshape(vi[vec(vns)], size(vns))
521+
r_raw = getindex_raw(vi, vec(vns))
522+
r = maybe_invlink.(Ref(vi), vns, dists, reshape(r_raw, size(vns)))
505523
end
506524
else
507525
f = (vn, dist) -> init(rng, dist, spl)
@@ -511,8 +529,13 @@ function get_and_set_val!(
511529
# 1. Figure out the broadcast size and use a `foreach`.
512530
# 2. Define an anonymous function which returns `nothing`, which
513531
# we then broadcast. This will allocate a vector of `nothing` though.
514-
push!!.(Ref(vi), vns, r, dists, Ref(spl))
515-
settrans!.(Ref(vi), false, vns)
532+
if istrans(vi)
533+
push!!.(Ref(vi), vns, link.(Ref(vi), vns, dists, r), dists, Ref(spl))
534+
# `push!!` sets the trans-flag to `false` by default.
535+
settrans!!.(Ref(vi), true, vns)
536+
else
537+
push!!.(Ref(vi), vns, r, dists, Ref(spl))
538+
end
516539
end
517540
return r
518541
end

0 commit comments

Comments
 (0)