Skip to content

Commit 3180da9

Browse files
KDr2yebai
authored andcommitted
Using new Selector type instead of gid (#720)
* Using new Selector type instead of gid * reset selector when resume from saved state * update comment and test assertion abount selector * use timestamp as gid, remove DEFAULT_GID * map vars in VarInfo to multiple selector * use a Tuple GID to ensure its uniqueness * Remove the default selector * Argument `parent` in constructor of Sampler * use Select.tag instead of Sampler.parent
1 parent 8f6aee6 commit 3180da9

File tree

20 files changed

+165
-162
lines changed

20 files changed

+165
-162
lines changed

src/Turing.jl

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using Markdown, Libtask, MacroTools
1515
@reexport using Distributions, MCMCChains, Libtask
1616
using Flux.Tracker: Tracker
1717

18-
import Base: ~, convert, promote_rule, rand, getindex, setindex!
18+
import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
1919
import Distributions: sample
2020
import MCMCChains: AbstractChains, Chains
2121

@@ -62,6 +62,15 @@ end
6262
(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...)
6363
function runmodel! end
6464

65+
struct Selector
66+
gid :: UInt64
67+
tag :: Ref{Symbol} # :default, :invalid, :Gibbs, :HMC, etc.
68+
end
69+
Selector() = Selector(time_ns(), Ref(:default))
70+
Selector(tag::Symbol) = Selector(time_ns(), Ref(tag))
71+
hash(s::Selector) = hash(s.gid)
72+
==(s1::Selector, s2::Selector) = s1.gid == s2.gid
73+
6574
abstract type AbstractSampler end
6675

6776
"""
@@ -83,10 +92,12 @@ Turing translates models to chunks that call the modelling functions at specifie
8392
then include that file at the end of this one.
8493
"""
8594
mutable struct Sampler{T} <: AbstractSampler
86-
alg :: T
87-
info :: Dict{Symbol, Any} # sampler infomation
95+
alg :: T
96+
info :: Dict{Symbol, Any} # sampler infomation
97+
selector :: Selector
8898
end
89-
Sampler(alg, model) = Sampler(alg)
99+
Sampler(alg, model::Model) = Sampler(alg)
100+
Sampler(alg, info::Dict{Symbol, Any}) = Sampler(alg, info, Selector())
90101

91102
include("utilities/Utilities.jl")
92103
using .Utilities

src/core/VarReplay.jl

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module VarReplay
22

33
using ...Turing: Turing, CACHERESET, CACHEIDCS, CACHERANGES, Model,
4-
AbstractSampler, Sampler, SampleFromPrior
4+
AbstractSampler, Sampler, SampleFromPrior,
5+
Selector
56
using ...Utilities: vectorize, reconstruct, reconstruct!
67
using Bijectors: SimplexDistribution
78
using Distributions
@@ -70,7 +71,7 @@ mutable struct VarInfo
7071
vals :: Vector{Real}
7172
rvs :: Dict{Union{VarName,Vector{VarName}},Any}
7273
dists :: Vector{Distributions.Distribution}
73-
gids :: Vector{Int}
74+
gids :: Vector{Set{Selector}}
7475
logp :: Real
7576
pred :: Dict{Symbol,Any}
7677
num_produce :: Int # num of produce calls from trace, each produce corresponds to an observe.
@@ -139,8 +140,7 @@ getsym(vi::VarInfo, vn::VarName) = vi.vns[getidx(vi, vn)].sym
139140
getdist(vi::VarInfo, vn::VarName) = vi.dists[getidx(vi, vn)]
140141

141142
getgid(vi::VarInfo, vn::VarName) = vi.gids[getidx(vi, vn)]
142-
143-
setgid!(vi::VarInfo, gid::Int, vn::VarName) = vi.gids[getidx(vi, vn)] = gid
143+
setgid!(vi::VarInfo, gid::Selector, vn::VarName) = push!(vi.gids[getidx(vi, vn)], gid)
144144

145145
istrans(vi::VarInfo, vn::VarName) = is_flagged(vi, vn, "trans")
146146
settrans!(vi::VarInfo, trans::Bool, vn::VarName) = trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans")
@@ -207,6 +207,9 @@ end
207207
Base.getindex(vi::VarInfo, vview::VarView) = copy(getval(vi, vview))
208208
Base.setindex!(vi::VarInfo, val::Any, vview::VarView) = setval!(vi, val, vview)
209209

210+
Base.getindex(vi::VarInfo, s::Selector) = copy(getval(vi, getranges(vi, s)))
211+
Base.setindex!(vi::VarInfo, val::Any, s::Selector) = setval!(vi, val, getranges(vi, s))
212+
210213
Base.getindex(vi::VarInfo, spl::Sampler) = copy(getval(vi, getranges(vi, spl)))
211214
Base.setindex!(vi::VarInfo, val::Any, spl::Sampler) = setval!(vi, val, getranges(vi, spl))
212215

@@ -237,7 +240,9 @@ function Base.show(io::IO, vi::VarInfo)
237240
end
238241

239242
# Add a new entry to VarInfo
240-
function push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution, gid::Int)
243+
push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution) = push!(vi, vn, r, dist, Set{Selector}([]))
244+
push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution, gid::Selector) = push!(vi, vn, r, dist, Set([gid]))
245+
function push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution, gidset::Set{Selector})
241246

