11module VarReplay
22
33using ... Turing: Turing, CACHERESET, CACHEIDCS, CACHERANGES, Model,
4- AbstractSampler, Sampler, SampleFromPrior
4+ AbstractSampler, Sampler, SampleFromPrior,
5+ Selector
56using ... Utilities: vectorize, reconstruct, reconstruct!
67using Bijectors: SimplexDistribution
78using Distributions
@@ -70,7 +71,7 @@ mutable struct VarInfo
7071 vals :: Vector{Real}
7172 rvs :: Dict{Union{VarName,Vector{VarName}},Any}
7273 dists :: Vector{Distributions.Distribution}
73- gids :: Vector{Int }
74+ gids :: Vector{Set{Selector} }
7475 logp :: Real
7576 pred :: Dict{Symbol,Any}
7677 num_produce :: Int # num of produce calls from trace, each produce corresponds to an observe.
@@ -139,8 +140,7 @@ getsym(vi::VarInfo, vn::VarName) = vi.vns[getidx(vi, vn)].sym
139140getdist (vi:: VarInfo , vn:: VarName ) = vi. dists[getidx (vi, vn)]
140141
141142getgid (vi:: VarInfo , vn:: VarName ) = vi. gids[getidx (vi, vn)]
142-
143- setgid! (vi:: VarInfo , gid:: Int , vn:: VarName ) = vi. gids[getidx (vi, vn)] = gid
143+ setgid! (vi:: VarInfo , gid:: Selector , vn:: VarName ) = push! (vi. gids[getidx (vi, vn)], gid)
144144
145145istrans (vi:: VarInfo , vn:: VarName ) = is_flagged (vi, vn, " trans" )
146146settrans! (vi:: VarInfo , trans:: Bool , vn:: VarName ) = trans ? set_flag! (vi, vn, " trans" ) : unset_flag! (vi, vn, " trans" )
207207Base. getindex (vi:: VarInfo , vview:: VarView ) = copy (getval (vi, vview))
208208Base. setindex! (vi:: VarInfo , val:: Any , vview:: VarView ) = setval! (vi, val, vview)
209209
210+ Base. getindex (vi:: VarInfo , s:: Selector ) = copy (getval (vi, getranges (vi, s)))
211+ Base. setindex! (vi:: VarInfo , val:: Any , s:: Selector ) = setval! (vi, val, getranges (vi, s))
212+
210213Base. getindex (vi:: VarInfo , spl:: Sampler ) = copy (getval (vi, getranges (vi, spl)))
211214Base. setindex! (vi:: VarInfo , val:: Any , spl:: Sampler ) = setval! (vi, val, getranges (vi, spl))
212215
@@ -237,7 +240,9 @@ function Base.show(io::IO, vi::VarInfo)
237240end
238241
239242# Add a new entry to VarInfo
240- function push! (vi:: VarInfo , vn:: VarName , r:: Any , dist:: Distributions.Distribution , gid:: Int )
243+ push! (vi:: VarInfo , vn:: VarName , r:: Any , dist:: Distributions.Distribution ) = push! (vi, vn, r, dist, Set {Selector} ([]))
244+ push! (vi:: VarInfo , vn:: VarName , r:: Any , dist:: Distributions.Distribution , gid:: Selector ) = push! (vi, vn, r, dist, Set ([gid]))
245+ function push! (vi:: VarInfo , vn:: VarName , r:: Any , dist:: Distributions.Distribution , gidset:: Set{Selector} )
241246
242247 @assert ~ (vn in vns (vi)) " [push!] attempt to add an exisitng variable $(sym (vn)) ($(vn) ) to VarInfo (keys=$(keys (vi)) ) with dist=$dist , gid=$gid "
243248
@@ -249,7 +254,7 @@ function push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distributio
249254 push! (vi. ranges, l+ 1 : l+ n)
250255 append! (vi. vals, val)
251256 push! (vi. dists, dist)
252- push! (vi. gids, gid )
257+ push! (vi. gids, gidset )
253258 push! (vi. orders, vi. num_produce)
254259 push! (vi. flags[" del" ], false )
255260 push! (vi. flags[" trans" ], false )
296301# vi.logp = vi.logp[end:end]
297302# end
298303
299- # Get all indices of variables belonging to gid or 0
300- getidcs (vi:: VarInfo ) = getidcs (vi, nothing )
301- getidcs (vi:: VarInfo , :: SampleFromPrior ) = filter (i -> vi. gids[i] == 0 , 1 : length (vi. gids))
304+ # Get all indices of variables belonging to SampleFromPrior:
305+ # if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to
306+ # the SampleFromPrior sampler
307+ getidcs (vi:: VarInfo , :: SampleFromPrior ) = filter (i -> isempty (vi. gids[i]) , 1 : length (vi. gids))
302308function getidcs (vi:: VarInfo , spl:: Sampler )
303309 # NOTE: 0b00 is the sanity flag for
304310 # |\____ getidcs (mask = 0b10)
@@ -309,12 +315,18 @@ function getidcs(vi::VarInfo, spl::Sampler)
309315 else
310316 spl. info[:cache_updated ] = spl. info[:cache_updated ] | CACHEIDCS
311317 spl. info[:idcs ] = filter (i ->
312- (vi. gids[i] == spl . alg . gid || vi. gids[i] == 0 ) && (isempty (spl. alg. space) || is_inside (vi. vns[i], spl. alg. space)),
318+ (spl . selector in vi. gids[i] || isempty ( vi. gids[i]) ) && (isempty (spl. alg. space) || is_inside (vi. vns[i], spl. alg. space)),
313319 1 : length (vi. gids)
314320 )
315321 end
316322end
317323
324+ # Get all indices of variables belonging to a given selector
325+ function getidcs (vi:: VarInfo , s:: Selector , space:: Set = Set ())
326+ filter (i -> (s in vi. gids[i] || isempty (vi. gids[i])) && (isempty (space) || is_inside (vi. vns[i], space)),
327+ 1 : length (vi. gids))
328+ end
329+
318330function is_inside (vn:: VarName , space:: Set ):: Bool
319331 if vn. sym in space
320332 return true
@@ -327,15 +339,13 @@ function is_inside(vn::VarName, space::Set)::Bool
327339 end
328340end
329341
330- # Get all values of variables belonging to gid or 0
331- getvals (vi:: VarInfo ) = getvals (vi, nothing )
342+ # Get all values of variables belonging to spl.selector
332343getvals (vi:: VarInfo , spl:: AbstractSampler ) = view (vi. vals, getidcs (vi, spl))
333344
334- # Get all vns of variables belonging to gid or 0
335- getvns (vi:: VarInfo ) = getvns (vi, nothing )
345+ # Get all vns of variables belonging to spl.selector
336346getvns (vi:: VarInfo , spl:: AbstractSampler ) = view (vi. vns, getidcs (vi, spl))
337347
338- # Get all vns of variables belonging to gid or 0
348+ # Get all vns of variables belonging to spl.selector
339349function getranges (vi:: VarInfo , spl:: Sampler )
340350 if ~ haskey (spl. info, :cache_updated ) spl. info[:cache_updated ] = CACHERESET end
341351 if haskey (spl. info, :ranges ) && (spl. info[:cache_updated ] & CACHERANGES) > 0
@@ -346,6 +356,10 @@ function getranges(vi::VarInfo, spl::Sampler)
346356 end
347357end
348358
359+ function getranges (vi:: VarInfo , s:: Selector )
360+ union (map (i -> vi. ranges[i], getidcs (vi, s))... )
361+ end
362+
349363# NOTE: this function below is not used anywhere but test files.
350364# we can safely remove it if we want.
351365function getretain (vi:: VarInfo , spl:: AbstractSampler )
@@ -381,8 +395,8 @@ function set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler)
381395end
382396
383397function updategid! (vi:: VarInfo , vn:: VarName , spl:: Sampler )
384- if ~ isempty (spl. alg. space) && getgid (vi, vn) == 0 && getsym (vi, vn) in spl. alg. space
385- setgid! (vi, spl. alg . gid , vn)
398+ if ~ isempty (spl. alg. space) && isempty ( getgid (vi, vn)) && getsym (vi, vn) in spl. alg. space
399+ setgid! (vi, spl. selector , vn)
386400 end
387401end
388402
0 commit comments