Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ export @model, # modelling
@sampler,

MH, # classic sampling
ESS,
Gibbs,

HMC, # Hamiltonian-like sampling
Expand Down
12 changes: 6 additions & 6 deletions src/core/RandomVariables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ export VarName,
resetlogp!,
set_retained_vns_del_by_spl!,
is_flagged,
set_flag!,
unset_flag!,
setgid!,
updategid!,
Expand Down Expand Up @@ -495,11 +496,10 @@ end
end

# Get all vns of variables belonging to spl
_getvns(vi::UntypedVarInfo, spl::AbstractSampler) = view(vi.metadata.vns, _getidcs(vi, spl))
function _getvns(vi::TypedVarInfo, spl::AbstractSampler)
# Get a NamedTuple of the indices of variables belonging to `spl`, one entry for each symbol
idcs = _getidcs(vi, spl)
return _getvns(vi.metadata, idcs)
_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl)))
_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, _getidcs(vi, s, space))
function _getvns(vi::TypedVarInfo, s::Selector, space)
return _getvns(vi.metadata, _getidcs(vi, s, space))
end
# Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol
@generated function _getvns(metadata, idcs::NamedTuple{names}) where {names}
Expand All @@ -525,7 +525,7 @@ end
#end
end
# Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space`
@inline function _getranges(vi::AbstractVarInfo, s::Selector, space = Val(()))
@inline function _getranges(vi::AbstractVarInfo, s::Selector, space)
return _getranges(vi, _getidcs(vi, s, space))
end
@inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int})
Expand Down
45 changes: 38 additions & 7 deletions src/inference/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Inference
using ..Core, ..Core.RandomVariables, ..Utilities
using ..Core.RandomVariables: Metadata, _tail, VarInfo, TypedVarInfo,
islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize,
settrans!
settrans!, _getvns, getdist
using ..Core: split_var_str
using Distributions, Libtask, Bijectors
using ProgressMeter, LinearAlgebra
Expand All @@ -13,7 +13,7 @@ using ..Turing: Model, runmodel!, Turing,
Selector, AbstractSamplerState, DefaultContext, PriorContext,
LikelihoodContext, MiniBatchContext, NamedDist, NoDist
using StatsFuns: logsumexp
using Random: GLOBAL_RNG, AbstractRNG
using Random: GLOBAL_RNG, AbstractRNG, randexp
using AbstractMCMC

import MCMCChains: Chains
Expand All @@ -33,6 +33,7 @@ export InferenceAlgorithm,
SampleFromUniform,
SampleFromPrior,
MH,
ESS,
Gibbs, # classic sampling
HMC,
SGLD,
Expand Down Expand Up @@ -274,8 +275,8 @@ function _params_to_array(ts::Vector{T}, spl::Sampler) where {T<:AbstractTransit
end
push!(dicts, d)
end
# Convert the set to an ordered vector so the parameter ordering

# Convert the set to an ordered vector so the parameter ordering
# is deterministic.
ordered_names = collect(names)
vals = Matrix{Union{Real, Missing}}(undef, length(ts), length(ordered_names))
Expand Down Expand Up @@ -486,6 +487,7 @@ end
# Concrete algorithm implementations. #
#######################################

include("ess.jl")
include("hmc.jl")
include("mh.jl")
include("is.jl")
Expand All @@ -498,7 +500,7 @@ include("../contrib/inference/AdvancedSMCExtensions.jl")
# Typing tools #
################

for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :Gibbs)
for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :ESS, :Gibbs)
@eval getspace(::$alg{space}) where {space} = space
end
for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC)
Expand Down Expand Up @@ -635,7 +637,14 @@ function assume(
vi::VarInfo,
)
if haskey(vi, vn)
if is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = spl isa SampleFromUniform ? init(dist) : rand(dist)
vi[vn] = vectorize(dist, r)
setorder!(vi, vn, vi.num_produce)
else
r = vi[vn]
end
else
r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
push!(vi, vn, r, dist, spl)
Expand Down Expand Up @@ -792,9 +801,19 @@ function get_and_set_val!(
)
n = length(vns)
if haskey(vi, vns[1])
if is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
setorder!(vi, vn, vi.num_produce)
end
else
r = vi[vns]
end
else
r = isa(spl, SampleFromUniform) ? init(dist, n) : rand(dist, n)
r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n)
for i in 1:n
push!(vi, vns[i], r[:,i], dist, spl)
end
Expand All @@ -808,9 +827,21 @@ function get_and_set_val!(
spl::AbstractSampler,
)
if haskey(vi, vns[1])
if is_flagged(vi, vns[1], "del")
unset_flag!(vi, vns[1], "del")
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
r = f.(vns, dists)
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
setorder!(vi, vn, vi.num_produce)
end
else
r = reshape(vi[vec(vns)], size(vns))
end
else
f(vn, dist) = isa(spl, SampleFromUniform) ? init(dist) : rand(dist)
f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist)
r = f.(vns, dists)
push!.(Ref(vi), vns, r, dists, Ref(spl))
end
Expand Down
146 changes: 146 additions & 0 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""
ESS

Elliptical slice sampling algorithm.