242247
@assert ~(vn in vns(vi)) "[push!] attempt to add an exisitng variable $(sym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gid"
243248

@@ -249,7 +254,7 @@ function push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distributio
249254
push!(vi.ranges, l+1:l+n)
250255
append!(vi.vals, val)
251256
push!(vi.dists, dist)
252-
push!(vi.gids, gid)
257+
push!(vi.gids, gidset)
253258
push!(vi.orders, vi.num_produce)
254259
push!(vi.flags["del"], false)
255260
push!(vi.flags["trans"], false)
@@ -296,9 +301,10 @@ end
296301
# vi.logp = vi.logp[end:end]
297302
# end
298303

299-
# Get all indices of variables belonging to gid or 0
300-
getidcs(vi::VarInfo) = getidcs(vi, nothing)
301-
getidcs(vi::VarInfo, ::SampleFromPrior) = filter(i -> vi.gids[i] == 0, 1:length(vi.gids))
304+
# Get all indices of variables belonging to SampleFromPrior:
305+
# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to
306+
# the SampleFromPrior sampler
307+
getidcs(vi::VarInfo, ::SampleFromPrior) = filter(i -> isempty(vi.gids[i]) , 1:length(vi.gids))
302308
function getidcs(vi::VarInfo, spl::Sampler)
303309
# NOTE: 0b00 is the sanity flag for
304310
# |\____ getidcs (mask = 0b10)
@@ -309,12 +315,18 @@ function getidcs(vi::VarInfo, spl::Sampler)
309315
else
310316
spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS
311317
spl.info[:idcs] = filter(i ->
312-
(vi.gids[i] == spl.alg.gid || vi.gids[i] == 0) && (isempty(spl.alg.space) || is_inside(vi.vns[i], spl.alg.space)),
318+
(spl.selector in vi.gids[i] || isempty(vi.gids[i])) && (isempty(spl.alg.space) || is_inside(vi.vns[i], spl.alg.space)),
313319
1:length(vi.gids)
314320
)
315321
end
316322
end
317323

324+
# Get all indices of variables belonging to a given selector
325+
function getidcs(vi::VarInfo, s::Selector, space::Set=Set())
326+
filter(i -> (s in vi.gids[i] || isempty(vi.gids[i])) && (isempty(space) || is_inside(vi.vns[i], space)),
327+
1:length(vi.gids))
328+
end
329+
318330
function is_inside(vn::VarName, space::Set)::Bool
319331
if vn.sym in space
320332
return true
@@ -327,15 +339,13 @@ function is_inside(vn::VarName, space::Set)::Bool
327339
end
328340
end
329341

