Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Markdown, Libtask, MacroTools
@reexport using Distributions, MCMCChains, Libtask
using Flux.Tracker: Tracker

import Base: ~, convert, promote_rule, rand, getindex, setindex!
import Base: ~, ==, convert, hash, promote_rule, rand, getindex, setindex!
import Distributions: sample
import MCMCChains: AbstractChains, Chains

Expand Down Expand Up @@ -62,6 +62,15 @@ end
(model::Model)(args...; kwargs...) = model.f(args..., model; kwargs...)
function runmodel! end

struct Selector
gid :: UInt64
tag :: Ref{Symbol} # :default, :invalid, :Gibbs, :HMC, etc.
end
Selector() = Selector(time_ns(), Ref(:default))
Selector(tag::Symbol) = Selector(time_ns(), Ref(tag))
hash(s::Selector) = hash(s.gid)
==(s1::Selector, s2::Selector) = s1.gid == s2.gid

abstract type AbstractSampler end

"""
Expand All @@ -83,10 +92,12 @@ Turing translates models to chunks that call the modelling functions at specifie
then include that file at the end of this one.
"""
mutable struct Sampler{T} <: AbstractSampler
alg :: T
info :: Dict{Symbol, Any} # sampler infomation
alg :: T
info :: Dict{Symbol, Any} # sampler infomation
selector :: Selector
end
Sampler(alg, model) = Sampler(alg)
Sampler(alg, model::Model) = Sampler(alg)
Sampler(alg, info::Dict{Symbol, Any}) = Sampler(alg, info, Selector())

include("utilities/Utilities.jl")
using .Utilities
Expand Down
48 changes: 31 additions & 17 deletions src/core/VarReplay.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module VarReplay

using ...Turing: Turing, CACHERESET, CACHEIDCS, CACHERANGES, Model,
AbstractSampler, Sampler, SampleFromPrior
AbstractSampler, Sampler, SampleFromPrior,
Selector
using ...Utilities: vectorize, reconstruct, reconstruct!
using Bijectors: SimplexDistribution
using Distributions
Expand Down Expand Up @@ -70,7 +71,7 @@ mutable struct VarInfo
vals :: Vector{Real}
rvs :: Dict{Union{VarName,Vector{VarName}},Any}
dists :: Vector{Distributions.Distribution}
gids :: Vector{Int}
gids :: Vector{Set{Selector}}
logp :: Real
pred :: Dict{Symbol,Any}
num_produce :: Int # num of produce calls from trace, each produce corresponds to an observe.
Expand Down Expand Up @@ -139,8 +140,7 @@ getsym(vi::VarInfo, vn::VarName) = vi.vns[getidx(vi, vn)].sym
getdist(vi::VarInfo, vn::VarName) = vi.dists[getidx(vi, vn)]

getgid(vi::VarInfo, vn::VarName) = vi.gids[getidx(vi, vn)]

setgid!(vi::VarInfo, gid::Int, vn::VarName) = vi.gids[getidx(vi, vn)] = gid
setgid!(vi::VarInfo, gid::Selector, vn::VarName) = push!(vi.gids[getidx(vi, vn)], gid)

istrans(vi::VarInfo, vn::VarName) = is_flagged(vi, vn, "trans")
settrans!(vi::VarInfo, trans::Bool, vn::VarName) = trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans")
Expand Down Expand Up @@ -207,6 +207,9 @@ end
Base.getindex(vi::VarInfo, vview::VarView) = copy(getval(vi, vview))
Base.setindex!(vi::VarInfo, val::Any, vview::VarView) = setval!(vi, val, vview)

Base.getindex(vi::VarInfo, s::Selector) = copy(getval(vi, getranges(vi, s)))
Base.setindex!(vi::VarInfo, val::Any, s::Selector) = setval!(vi, val, getranges(vi, s))

Base.getindex(vi::VarInfo, spl::Sampler) = copy(getval(vi, getranges(vi, spl)))
Base.setindex!(vi::VarInfo, val::Any, spl::Sampler) = setval!(vi, val, getranges(vi, spl))

Expand Down Expand Up @@ -237,7 +240,9 @@ function Base.show(io::IO, vi::VarInfo)
end

# Add a new entry to VarInfo
function push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution, gid::Int)
push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution) = push!(vi, vn, r, dist, Set{Selector}([]))
push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution, gid::Selector) = push!(vi, vn, r, dist, Set([gid]))
function push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distribution, gidset::Set{Selector})

@assert ~(vn in vns(vi)) "[push!] attempt to add an exisitng variable $(sym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gid"

