Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ function runmodel! end

struct Selector
gid :: UInt64
tag :: Ref{Symbol} # :default, :invalid, :Gibbs, :HMC, etc.
tag :: Symbol # :default, :invalid, :Gibbs, :HMC, etc.
end
Selector() = Selector(time_ns(), Ref(:default))
Selector(tag::Symbol) = Selector(time_ns(), Ref(tag))
Selector() = Selector(time_ns(), :default)
Selector(tag::Symbol) = Selector(time_ns(), tag)
hash(s::Selector) = hash(s.gid)
==(s1::Selector, s2::Selector) = s1.gid == s2.gid

Expand Down Expand Up @@ -96,8 +96,9 @@ mutable struct Sampler{T} <: AbstractSampler
info :: Dict{Symbol, Any} # sampler infomation
selector :: Selector
end
Sampler(alg, model::Model) = Sampler(alg)
Sampler(alg, info::Dict{Symbol, Any}) = Sampler(alg, info, Selector())
Sampler(alg) = Sampler(alg, Selector())
Sampler(alg, model::Model) = Sampler(alg, model, Selector())
Sampler(alg, model::Model, s::Selector) = Sampler(alg, s)

include("utilities/Utilities.jl")
using .Utilities
Expand Down
18 changes: 6 additions & 12 deletions src/core/VarReplay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,8 @@ end
Base.getindex(vi::VarInfo, vview::VarView) = copy(getval(vi, vview))
Base.setindex!(vi::VarInfo, val::Any, vview::VarView) = setval!(vi, val, vview)

Base.getindex(vi::VarInfo, s::Selector) = copy(getval(vi, getranges(vi, s)))
Base.setindex!(vi::VarInfo, val::Any, s::Selector) = setval!(vi, val, getranges(vi, s))

Base.getindex(vi::VarInfo, spl::Sampler) = copy(getval(vi, getranges(vi, spl)))
Base.setindex!(vi::VarInfo, val::Any, spl::Sampler) = setval!(vi, val, getranges(vi, spl))
Base.getindex(vi::VarInfo, s::Union{Selector, Sampler}) = copy(getval(vi, getranges(vi, s)))
Base.setindex!(vi::VarInfo, val::Any, s::Union{Selector, Sampler}) = setval!(vi, val, getranges(vi, s))

Base.getindex(vi::VarInfo, ::SampleFromPrior) = copy(getall(vi))
Base.setindex!(vi::VarInfo, val::Any, ::SampleFromPrior) = setall!(vi, val)
Expand Down Expand Up @@ -314,10 +311,7 @@ function getidcs(vi::VarInfo, spl::Sampler)
spl.info[:idcs]
else
spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS
spl.info[:idcs] = filter(i ->
(spl.selector in vi.gids[i] || isempty(vi.gids[i])) && (isempty(spl.alg.space) || is_inside(vi.vns[i], spl.alg.space)),
1:length(vi.gids)
)
spl.info[:idcs] = getidcs(vi, spl.selector, spl.alg.space)
end
end

Expand Down Expand Up @@ -352,12 +346,12 @@ function getranges(vi::VarInfo, spl::Sampler)
spl.info[:ranges]
else
spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES
spl.info[:ranges] = union(map(i -> vi.ranges[i], getidcs(vi, spl))...)
spl.info[:ranges] = getranges(vi, spl.selector, spl.alg.space)
end
end

function getranges(vi::VarInfo, s::Selector)
union(map(i -> vi.ranges[i], getidcs(vi, s))...)
function getranges(vi::VarInfo, s::Selector, space::Set=Set())
union(map(i -> vi.ranges[i], getidcs(vi, s, space))...)
end

# NOTE: this function below is not used anywhere but test files.
Expand Down
3 changes: 2 additions & 1 deletion src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ using Distributions, Libtask, Bijectors
using ProgressMeter, LinearAlgebra
using ..Turing: PROGRESS, CACHERESET, AbstractSampler
using ..Turing: Model, runmodel!, get_pvars, get_dvars,
Sampler, SampleFromPrior, SampleFromUniform
Sampler, SampleFromPrior, SampleFromUniform,
Selector
using ..Turing: in_pvars, in_dvars, Turing
using StatsFuns: logsumexp

Expand Down
6 changes: 3 additions & 3 deletions src/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ function DynamicNUTS{AD}(n_iters::Integer, space...) where AD
DynamicNUTS{AD, eltype(_space)}(n_iters, _space)
end

