Skip to content

Commit e585eea

Browse files
KDr2yebai
authored andcommitted
change Selector.tag from Ref{Symbol} to Symbol (#726)
1 parent 3180da9 commit e585eea

File tree

14 files changed

+60
-68
lines changed

14 files changed

+60
-68
lines changed

src/Turing.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,10 @@ function runmodel! end
6464

6565
struct Selector
6666
gid :: UInt64
67-
tag :: Ref{Symbol} # :default, :invalid, :Gibbs, :HMC, etc.
67+
tag :: Symbol # :default, :invalid, :Gibbs, :HMC, etc.
6868
end
69-
Selector() = Selector(time_ns(), Ref(:default))
70-
Selector(tag::Symbol) = Selector(time_ns(), Ref(tag))
69+
Selector() = Selector(time_ns(), :default)
70+
Selector(tag::Symbol) = Selector(time_ns(), tag)
7171
hash(s::Selector) = hash(s.gid)
7272
==(s1::Selector, s2::Selector) = s1.gid == s2.gid
7373

@@ -96,8 +96,9 @@ mutable struct Sampler{T} <: AbstractSampler
9696
info :: Dict{Symbol, Any} # sampler infomation
9797
selector :: Selector
9898
end
99-
Sampler(alg, model::Model) = Sampler(alg)
100-
Sampler(alg, info::Dict{Symbol, Any}) = Sampler(alg, info, Selector())
99+
Sampler(alg) = Sampler(alg, Selector())
100+
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
101+
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)
101102

102103
include("utilities/Utilities.jl")
103104
using .Utilities

src/core/VarReplay.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -207,11 +207,8 @@ end
207207
Base.getindex(vi::VarInfo, vview::VarView) = copy(getval(vi, vview))
208208
Base.setindex!(vi::VarInfo, val::Any, vview::VarView) = setval!(vi, val, vview)
209209

210-
Base.getindex(vi::VarInfo, s::Selector) = copy(getval(vi, getranges(vi, s)))
211-
Base.setindex!(vi::VarInfo, val::Any, s::Selector) = setval!(vi, val, getranges(vi, s))
212-
213-
Base.getindex(vi::VarInfo, spl::Sampler) = copy(getval(vi, getranges(vi, spl)))
214-
Base.setindex!(vi::VarInfo, val::Any, spl::Sampler) = setval!(vi, val, getranges(vi, spl))
210+
Base.getindex(vi::VarInfo, s::Union{Selector, Sampler}) = copy(getval(vi, getranges(vi, s)))
211+
Base.setindex!(vi::VarInfo, val::Any, s::Union{Selector, Sampler}) = setval!(vi, val, getranges(vi, s))
215212

216213
Base.getindex(vi::VarInfo, ::SampleFromPrior) = copy(getall(vi))
217214
Base.setindex!(vi::VarInfo, val::Any, ::SampleFromPrior) = setall!(vi, val)
@@ -314,10 +311,7 @@ function getidcs(vi::VarInfo, spl::Sampler)
314311
spl.info[:idcs]
315312
else
316313
spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS
317-
spl.info[:idcs] = filter(i ->
318-
(spl.selector in vi.gids[i] || isempty(vi.gids[i])) && (isempty(spl.alg.space) || is_inside(vi.vns[i], spl.alg.space)),
319-
1:length(vi.gids)
320-
)
314+
spl.info[:idcs] = getidcs(vi, spl.selector, spl.alg.space)
321315
end
322316
end
323317

@@ -352,12 +346,12 @@ function getranges(vi::VarInfo, spl::Sampler)
352346
spl.info[:ranges]
353347
else
354348
spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES
355-
spl.info[:ranges] = union(map(i -> vi.ranges[i], getidcs(vi, spl))...)
349+
spl.info[:ranges] = getranges(vi, spl.selector, spl.alg.space)
356350
end
357351
end
358352