Expand All @@ -249,7 +254,7 @@ function push!(vi::VarInfo, vn::VarName, r::Any, dist::Distributions.Distributio
push!(vi.ranges, l+1:l+n)
append!(vi.vals, val)
push!(vi.dists, dist)
push!(vi.gids, gid)
push!(vi.gids, gidset)
push!(vi.orders, vi.num_produce)
push!(vi.flags["del"], false)
push!(vi.flags["trans"], false)
Expand Down Expand Up @@ -296,9 +301,10 @@ end
# vi.logp = vi.logp[end:end]
# end

# Get all indices of variables belonging to gid or 0
getidcs(vi::VarInfo) = getidcs(vi, nothing)
getidcs(vi::VarInfo, ::SampleFromPrior) = filter(i -> vi.gids[i] == 0, 1:length(vi.gids))
# Get all indices of variables belonging to SampleFromPrior:
# if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to
# the SampleFromPrior sampler
getidcs(vi::VarInfo, ::SampleFromPrior) = filter(i -> isempty(vi.gids[i]) , 1:length(vi.gids))
function getidcs(vi::VarInfo, spl::Sampler)
# NOTE: 0b00 is the sanity flag for
# |\____ getidcs (mask = 0b10)
Expand All @@ -309,12 +315,18 @@ function getidcs(vi::VarInfo, spl::Sampler)
else
spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS
spl.info[:idcs] = filter(i ->
(vi.gids[i] == spl.alg.gid || vi.gids[i] == 0) && (isempty(spl.alg.space) || is_inside(vi.vns[i], spl.alg.space)),
(spl.selector in vi.gids[i] || isempty(vi.gids[i])) && (isempty(spl.alg.space) || is_inside(vi.vns[i], spl.alg.space)),
1:length(vi.gids)
)
end
end

# Get all indices of variables belonging to a given selector
function getidcs(vi::VarInfo, s::Selector, space::Set=Set())
filter(i -> (s in vi.gids[i] || isempty(vi.gids[i])) && (isempty(space) || is_inside(vi.vns[i], space)),
1:length(vi.gids))
end

function is_inside(vn::VarName, space::Set)::Bool
if vn.sym in space
return true
Expand All @@ -327,15 +339,13 @@ function is_inside(vn::VarName, space::Set)::Bool
end
end

# Get all values of variables belonging to gid or 0
getvals(vi::VarInfo) = getvals(vi, nothing)
# Get all values of variables belonging to spl.selector
getvals(vi::VarInfo, spl::AbstractSampler) = view(vi.vals, getidcs(vi, spl))

# Get all vns of variables belonging to gid or 0
getvns(vi::VarInfo) = getvns(vi, nothing)
# Get all vns of variables belonging to spl.selector
getvns(vi::VarInfo, spl::AbstractSampler) = view(vi.vns, getidcs(vi, spl))

# Get all vns of variables belonging to gid or 0
# Get all vns of variables belonging to spl.selector
function getranges(vi::VarInfo, spl::Sampler)
if ~haskey(spl.info, :cache_updated) spl.info[:cache_updated] = CACHERESET end
if haskey(spl.info, :ranges) && (spl.info[:cache_updated] & CACHERANGES) > 0
Expand All @@ -346,6 +356,10 @@ function getranges(vi::VarInfo, spl::Sampler)
end
end

function getranges(vi::VarInfo, s::Selector)
union(map(i -> vi.ranges[i], getidcs(vi, s))...)
end

# NOTE: this function below is not used anywhere but test files.
# we can safely remove it if we want.
function getretain(vi::VarInfo, spl::AbstractSampler)
Expand Down Expand Up @@ -381,8 +395,8 @@ function set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler)
end

function updategid!(vi::VarInfo, vn::VarName, spl::Sampler)
if ~isempty(spl.alg.space) && getgid(vi, vn) == 0 && getsym(vi, vn) in spl.alg.space
setgid!(vi, spl.alg.gid, vn)
if ~isempty(spl.alg.space) && isempty(getgid(vi, vn)) && getsym(vi, vn) in spl.alg.space
setgid!(vi, spl.selector, vn)
end
end