function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian
return Sampler(alg, Dict{Symbol,Any}())
function Sampler(alg::DynamicNUTS{T}, s::Selector) where T <: Hamiltonian
return Sampler(alg, Dict{Symbol,Any}(), s)
end

function sample(model::Model,
Expand All @@ -52,7 +52,7 @@ function sample(model::Model,
vi = VarInfo()
model(vi, SampleFromUniform())

if spl.selector.tag[] == :default
if spl.selector.tag == :default
link!(vi, spl)
runmodel!(model, vi, spl)
end
Expand Down
7 changes: 3 additions & 4 deletions src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ Gibbs(n_iters::Int, algs...; thin=true) = Gibbs(n_iters, algs, thin)

const GibbsComponent = Union{Hamiltonian,MH,PG}

function Sampler(alg::Gibbs, model::Model)
function Sampler(alg::Gibbs, model::Model, s::Selector)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info)
spl = Sampler(alg, info, s)

n_samplers = length(alg.algs)
samplers = Array{Sampler}(undef, n_samplers)
Expand All @@ -44,8 +44,7 @@ function Sampler(alg::Gibbs, model::Model)
for i in 1:n_samplers
sub_alg = alg.algs[i]
if isa(sub_alg, GibbsComponent)
samplers[i] = Sampler(sub_alg, model)
samplers[i].selector.tag[] = Symbol(typeof(sub_alg))
samplers[i] = Sampler(sub_alg, model, Selector(Symbol(typeof(sub_alg))))
else
@error("[Gibbs] unsupport base sampling algorithm $alg")
end
Expand Down
25 changes: 13 additions & 12 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,12 @@ end
DEFAULT_ADAPT_CONF_TYPE = Nothing
STAN_DEFAULT_ADAPT_CONF = nothing

Sampler(alg::Hamiltonian) = Sampler(alg, nothing)
function Sampler(alg::Hamiltonian, adapt_conf::Nothing)
return _sampler(alg::Hamiltonian, adapt_conf)
Sampler(alg::Hamiltonian, s::Selector) = Sampler(alg, nothing, s)
Sampler(alg::Hamiltonian, adapt_conf::Nothing) = Sampler(alg, adapt_conf, Selector())
function Sampler(alg::Hamiltonian, adapt_conf::Nothing, s::Selector)
return _sampler(alg::Hamiltonian, adapt_conf, s)
end
function _sampler(alg::Hamiltonian, adapt_conf)
function _sampler(alg::Hamiltonian, adapt_conf, s::Selector)
info=Dict{Symbol, Any}()

# For state infomation
Expand All @@ -88,7 +89,7 @@ function _sampler(alg::Hamiltonian, adapt_conf)
# Adapt configuration
info[:adapt_conf] = adapt_conf

Sampler(alg, info)
Sampler(alg, info, s)
end

function sample(model::Model, alg::Hamiltonian;
Expand Down Expand Up @@ -133,7 +134,7 @@ function sample(model::Model, alg::Hamiltonian;
deepcopy(resume_from.info[:vi])
end

if spl.selector.tag[] == :default
if spl.selector.tag == :default
link!(vi, spl)
runmodel!(model, vi, spl)
end
Expand Down Expand Up @@ -185,7 +186,7 @@ function sample(model::Model, alg::Hamiltonian;
c = Chain(0.0, samples) # wrap the result by Chain
if save_state # save state
# Convert vi back to X if vi is required to be saved
spl.selector.tag[] == :default && invlink!(vi, spl)
spl.selector.tag == :default && invlink!(vi, spl)
c = save(c, spl, model, vi, samples)
end
return c
Expand All @@ -197,11 +198,11 @@ function step(model, spl::Sampler{<:StaticHamiltonian}, vi::VarInfo, is_first::V
end

function step(model, spl::Sampler{<:AdaptiveHamiltonian}, vi::VarInfo, is_first::Val{true})
spl.selector.tag[] != :default && link!(vi, spl)
spl.selector.tag != :default && link!(vi, spl)
epsilon = find_good_eps(model, spl, vi) # heuristically find good initial epsilon
dim = length(vi[spl])
spl.info[:wum] = ThreePhaseAdapter(spl, epsilon, dim)
spl.selector.tag[] != :default && invlink!(vi, spl)
spl.selector.tag != :default && invlink!(vi, spl)
return vi, true
end

Expand All @@ -215,7 +216,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
spl.info[:eval_num] = 0

Turing.DEBUG && @debug "X-> R..."
if spl.selector.tag[] != :default
if spl.selector.tag != :default
link!(vi, spl)
runmodel!(model, vi, spl)
end
Expand All @@ -241,7 +242,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
setlogp!(vi, lj)
end

if PROGRESS[] && spl.selector.tag[] == :default
if PROGRESS[] && spl.selector.tag == :default
std_str = string(spl.info[:wum].pc)
std_str = length(std_str) >= 32 ? std_str[1:30]*"..." : std_str
haskey(spl.info, :progress) && ProgressMeter.update!(
Expand All @@ -256,7 +257,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
end

Turing.DEBUG && @debug "R -> X..."
spl.selector.tag[] != :default && invlink!(vi, spl)
spl.selector.tag != :default && invlink!(vi, spl)

return vi, is_accept
end
Expand Down
10 changes: 4 additions & 6 deletions src/inference/ipmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,20 @@ function IPMCMC(n1::Int, n2::Int, n3::Int, n4::Int, space...)
IPMCMC(n1, n2, n3, n4, resample_systematic, _space)
end

function Sampler(alg::IPMCMC)
function Sampler(alg::IPMCMC, s::Selector)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info)
spl = Sampler(alg, info, s)
# Create SMC and CSMC nodes
samplers = Array{Sampler}(undef, alg.n_nodes)
# Use resampler_threshold=1.0 for SMC since adaptive resampling is invalid in this setting
default_CSMC = CSMC(alg.n_particles, 1, alg.resampler, alg.space)
default_SMC = SMC(alg.n_particles, alg.resampler, 1.0, false, alg.space)

for i in 1:alg.n_csmc_nodes
samplers[i] = Sampler(default_CSMC)
samplers[i].selector.tag[] = Symbol(typeof(default_CSMC))
samplers[i] = Sampler(default_CSMC, Selector(Symbol(typeof(default_CSMC))))
end
for i in (alg.n_csmc_nodes+1):alg.n_nodes
samplers[i] = Sampler(default_SMC)
samplers[i].selector.tag[] = Symbol(typeof(default_SMC))
samplers[i] = Sampler(default_SMC, Symbol(typeof(default_SMC)))
end

info[:samplers] = samplers
Expand Down
4 changes: 2 additions & 2 deletions src/inference/is.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ mutable struct IS <: InferenceAlgorithm
n_particles :: Int
end

function Sampler(alg::IS)
function Sampler(alg::IS, s::Selector)
info = Dict{Symbol, Any}()
Sampler(alg, info)
Sampler(alg, info, s)
end

function sample(model::Model, alg::IS)
Expand Down
11 changes: 5 additions & 6 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,11 @@ function MH(n_iters::Int, space...)
MH{eltype(set)}(n_iters, proposals, set)
end

function Sampler(alg::MH, model::Model)
function Sampler(alg::MH, model::Model, s::Selector)
alg_str = "MH"

# Sanity check for space
# TODO: if (this_sampler.selector.tag[] == :default) && !isempty(alg.space)
if false && !isempty(alg.space)
if (s.tag == :default) && !isempty(alg.space)
@assert issubset(Set(get_pvars(model)), alg.space) "[$alg_str] symbols specified to samplers ($alg.space) doesn't cover the model parameters ($(Set(get_pvars(model))))"
if Set(get_pvars(model)) != alg.space
warn("[$alg_str] extra parameters specified by samplers don't exist in model: $(setdiff(alg.space, Set(get_pvars(model))))")
Expand All @@ -65,7 +64,7 @@ function Sampler(alg::MH, model::Model)
info[:prior_prob] = 0.0
info[:violating_support] = false

return Sampler(alg, info)
return Sampler(alg, info, s)
end

function propose(model, spl::Sampler{<:MH}, vi::VarInfo)
Expand All @@ -80,7 +79,7 @@ function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{true})
end

function step(model, spl::Sampler{<:MH}, vi::VarInfo, is_first::Val{false})
if spl.selector.tag[] != :default # Recompute joint in logp
if spl.selector.tag != :default # Recompute joint in logp
runmodel!(model, vi)
end
old_θ = copy(vi[spl])
Expand Down Expand Up @@ -137,7 +136,7 @@ function sample(model::Model, alg::MH;
resume_from.info[:vi]
end

if spl.selector.tag[] == :default
if spl.selector.tag == :default
runmodel!(model, vi, spl)
end

Expand Down
6 changes: 3 additions & 3 deletions src/inference/pgibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ end

const CSMC = PG # type alias of PG as Conditional SMC

function Sampler(alg::PG)
function Sampler(alg::PG, s::Selector)
info = Dict{Symbol, Any}()
info[:logevidence] = []
Sampler(alg, info)
Sampler(alg, info, s)
end

step(model, spl::Sampler{<:PG}, vi::VarInfo, _) = step(model, spl, vi)
Expand Down Expand Up @@ -117,7 +117,7 @@ function sample( model::Model,

time_total += time_elapsed

if PROGRESS[] && spl.selector.tag[] == :default
if PROGRESS[] && spl.selector.tag == :default
ProgressMeter.next!(spl.info[:progress])
end
end
Expand Down
7 changes: 3 additions & 4 deletions src/inference/pmmh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ end

PIMH(n_iters::Int, smc_alg::SMC) = PMMH(n_iters, tuple(smc_alg), Set())

function Sampler(alg::PMMH, model::Model)
function Sampler(alg::PMMH, model::Model, s::Selector)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info)
spl = Sampler(alg, info, s)

alg_str = "PMMH"
n_samplers = length(alg.algs)
Expand All @@ -45,8 +45,7 @@ function Sampler(alg::PMMH, model::Model)
for i in 1:n_samplers
sub_alg = alg.algs[i]
if isa(sub_alg, Union{SMC, MH})
samplers[i] = Sampler(sub_alg, model)
samplers[i].selector.tag[] = Symbol(typeof(sub_alg))
samplers[i] = Sampler(sub_alg, model, Selector(Symbol(typeof(sub_alg))))
else
error("[$alg_str] unsupport base sampling algorithm $alg")
end
Expand Down
8 changes: 4 additions & 4 deletions src/inference/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ function SGHMC{AD}(n_iters, learning_rate, momentum_decay, space...) where AD
end

function step(model, spl::Sampler{<:SGHMC}, vi::VarInfo, is_first::Val{true})
spl.selector.tag[] != :default && link!(vi, spl)
spl.selector.tag != :default && link!(vi, spl)

# Initialize velocity
v = zeros(Float64, size(vi[spl]))
spl.info[:v] = v

spl.selector.tag[] != :default && invlink!(vi, spl)
spl.selector.tag != :default && invlink!(vi, spl)
return vi, true
end

Expand All @@ -60,7 +60,7 @@ function step(model, spl::Sampler{<:SGHMC}, vi::VarInfo, is_first::Val{false})
η, α = spl.alg.learning_rate, spl.alg.momentum_decay

Turing.DEBUG && @debug "X-> R..."
if spl.selector.tag[] != :default
if spl.selector.tag != :default
link!(vi, spl)
runmodel!(model, vi, spl)
end
Expand All @@ -79,7 +79,7 @@ function step(model, spl::Sampler{<:SGHMC}, vi::VarInfo, is_first::Val{false})
vi[spl] = θ

Turing.DEBUG && @debug "R -> X..."
spl.selector.tag[] != :default && invlink!(vi, spl)
spl.selector.tag != :default && invlink!(vi, spl)

Turing.DEBUG && @debug "always accept..."
return vi, true
Expand Down
8 changes: 4 additions & 4 deletions src/inference/sgld.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ function SGLD{AD}(n_iters, epsilon, space...) where AD
end

function step(model, spl::Sampler{<:SGLD}, vi::VarInfo, is_first::Val{true})
spl.selector.tag[] != :default && link!(vi, spl)
spl.selector.tag != :default && link!(vi, spl)

spl.info[:wum] = NaiveCompAdapter(UnitPreConditioner(), ManualSSAdapter(MSSState(spl.alg.epsilon)))

# Initialize iteration counter
spl.info[:t] = 0

spl.selector.tag[] != :default && invlink!(vi, spl)
spl.selector.tag != :default && invlink!(vi, spl)
return vi, true
end

Expand All @@ -65,7 +65,7 @@ function step(model, spl::Sampler{<:SGLD}, vi::VarInfo, is_first::Val{false})
mssa.state.ϵ = ϵ_t

Turing.DEBUG && @debug "X-> R..."
if spl.selector.tag[] != :default
if spl.selector.tag != :default
link!(vi, spl)
runmodel!(model, vi, spl)
end
Expand All @@ -82,7 +82,7 @@ function step(model, spl::Sampler{<:SGLD}, vi::VarInfo, is_first::Val{false})
vi[spl] = θ

Turing.DEBUG && @debug "R -> X..."
spl.selector.tag[] != :default && invlink!(vi, spl)
spl.selector.tag != :default && invlink!(vi, spl)

return vi, true
end
Loading