359-
function getranges(vi::VarInfo, s::Selector)
360-
union(map(i -> vi.ranges[i], getidcs(vi, s))...)
353+
function getranges(vi::VarInfo, s::Selector, space::Set=Set())
354+
union(map(i -> vi.ranges[i], getidcs(vi, s, space))...)
361355
end
362356

363357
# NOTE: this function below is not used anywhere but test files.

src/inference/Inference.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using Distributions, Libtask, Bijectors
55
using ProgressMeter, LinearAlgebra
66
using ..Turing: PROGRESS, CACHERESET, AbstractSampler
77
using ..Turing: Model, runmodel!, get_pvars, get_dvars,
8-
Sampler, SampleFromPrior, SampleFromUniform
8+
Sampler, SampleFromPrior, SampleFromUniform,
9+
Selector
910
using ..Turing: in_pvars, in_dvars, Turing
1011
using StatsFuns: logsumexp
1112

src/inference/dynamichmc.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ function DynamicNUTS{AD}(n_iters::Integer, space...) where AD
3232
DynamicNUTS{AD, eltype(_space)}(n_iters, _space)
3333
end
3434

35-
function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian
36-
return Sampler(alg, Dict{Symbol,Any}())
35+
function Sampler(alg::DynamicNUTS{T}, s::Selector) where T <: Hamiltonian
36+
return Sampler(alg, Dict{Symbol,Any}(), s)
3737
end
3838