Expand Down
6 changes: 3 additions & 3 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ function assume(spl::A,
r = vi[vn]
else
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
push!(vi, vn, r, dist, 0)
push!(vi, vn, r, dist)
end
# NOTE: The importance weight is not correctly computed here because
# r is genereated from some uniform distribution which is different from the prior
Expand Down Expand Up @@ -144,13 +144,13 @@ function assume(spl::A,

if isa(dist, UnivariateDistribution) || isa(dist, MatrixDistribution)
for i = 1:n
push!(vi, vns[i], rs[i], dist, 0)
push!(vi, vns[i], rs[i], dist)
end
@assert size(var) == size(rs) "Turing.assume: variable and random number dimension unmatched"
var = rs
elseif isa(dist, MultivariateDistribution)
for i = 1:n
push!(vi, vns[i], rs[:,i], dist, 0)
push!(vi, vns[i], rs[:,i], dist)
end
if isa(var, Vector)
@assert length(var) == size(rs)[2] "Turing.assume: variable and random number dimension unmatched"
Expand Down
5 changes: 2 additions & 3 deletions src/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
struct DynamicNUTS{AD, T} <: Hamiltonian{AD}
n_iters :: Integer # number of samples
space :: Set{T} # sampling space, emtpy means all
gid :: Integer # group ID
end

"""
Expand Down Expand Up @@ -30,7 +29,7 @@ chn = sample(gdemo(1.5, 2.0), DynamicNUTS(2000))
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
function DynamicNUTS{AD}(n_iters::Integer, space...) where AD
_space = isa(space, Symbol) ? Set([space]) : Set(space)
DynamicNUTS{AD, eltype(_space)}(n_iters, _space, 0)
DynamicNUTS{AD, eltype(_space)}(n_iters, _space)
end

function Sampler(alg::DynamicNUTS{T}) where T <: Hamiltonian
Expand All @@ -53,7 +52,7 @@ function sample(model::Model,
vi = VarInfo()
model(vi, SampleFromUniform())

if spl.alg.gid == 0
if spl.selector.tag[] == :default
link!(vi, spl)
runmodel!(model, vi, spl)
end
Expand Down
27 changes: 18 additions & 9 deletions src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,24 @@ mutable struct Gibbs{A} <: InferenceAlgorithm
n_iters :: Int # number of Gibbs iterations
algs :: A # component sampling algorithms
thin :: Bool # if thinning to output only after a whole Gibbs sweep
gid :: Int
end
Gibbs(n_iters::Int, algs...; thin=true) = Gibbs(n_iters, algs, thin, 0)
Gibbs(alg::Gibbs, new_gid) = Gibbs(alg.n_iters, alg.algs, alg.thin, new_gid)
Gibbs(n_iters::Int, algs...; thin=true) = Gibbs(n_iters, algs, thin)

const GibbsComponent = Union{Hamiltonian,MH,PG}

function Sampler(alg::Gibbs, model::Model)
info = Dict{Symbol, Any}()
spl = Sampler(alg, info)

n_samplers = length(alg.algs)
samplers = Array{Sampler}(undef, n_samplers)

space = Set{Symbol}()

for i in 1:n_samplers
sub_alg = alg.algs[i]
if isa(sub_alg, GibbsComponent)
samplers[i] = Sampler(typeof(sub_alg)(sub_alg, i), model)
samplers[i] = Sampler(sub_alg, model)
samplers[i].selector.tag[] = Symbol(typeof(sub_alg))
else
@error("[Gibbs] unsupport base sampling algorithm $alg")
end
Expand All @@ -58,10 +59,9 @@ function Sampler(alg::Gibbs, model::Model)
@warn("[Gibbs] extra parameters specified by samplers don't exist in model: $(setdiff(space, Set(get_pvars(model))))")
end

info = Dict{Symbol, Any}()
info[:samplers] = samplers

Sampler(alg, info)
return spl
end

function sample(
Expand All @@ -73,8 +73,17 @@ function sample(
)

# Init the (master) Gibbs sampler
spl = reuse_spl_n > 0 ? resume_from.info[:spl] : Sampler(alg, model)

if reuse_spl_n > 0
spl = resume_from.info[:spl]
else
spl = Sampler(alg, model)
if resume_from != nothing
spl.selector = resume_from.info[:spl].selector
for i in 1:length(spl.info[:samplers])
spl.info[:samplers][i].selector = resume_from.info[:spl].info[:samplers][i].selector
end
end
end
@assert typeof(spl.alg) == typeof(alg) "[Turing] alg type mismatch; please use resume() to re-use spl"

# Initialize samples
Expand Down
30 changes: 13 additions & 17 deletions src/inference/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,18 @@ mutable struct HMC{AD, T} <: StaticHamiltonian{AD}
epsilon :: Float64 # leapfrog step size
tau :: Int # leapfrog step number
space :: Set{T} # sampling space, emtpy means all
gid :: Int # group ID
end
HMC(args...) = HMC{ADBackend()}(args...)
function HMC{AD}(epsilon::Float64, tau::Int, space...) where AD
_space = isa(space, Symbol) ? Set([space]) : Set(space)
return HMC{AD, eltype(_space)}(1, epsilon, tau, _space, 0)
return HMC{AD, eltype(_space)}(1, epsilon, tau, _space)
end
function HMC{AD}(n_iters::Int, epsilon::Float64, tau::Int) where AD
return HMC{AD, Any}(n_iters, epsilon, tau, Set(), 0)
return HMC{AD, Any}(n_iters, epsilon, tau, Set())
end
function HMC{AD}(n_iters::Int, epsilon::Float64, tau::Int, space...) where AD
_space = isa(space, Symbol) ? Set([space]) : Set(space)
return HMC{AD, eltype(_space)}(n_iters, epsilon, tau, _space, 0)
end
function HMC{AD1}(alg::HMC{AD2, T}, new_gid::Int) where {AD1, AD2, T}
return HMC{AD1, T}(alg.n_iters, alg.epsilon, alg.tau, alg.space, new_gid)
end
function HMC{AD, T}(alg::HMC, new_gid::Int) where {AD, T}
return HMC{AD, T}(alg.n_iters, alg.epsilon, alg.tau, alg.space, new_gid)
return HMC{AD, eltype(_space)}(n_iters, epsilon, tau, _space)
end

function hmc_step(θ, lj, lj_func, grad_func, H_func, ϵ, alg::HMC, momentum_sampler::Function;
Expand Down Expand Up @@ -108,6 +101,9 @@ function sample(model::Model, alg::Hamiltonian;
spl = reuse_spl_n > 0 ?
resume_from.info[:spl] :
Sampler(alg, adapt_conf)
if resume_from != nothing
spl.selector = resume_from.info[:spl].selector
end

@assert isa(spl.alg, Hamiltonian) "[Turing] alg type mismatch; please use resume() to re-use spl"

Expand Down Expand Up @@ -137,7 +133,7 @@ function sample(model::Model, alg::Hamiltonian;
deepcopy(resume_from.info[:vi])
end

if spl.alg.gid == 0
if spl.selector.tag[] == :default
link!(vi, spl)
runmodel!(model, vi, spl)
end
Expand Down Expand Up @@ -189,7 +185,7 @@ function sample(model::Model, alg::Hamiltonian;
c = Chain(0.0, samples) # wrap the result by Chain
if save_state # save state
# Convert vi back to X if vi is required to be saved
if spl.alg.gid == 0 invlink!(vi, spl) end
spl.selector.tag[] == :default && invlink!(vi, spl)
c = save(c, spl, model, vi, samples)
end
return c
Expand All @@ -201,11 +197,11 @@ function step(model, spl::Sampler{<:StaticHamiltonian}, vi::VarInfo, is_first::V
end

function step(model, spl::Sampler{<:AdaptiveHamiltonian}, vi::VarInfo, is_first::Val{true})
spl.alg.gid != 0 && link!(vi, spl)
spl.selector.tag[] != :default && link!(vi, spl)
epsilon = find_good_eps(model, spl, vi) # heuristically find good initial epsilon
dim = length(vi[spl])
spl.info[:wum] = ThreePhaseAdapter(spl, epsilon, dim)
spl.alg.gid != 0 && invlink!(vi, spl)
spl.selector.tag[] != :default && invlink!(vi, spl)
return vi, true
end

Expand All @@ -219,7 +215,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
spl.info[:eval_num] = 0

Turing.DEBUG && @debug "X-> R..."
if spl.alg.gid != 0
if spl.selector.tag[] != :default
link!(vi, spl)
runmodel!(model, vi, spl)
end
Expand All @@ -245,7 +241,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
setlogp!(vi, lj)
end

if PROGRESS[] && spl.alg.gid == 0
if PROGRESS[] && spl.selector.tag[] == :default
std_str = string(spl.info[:wum].pc)
std_str = length(std_str) >= 32 ? std_str[1:30]*"..." : std_str
haskey(spl.info, :progress) && ProgressMeter.update!(
Expand All @@ -260,7 +256,7 @@ function step(model, spl::Sampler{<:Hamiltonian}, vi::VarInfo, is_first::Val{fal
end

Turing.DEBUG && @debug "R -> X..."
spl.alg.gid != 0 && invlink!(vi, spl)
spl.selector.tag[] != :default && invlink!(vi, spl)

return vi, is_accept
end
Expand Down
Loading