# Examples
```jldoctest; setup = :(Random.seed!(1))
julia> @model gdemo(x) = begin
m ~ Normal()
x ~ Normal(m, 0.5)
end
gdemo (generic function with 2 methods)

julia> sample(gdemo(1.0), ESS(), 1_000) |> mean
Mean

│ Row │ parameters │ mean │
│ │ Symbol │ Float64 │
├─────┼────────────┼──────────┤
│ 1 │ m │ 0.824853 │
```
"""
struct ESS{space} <: InferenceAlgorithm end

ESS() = ESS{()}()
ESS(space::Symbol) = ESS{(space,)}()

mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState
vi::V
end

function Sampler(alg::ESS, model::Model, s::Selector)
# sanity check
vi = VarInfo(model)
space = getspace(alg)
vns = _getvns(vi, s, Val(space))
length(vns) == 1 ||
error("[ESS] does only support one variable ($(length(vns)) variables specified)")
for vn in vns[1]
dist = getdist(vi, vn)
isgaussian(dist) ||
error("[ESS] only supports Gaussian prior distributions")
end

state = ESSState(vi)
info = Dict{Symbol, Any}()

return Sampler(alg, info, s, state)
end

isgaussian(dist) = false
isgaussian(::Normal) = true
isgaussian(::NormalCanon) = true
isgaussian(::AbstractMvNormal) = true

# always accept in the first step
function step!(::AbstractRNG, model::Model, spl::Sampler{<:ESS}, ::Integer; kwargs...)
return Transition(spl)
end

function step!(
rng::AbstractRNG,
model::Model,
spl::Sampler{<:ESS},
::Integer,
::Transition;
kwargs...
)
# obtain mean of distribution
vi = spl.state.vi
vns = _getvns(vi, spl)
μ = mapreduce(vcat, vns[1]) do vn
dist = getdist(vi, vn)
vectorize(dist, mean(dist))
end

# obtain previous sample
f = vi[spl]

# recompute log-likelihood in logp
if spl.selector.tag !== :default
runmodel!(model, vi, spl)
end

# sample log-likelihood threshold for the next sample
threshold = getlogp(vi) - randexp(rng)

# sample from the prior
set_flag!(vi, vns[1][1], "del")
runmodel!(model, vi, spl)
ν = vi[spl]

# sample initial angle
θ = 2 * π * rand(rng)
θmin = θ - 2 * π
θmax = θ

while true
# compute proposal and apply correction for distributions with nonzero mean
sinθ, cosθ = sincos(θ)
a = 1 - (sinθ + cosθ)
vi[spl] = @. f * cosθ + ν * sinθ + μ * a

# recompute log-likelihood and check if threshold is reached
runmodel!(model, vi, spl)
if getlogp(vi) > threshold
break
end

# shrink the bracket
if θ < 0
θmin = θ
else
θmax = θ
end

# sample new angle
θ = θmin + rand(rng) * (θmax - θmin)
end

return Transition(spl)
end

function tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)
if vn in getspace(sampler)
return tilde(LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi)
else
return tilde(ctx, SampleFromPrior(), right, vn, inds, vi)
end
end

function tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return tilde(ctx, SampleFromPrior(), right, left, vi)
end

function dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi)
if vn in getspace(sampler)
return dot_tilde(LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi)
else
return dot_tilde(ctx, SampleFromPrior(), right, left, vn, inds, vi)
end
end

function dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return dot_tilde(ctx, SampleFromPrior(), right, left, vi)
end
4 changes: 2 additions & 2 deletions src/inference/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
### Gibbs samplers / compositional samplers.
###

const GibbsComponent = Union{Hamiltonian,MH,PG}
const GibbsComponent = Union{Hamiltonian,MH,ESS,PG}

"""
Gibbs(algs...)
Expand Down Expand Up @@ -150,7 +150,7 @@ function step!(
# Uncomment when developing thinning functionality.
# Retrieve symbol to store this subsample.
# symbol_id = Symbol(local_spl.selector.gid)

# # Store the subsample.
# spl.state.subsamples[symbol_id][] = trans

Expand Down
58 changes: 58 additions & 0 deletions test/inference/ess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using Turing, Random, Test

dir = splitdir(splitdir(pathof(Turing))[1])[1]
include(dir*"/test/test_utils/AllUtils.jl")

@testset "ESS" begin
@model demo(x) = begin
m ~ Normal()
x ~ Normal(m, 0.5)
end
demo_default = demo(1.0)

@model demodot(x) = begin
m = Vector{Float64}(undef, 2)
@. m ~ Normal()
x ~ Normal(m[2], 0.5)
end
demodot_default = demodot(1.0)

@turing_testset "ESS constructor" begin
Random.seed!(0)
N = 500
s1 = ESS()
s2 = ESS(:m)
s3 = Gibbs(ESS(:m), MH(:s))

c1 = sample(demo_default, s1, N)
c2 = sample(demo_default, s2, N)
c3 = sample(demodot_default, s1, N)
c4 = sample(demodot_default, s2, N)
c5 = sample(gdemo_default, s3, N)
end

@numerical_testset "ESS inference" begin
Random.seed!(1)
chain = sample(demo_default, ESS(), 5_000)
check_numerical(chain, [:m], [0.8], atol = 0.1)

Random.seed!(1)
chain = sample(demodot_default, ESS(), 5_000)
check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8], atol = 0.1)

Random.seed!(100)
alg = Gibbs(
CSMC(15, :s),
ESS(:m))
chain = sample(gdemo(1.5, 2.0), alg, 10_000)
check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1)

# MoGtest
Random.seed!(125)
alg = Gibbs(
CSMC(15, :z1, :z2, :z3, :z4),
ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, alg, 6000)
check_MoGtest_default(chain, atol = 0.1)
end
end
Loading