330-
# Get all values of variables belonging to gid or 0
331-
getvals(vi::VarInfo) = getvals(vi, nothing)
342+
# Get all values of variables belonging to spl.selector
332343
getvals(vi::VarInfo, spl::AbstractSampler) = view(vi.vals, getidcs(vi, spl))
333344

334-
# Get all vns of variables belonging to gid or 0
335-
getvns(vi::VarInfo) = getvns(vi, nothing)
345+
# Get all vns of variables belonging to spl.selector
336346
getvns(vi::VarInfo, spl::AbstractSampler) = view(vi.vns, getidcs(vi, spl))
337347

338-
# Get all vns of variables belonging to gid or 0
348+
# Get all vns of variables belonging to spl.selector
339349
function getranges(vi::VarInfo, spl::Sampler)
340350
if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end
341351
if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0
@@ -346,6 +356,10 @@ function getranges(vi::VarInfo, spl::Sampler)
346356
end
347357
end
348358

359+
function getranges(vi::VarInfo, s::Selector)
360+
union(map(i -> vi.ranges[i], getidcs(vi, s))...)
361+
end
362+
349363
# NOTE: this function below is not used anywhere but test files.
350364
# we can safely remove it if we want.
351365
function getretain(vi::VarInfo, spl::AbstractSampler)
@@ -381,8 +395,8 @@ function set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler)
381395
end
382396

383397
function updategid!(vi::VarInfo, vn::VarName, spl::Sampler)
384-
if ~isempty(spl.alg.space) && getgid(vi, vn) == 0 && getsym(vi, vn) in spl.alg.space
385-
setgid!(vi, spl.alg.gid, vn)
398+
if ~isempty(spl.alg.space) && isempty(getgid(vi, vn)) && getsym(vi, vn) in spl.alg.space
399+
setgid!(vi, spl.selector, vn)
386400
end
387401
end
388402