3939
function sample(model::Model,
@@ -52,7 +52,7 @@ function sample(model::Model,
5252
vi = VarInfo()
5353
model(vi, SampleFromUniform())
5454

55-
if spl.selector.tag[] == :default
55+
if spl.selector.tag == :default
5656
link!(vi, spl)
5757
runmodel!(model, vi, spl)
5858
end

src/inference/gibbs.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ Gibbs(n_iters::Int, algs...; thin=true) = Gibbs(n_iters, algs, thin)
3333

3434
const GibbsComponent = Union{Hamiltonian,MH,PG}
3535

36-
function Sampler(alg::Gibbs, model::Model)
36+
function Sampler(alg::Gibbs, model::Model, s::Selector)
3737
info = Dict{Symbol, Any}()
38-
spl = Sampler(alg, info)
38+
spl = Sampler(alg, info, s)
3939

4040
n_samplers = length(alg.algs)
4141
samplers = Array{Sampler}(undef, n_samplers)
@@ -44,8 +44,7 @@ function Sampler(alg::Gibbs, model::Model)
4444
for i in 1:n_samplers
4545
sub_alg = alg.algs[i]
4646
if isa(sub_alg, GibbsComponent)
47-
samplers[i] = Sampler(sub_alg, model)
48-
samplers[i].selector.tag[] = Symbol(typeof(sub_alg))
47+
samplers[i] = Sampler(sub_alg, model, Selector(Symbol(typeof(sub_alg))))
4948
else
5049
@error("[Gibbs] unsupport base sampling algorithm $alg")
5150
end

src/inference/hmc.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,12 @@ end
7474
DEFAULT_ADAPT_CONF_TYPE = Nothing
7575
STAN_DEFAULT_ADAPT_CONF = nothing
7676

77-
Sampler(alg::Hamiltonian) = Sampler(alg, nothing)
78-
function Sampler(alg::Hamiltonian, adapt_conf::Nothing)
79-
return _sampler(alg::Hamiltonian, adapt_conf)
77+
Sampler(alg::Hamiltonian, s::Selector) = Sampler(alg, nothing, s)
78+
Sampler(alg::Hamiltonian, adapt_conf::Nothing) = Sampler(alg, adapt_conf, Selector())
79+
function Sampler(alg::Hamiltonian, adapt_conf::Nothing, s::Selector)
80+
return _sampler(alg::Hamiltonian, adapt_conf, s)
8081
end
81-
function _sampler(alg::Hamiltonian, adapt_conf)
82+
function _sampler(alg::Hamiltonian, adapt_conf, s::Selector)
8283
info=Dict{Symbol, Any}()
8384

8485
# For state infomation
@@ -88,7 +89,7 @@ function _sampler(alg::Hamiltonian, adapt_conf)
8889
# Adapt configuration
8990
info[:adapt_conf] = adapt_conf
9091

91-
Sampler(alg, info)
92+
Sampler(alg, info, s)
9293
end
9394

9495
function sample(model::Model, alg::Hamiltonian;
@@ -133,7 +134,7 @@ function sample(model::Model, alg::Hamiltonian;
133134
deepcopy(resume_from.info[:vi])
134135
end
135136

136-
if spl.selector.tag[] == :default
137+
if spl.selector.tag == :default
137138
link!(vi, spl)
138139
runmodel!(model, vi, spl)
139140
end
@@ -185,7 +186,7 @@ function sample(model::Model, alg::Hamiltonian;
185186
c = Chain(0.0, samples) # wrap the result by Chain
186187
if save_state # save state
187188
# Convert vi back to X if vi is required to be saved
188-
spl.selector.tag[] == :default && invlink!(vi, spl)
189+
spl.selector.tag == :default && invlink!(vi, spl)
189190
c = save(c, spl, model, vi, samples)
190191
end
191192
return c
@@ -197,11 +198,11 @@ function step(model, spl::Sampler{<:StaticHamiltonian}, vi::VarInfo, is_first::V
197198
end
198199

199200
function step(model, spl::Sampler{<:AdaptiveHamiltonian}, vi::VarInfo, is_first::Val{true})
200-
spl.selector.tag[] != :default && link!(vi, spl)
201+
spl.selector.tag != :default && link!(vi, spl)
201202
epsilon = find_good_eps(model, spl, vi) # heuristically find good initial epsilon
202203
dim = length(vi[spl])
203204
spl.info[:wum] = ThreePhaseAdapter(spl, epsilon, dim)
204-
spl.selector.tag[] != :default && invlink!(vi, spl)
205+
spl.selector.tag != :default && invlink!(vi, spl)
205206
return vi, true
206207
end
207208

@@ -215,7 +216,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
215216
spl.info[:eval_num] = 0
216217

217218
Turing.DEBUG && @debug "X-> R..."
218-
if spl.selector.tag[] != :default
219+
if spl.selector.tag != :default
219220
link!(vi, spl)
220221
runmodel!(model, vi, spl)
221222
end
@@ -241,7 +242,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
241242
setlogp!(vi, lj)
242243
end
243244

244-
if PROGRESS[] && spl.selector.tag[] == :default
245+
if PROGRESS[] && spl.selector.tag == :default
245246
std_str = string(spl.info[:wum].pc)
246247
std_str = length(std_str) >= 32 ? std_str[1:30]*"..." : std_str
247248
haskey(spl.info, :progress) && ProgressMeter.update!(
@@ -256,7 +257,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
256257
end
257258

258259
Turing.DEBUG && @debug "R -> X..."
259-
spl.selector.tag[] != :default && invlink!(vi, spl)
260+
spl.selector.tag != :default && invlink!(vi, spl)
260261

261262
return vi, is_accept
262263
end

src/inference/ipmcmc.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,20 @@ function IPMCMC(n1::Int, n2::Int, n3::Int, n4::Int, space...)
5252
IPMCMC(n1, n2, n3, n4, resample_systematic, _space)
5353
end
5454

55-
function Sampler(alg::IPMCMC)
55+
function Sampler(alg::IPMCMC, s::Selector)
5656
info = Dict{Symbol, Any}()
57-
spl = Sampler(alg, info)
57+
spl = Sampler(alg, info, s)
5858
# Create SMC and CSMC nodes
5959
samplers = Array{Sampler}(undef, alg.n_nodes)
6060
# Use resampler_threshold=1.0 for SMC since adaptive resampling is invalid in this setting
6161
default_CSMC = CSMC(alg.n_particles, 1, alg.resampler, alg.space)
6262
default_SMC = SMC(alg.n_particles, alg.resampler, 1.0, false, alg.space)
6363

6464
for i in 1:alg.n_csmc_nodes
65-
samplers[i] = Sampler(default_CSMC)
66-
samplers[i].selector.tag[] = Symbol(typeof(default_CSMC))
65+
samplers[i] = Sampler(default_CSMC, Selector(Symbol(typeof(default_CSMC))))
6766
end
6867
for i in (alg.n_csmc_nodes+1):alg.n_nodes
69-
samplers[i] = Sampler(default_SMC)
70-
samplers[i].selector.tag[] = Symbol(typeof(default_SMC))
68+
samplers[i] = Sampler(default_SMC, Symbol(typeof(default_SMC)))
7169
end
7270

7371
info[:samplers] = samplers

src/inference/is.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@ mutable struct IS <: InferenceAlgorithm
3535
n_particles :: Int
3636
end
3737

38-
function Sampler(alg::IS)
38+
function Sampler(alg::IS, s::Selector)
3939
info = Dict{Symbol, Any}()
40-
Sampler(alg, info)
40+
Sampler(alg, info, s)
4141
end
4242

4343
function sample(model::Model, alg::IS)

src/inference/mh.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,11 @@ function MH(n_iters::Int, space...)
4848
MH{eltype(set)}(n_iters, proposals, set)
4949
end
5050

51-
function Sampler(alg::MH, model::Model)
51+
function Sampler(alg::MH, model::Model, s::Selector)
5252
alg_str = "MH"
5353

5454
# Sanity check for space
55-
# TODO: if (this_sampler.selector.tag[] == :default) && !isempty(alg.space)
56-
if false && !isempty(alg.space)
55+
if (s.tag == :default) && !isempty(alg.space)
5756
@assert issubset(Set(get_pvars(model)), alg.space) "[$alg_str] symbols specified to samplers ($alg.space) doesn't cover the model parameters ($(Set(get_pvars(model))))"
5857
if Set(get_pvars(model)) != alg.space
5958
warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(alg.space, Set(get_pvars(model))))")
@@ -65,7 +64,7 @@ function Sampler(alg::MH, model::Model)
6564
info[:prior_prob] = 0.0
6665
info[:violating_support] = false
6766

68-
return Sampler(alg, info)
67+
return Sampler(alg, info, s)
6968
end
7069

7170
function propose(model, spl::Sampler{<:MH}, vi::VarInfo)
@@ -80,7 +79,7 @@ function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{true})
8079
end
8180

8281
function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{false})
83-
if spl.selector.tag[] != :default # Recompute joint in logp
82+
if spl.selector.tag != :default # Recompute joint in logp
8483
runmodel!(model, vi)
8584
end
8685
old_θ = copy(vi[spl])
@@ -137,7 +136,7 @@ function sample(model::Model, alg::MH;
137136
resume_from.info[:vi]
138137
end
139138

140-
if spl.selector.tag[] == :default
139+
if spl.selector.tag == :default
141140
runmodel!(model, vi, spl)
142141
end
143142

src/inference/pgibbs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ end
4141

4242
const CSMC = PG # type alias of PG as Conditional SMC
4343

44-
function Sampler(alg::PG)
44+
function Sampler(alg::PG, s::Selector)
4545
info = Dict{Symbol, Any}()
4646
info[:logevidence] = []
47-
Sampler(alg, info)
47+
Sampler(alg, info, s)
4848
end
4949

5050
step(model, spl::Sampler{<:PG}, vi::VarInfo, _) = step(model, spl, vi)
@@ -117,7 +117,7 @@ function sample( model::Model,
117117

118118
time_total += time_elapsed
119119

120-
if PROGRESS[] && spl.selector.tag[] == :default
120+
if PROGRESS[] && spl.selector.tag == :default
121121
ProgressMeter.next!(spl.info[:progress])
122122
end
123123
end

0 commit comments

Comments
 (0)