src/inference/Inference.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ function assume(spl::A,
115115
r = vi[vn]
116116
else
117117
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
118-
push!(vi, vn, r, dist, 0)
118+
push!(vi, vn, r, dist)
119119
end
120120
# NOTE: The importance weight is not correctly computed here because
121121
# r is genereated from some uniform distribution which is different from the prior
@@ -144,13 +144,13 @@ function assume(spl::A,
144144

145145
if isa(dist, UnivariateDistribution) || isa(dist, MatrixDistribution)
146146
for i = 1:n
147-
push!(vi, vns[i], rs[i], dist, 0)
147+
push!(vi, vns[i], rs[i], dist)
148148
end
149149
@assert size(var) == size(rs) "Turing.assume: variable and random number dimension unmatched"
150150
var = rs
151151
elseif isa(dist, MultivariateDistribution)
152152
for i = 1:n
153-
push!(vi, vns[i], rs[:,i], dist, 0)
153+
push!(vi, vns[i], rs[:,i], dist)
154154
end
155155
if isa(var, Vector)
156156
@assert length(var) == size(rs)[2] "Turing.assume: variable and random number dimension unmatched"

src/inference/dynamichmc.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
struct DynamicNUTS{AD, T} <: Hamiltonian{AD}
22
n_iters :: Integer # number of samples
33
space :: Set{T} # sampling space, emtpy means all
4-
gid :: Integer # group ID
54
end
65

76
"""
@@ -30,7 +29,7 @@ chn = sample(gdemo(1.5, 2.0), DynamicNUTS(2000))
3029
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
3130
function DynamicNUTS{AD}(n_iters::Integer, space...) where AD
3231
_space = isa(space, Symbol) ? Set([space]) : Set(space)
33-
DynamicNUTS{AD, eltype(_space)}(n_iters, _space, 0)
32+
DynamicNUTS{AD, eltype(_space)}(n_iters, _space)
3433
end
3534

3635
function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian
@@ -53,7 +52,7 @@ function sample(model::Model,
5352
vi = VarInfo()
5453
model(vi, SampleFromUniform())
5554

56-
if spl.alg.gid == 0
55+
if spl.selector.tag[] == :default
5756
link!(vi, spl)
5857
runmodel!(model, vi, spl)
5958
end

src/inference/gibbs.jl

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,24 @@ mutable struct Gibbs{A} <: InferenceAlgorithm
2828
n_iters :: Int # number of Gibbs iterations
2929
algs :: A # component sampling algorithms
3030
thin :: Bool # if thinning to output only after a whole Gibbs sweep
31-
gid :: Int
3231
end
33-
Gibbs(n_iters::Int, algs...; thin=true) = Gibbs(n_iters, algs, thin, 0)
34-
Gibbs(alg::Gibbs, new_gid) = Gibbs(alg.n_iters, alg.algs, alg.thin, new_gid)
32+
Gibbs(n_iters::Int, algs...; thin=true) = Gibbs(n_iters, algs, thin)
3533

3634
const GibbsComponent = Union{Hamiltonian,MH,PG}
3735

3836
function Sampler(alg::Gibbs, model::Model)
37+
info = Dict{Symbol, Any}()
38+
spl = Sampler(alg, info)
39+
3940
n_samplers = length(alg.algs)
4041
samplers = Array{Sampler}(undef, n_samplers)
41-
4242
space = Set{Symbol}()
4343

4444
for i in 1:n_samplers
4545
sub_alg = alg.algs[i]
4646
if isa(sub_alg, GibbsComponent)
47-
samplers[i] = Sampler(typeof(sub_alg)(sub_alg, i), model)
47+
samplers[i] = Sampler(sub_alg, model)
48+
samplers[i].selector.tag[] = Symbol(typeof(sub_alg))
4849
else
4950
@error("[Gibbs] unsupport base sampling algorithm $alg")
5051
end
@@ -58,10 +59,9 @@ function Sampler(alg::Gibbs, model::Model)
5859
@warn("[Gibbs] extra parameters specified by samplers don't exist in model: $(setdiff(space, Set(get_pvars(model))))")
5960
end
6061

61-
info = Dict{Symbol, Any}()
6262
info[:samplers] = samplers
6363

64-
Sampler(alg, info)
64+
return spl
6565
end
6666

6767
function sample(
@@ -73,8 +73,17 @@ function sample(
7373
)
7474

7575
# Init the (master) Gibbs sampler
76-
spl = reuse_spl_n > 0 ? resume_from.info[:spl] : Sampler(alg, model)
77-
76+
if reuse_spl_n > 0
77+
spl = resume_from.info[:spl]
78+
else
79+
spl = Sampler(alg, model)
80+
if resume_from != nothing
81+
spl.selector = resume_from.info[:spl].selector
82+
for i in 1:length(spl.info[:samplers])
83+
spl.info[:samplers][i].selector = resume_from.info[:spl].info[:samplers][i].selector
84+
end
85+
end
86+
end
7887
@assert typeof(spl.alg) == typeof(alg) "[Turing] alg type mismatch; please use resume() to re-use spl"
7988

8089
# Initialize samples

src/inference/hmc.jl

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -48,25 +48,18 @@ mutable struct HMC{AD, T} <: StaticHamiltonian{AD}
4848
epsilon :: Float64 # leapfrog step size
4949
tau :: Int # leapfrog step number
5050
space :: Set{T} # sampling space, emtpy means all
51-
gid :: Int # group ID
5251
end
5352
HMC(args...) = HMC{ADBackend()}(args...)
5453
function HMC{AD}(epsilon::Float64, tau::Int, space...) where AD
5554
_space = isa(space, Symbol) ? Set([space]) : Set(space)
56-
return HMC{AD, eltype(_space)}(1, epsilon, tau, _space, 0)
55+
return HMC{AD, eltype(_space)}(1, epsilon, tau, _space)
5756
end
5857
function HMC{AD}(n_iters::Int, epsilon::Float64, tau::Int) where AD
59-
return HMC{AD, Any}(n_iters, epsilon, tau, Set(), 0)
58+
return HMC{AD, Any}(n_iters, epsilon, tau, Set())
6059
end
6160
function HMC{AD}(n_iters::Int, epsilon::Float64, tau::Int, space...) where AD
6261
_space = isa(space, Symbol) ? Set([space]) : Set(space)
63-
return HMC{AD, eltype(_space)}(n_iters, epsilon, tau, _space, 0)
64-
end
65-
function HMC{AD1}(alg::HMC{AD2, T}, new_gid::Int) where {AD1, AD2, T}
66-
return HMC{AD1, T}(alg.n_iters, alg.epsilon, alg.tau, alg.space, new_gid)
67-
end
68-
function HMC{AD, T}(alg::HMC, new_gid::Int) where {AD, T}
69-
return HMC{AD, T}(alg.n_iters, alg.epsilon, alg.tau, alg.space, new_gid)
62+
return HMC{AD, eltype(_space)}(n_iters, epsilon, tau, _space)
7063
end
7164

7265
function hmc_step(θ, lj, lj_func, grad_func, H_func, ϵ, alg::HMC, momentum_sampler::Function;
@@ -108,6 +101,9 @@ function sample(model::Model, alg::Hamiltonian;
108101
spl = reuse_spl_n > 0 ?
109102
resume_from.info[:spl] :
110103
Sampler(alg, adapt_conf)
104+
if resume_from != nothing
105+
spl.selector = resume_from.info[:spl].selector
106+
end
111107

112108
@assert isa(spl.alg, Hamiltonian) "[Turing] alg type mismatch; please use resume() to re-use spl"
113109

@@ -137,7 +133,7 @@ function sample(model::Model, alg::Hamiltonian;
137133
deepcopy(resume_from.info[:vi])
138134
end
139135

140-
if spl.alg.gid == 0
136+
if spl.selector.tag[] == :default
141137
link!(vi, spl)
142138
runmodel!(model, vi, spl)
143139
end
@@ -189,7 +185,7 @@ function sample(model::Model, alg::Hamiltonian;
189185
c = Chain(0.0, samples) # wrap the result by Chain
190186
if save_state # save state
191187
# Convert vi back to X if vi is required to be saved
192-
if spl.alg.gid == 0 invlink!(vi, spl) end
188+
spl.selector.tag[] == :default && invlink!(vi, spl)
193189
c = save(c, spl, model, vi, samples)
194190
end
195191
return c
@@ -201,11 +197,11 @@ function step(model, spl::Sampler{<:StaticHamiltonian}, vi::VarInfo, is_first::V
201197
end
202198

203199
function step(model, spl::Sampler{<:AdaptiveHamiltonian}, vi::VarInfo, is_first::Val{true})
204-
spl.alg.gid != 0 && link!(vi, spl)
200+
spl.selector.tag[] != :default && link!(vi, spl)
205201
epsilon = find_good_eps(model, spl, vi) # heuristically find good initial epsilon
206202
dim = length(vi[spl])
207203
spl.info[:wum] = ThreePhaseAdapter(spl, epsilon, dim)
208-
spl.alg.gid != 0 && invlink!(vi, spl)
204+
spl.selector.tag[] != :default && invlink!(vi, spl)
209205
return vi, true
210206
end
211207

@@ -219,7 +215,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
219215
spl.info[:eval_num] = 0
220216

221217
Turing.DEBUG && @debug "X-> R..."
222-
if spl.alg.gid != 0
218+
if spl.selector.tag[] != :default
223219
link!(vi, spl)
224220
runmodel!(model, vi, spl)
225221
end
@@ -245,7 +241,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
245241
setlogp!(vi, lj)
246242
end
247243

248-
if PROGRESS[] && spl.alg.gid == 0
244+
if PROGRESS[] && spl.selector.tag[] == :default
249245
std_str = string(spl.info[:wum].pc)
250246
std_str = length(std_str) >= 32 ? std_str[1:30]*"..." : std_str
251247
haskey(spl.info, :progress) && ProgressMeter.update!(
@@ -260,7 +256,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
260256
end
261257

262258
Turing.DEBUG && @debug "R -> X..."
263-
spl.alg.gid != 0 && invlink!(vi, spl)
259+
spl.selector.tag[] != :default && invlink!(vi, spl)
264260

265261
return vi, is_accept
266262
end

0 commit comments

Comments
 (0)