diff --git a/src/Turing.jl b/src/Turing.jl index 2876a29488..0652ff8920 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -214,6 +214,7 @@ export @model, # modelling @sampler, MH, # classic sampling + ESS, Gibbs, HMC, # Hamiltonian-like sampling diff --git a/src/core/RandomVariables.jl b/src/core/RandomVariables.jl index e3efe74030..af4fdae6fa 100644 --- a/src/core/RandomVariables.jl +++ b/src/core/RandomVariables.jl @@ -34,6 +34,7 @@ export VarName, resetlogp!, set_retained_vns_del_by_spl!, is_flagged, + set_flag!, unset_flag!, setgid!, updategid!, @@ -495,11 +496,10 @@ end end # Get all vns of variables belonging to spl -_getvns(vi::UntypedVarInfo, spl::AbstractSampler) = view(vi.metadata.vns, _getidcs(vi, spl)) -function _getvns(vi::TypedVarInfo, spl::AbstractSampler) - # Get a NamedTuple of the indices of variables belonging to `spl`, one entry for each symbol - idcs = _getidcs(vi, spl) - return _getvns(vi.metadata, idcs) +_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) +_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, _getidcs(vi, s, space)) +function _getvns(vi::TypedVarInfo, s::Selector, space) + return _getvns(vi.metadata, _getidcs(vi, s, space)) end # Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol @generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} @@ -525,7 +525,7 @@ end #end end # Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::AbstractVarInfo, s::Selector, space = Val(())) +@inline function _getranges(vi::AbstractVarInfo, s::Selector, space) return _getranges(vi, _getidcs(vi, s, space)) end @inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int}) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 39d5fefae8..52a0e08b8b 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -3,7 +3,7 @@ module Inference using ..Core, ..Core.RandomVariables, ..Utilities using ..Core.RandomVariables: Metadata, _tail, VarInfo, TypedVarInfo, islinked, invlink!, getlogp, tonamedtuple, VarName, getsym, vectorize, - settrans! + settrans!, _getvns, getdist using ..Core: split_var_str using Distributions, Libtask, Bijectors using ProgressMeter, LinearAlgebra @@ -13,7 +13,7 @@ using ..Turing: Model, runmodel!, Turing, Selector, AbstractSamplerState, DefaultContext, PriorContext, LikelihoodContext, MiniBatchContext, NamedDist, NoDist using StatsFuns: logsumexp -using Random: GLOBAL_RNG, AbstractRNG +using Random: GLOBAL_RNG, AbstractRNG, randexp using AbstractMCMC import MCMCChains: Chains @@ -33,6 +33,7 @@ export InferenceAlgorithm, SampleFromUniform, SampleFromPrior, MH, + ESS, Gibbs, # classic sampling HMC, SGLD, @@ -274,8 +275,8 @@ function _params_to_array(ts::Vector{T}, spl::Sampler) where {T<:AbstractTransit end push!(dicts, d) end - - # Convert the set to an ordered vector so the parameter ordering + + # Convert the set to an ordered vector so the parameter ordering # is deterministic. ordered_names = collect(names) vals = Matrix{Union{Real, Missing}}(undef, length(ts), length(ordered_names)) @@ -486,6 +487,7 @@ end # Concrete algorithm implementations. # ####################################### +include("ess.jl") include("hmc.jl") include("mh.jl") include("is.jl") @@ -498,7 +500,7 @@ include("../contrib/inference/AdvancedSMCExtensions.jl") # Typing tools # ################ -for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :Gibbs) +for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :ESS, :Gibbs) @eval getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) @@ -635,7 +637,14 @@ function assume( vi::VarInfo, ) if haskey(vi, vn) + if is_flagged(vi, vn, "del") + unset_flag!(vi, vn, "del") + r = spl isa SampleFromUniform ? init(dist) : rand(dist) + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, vi.num_produce) + else r = vi[vn] + end else r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist) push!(vi, vn, r, dist, spl) @@ -792,9 +801,19 @@ function get_and_set_val!( ) n = length(vns) if haskey(vi, vns[1]) + if is_flagged(vi, vns[1], "del") + unset_flag!(vi, vns[1], "del") + r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n) + for i in 1:n + vn = vns[i] + vi[vn] = vectorize(dist, r[:, i]) + setorder!(vi, vn, vi.num_produce) + end + else r = vi[vns] + end else - r = isa(spl, SampleFromUniform) ? init(dist, n) : rand(dist, n) + r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n) for i in 1:n push!(vi, vns[i], r[:,i], dist, spl) end @@ -808,9 +827,21 @@ function get_and_set_val!( spl::AbstractSampler, ) if haskey(vi, vns[1]) + if is_flagged(vi, vns[1], "del") + unset_flag!(vi, vns[1], "del") + f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist) + r = f.(vns, dists) + for i in eachindex(vns) + vn = vns[i] + dist = dists isa AbstractArray ? dists[i] : dists + vi[vn] = vectorize(dist, r[i]) + setorder!(vi, vn, vi.num_produce) + end + else r = reshape(vi[vec(vns)], size(vns)) + end else - f(vn, dist) = isa(spl, SampleFromUniform) ? init(dist) : rand(dist) + f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist) r = f.(vns, dists) push!.(Ref(vi), vns, r, dists, Ref(spl)) end diff --git a/src/inference/ess.jl b/src/inference/ess.jl new file mode 100644 index 0000000000..307b5e2d90 --- /dev/null +++ b/src/inference/ess.jl @@ -0,0 +1,146 @@ +""" + ESS + +Elliptical slice sampling algorithm. + +# Examples +```jldoctest; setup = :(Random.seed!(1)) +julia> @model gdemo(x) = begin + m ~ Normal() + x ~ Normal(m, 0.5) + end +gdemo (generic function with 2 methods) + +julia> sample(gdemo(1.0), ESS(), 1_000) |> mean +Mean + +│ Row │ parameters │ mean │ +│ │ Symbol │ Float64 │ +├─────┼────────────┼──────────┤ +│ 1 │ m │ 0.824853 │ +``` +""" +struct ESS{space} <: InferenceAlgorithm end + +ESS() = ESS{()}() +ESS(space::Symbol) = ESS{(space,)}() + +mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState + vi::V +end + +function Sampler(alg::ESS, model::Model, s::Selector) + # sanity check + vi = VarInfo(model) + space = getspace(alg) + vns = _getvns(vi, s, Val(space)) + length(vns) == 1 || + error("[ESS] does only support one variable ($(length(vns)) variables specified)") + for vn in vns[1] + dist = getdist(vi, vn) + isgaussian(dist) || + error("[ESS] only supports Gaussian prior distributions") + end + + state = ESSState(vi) + info = Dict{Symbol, Any}() + + return Sampler(alg, info, s, state) +end + +isgaussian(dist) = false +isgaussian(::Normal) = true +isgaussian(::NormalCanon) = true +isgaussian(::AbstractMvNormal) = true + +# always accept in the first step +function step!(::AbstractRNG, model::Model, spl::Sampler{<:ESS}, ::Integer; kwargs...) + return Transition(spl) +end + +function step!( + rng::AbstractRNG, + model::Model, + spl::Sampler{<:ESS}, + ::Integer, + ::Transition; + kwargs... +) + # obtain mean of distribution + vi = spl.state.vi + vns = _getvns(vi, spl) + μ = mapreduce(vcat, vns[1]) do vn + dist = getdist(vi, vn) + vectorize(dist, mean(dist)) + end + + # obtain previous sample + f = vi[spl] + + # recompute log-likelihood in logp + if spl.selector.tag !== :default + runmodel!(model, vi, spl) + end + + # sample log-likelihood threshold for the next sample + threshold = getlogp(vi) - randexp(rng) + + # sample from the prior + set_flag!(vi, vns[1][1], "del") + runmodel!(model, vi, spl) + ν = vi[spl] + + # sample initial angle + θ = 2 * π * rand(rng) + θmin = θ - 2 * π + θmax = θ + + while true + # compute proposal and apply correction for distributions with nonzero mean + sinθ, cosθ = sincos(θ) + a = 1 - (sinθ + cosθ) + vi[spl] = @. f * cosθ + ν * sinθ + μ * a + + # recompute log-likelihood and check if threshold is reached + runmodel!(model, vi, spl) + if getlogp(vi) > threshold + break + end + + # shrink the bracket + if θ < 0 + θmin = θ + else + θmax = θ + end + + # sample new angle + θ = θmin + rand(rng) * (θmax - θmin) + end + + return Transition(spl) +end + +function tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi) + if vn in getspace(sampler) + return tilde(LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi) + else + return tilde(ctx, SampleFromPrior(), right, vn, inds, vi) + end +end + +function tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return tilde(ctx, SampleFromPrior(), right, left, vi) +end + +function dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi) + if vn in getspace(sampler) + return dot_tilde(LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi) + else + return dot_tilde(ctx, SampleFromPrior(), right, left, vn, inds, vi) + end +end + +function dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return dot_tilde(ctx, SampleFromPrior(), right, left, vi) +end \ No newline at end of file diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index 7a3b62193c..e684386149 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -2,7 +2,7 @@ ### Gibbs samplers / compositional samplers. ### -const GibbsComponent = Union{Hamiltonian,MH,PG} +const GibbsComponent = Union{Hamiltonian,MH,ESS,PG} """ Gibbs(algs...) @@ -150,7 +150,7 @@ function step!( # Uncomment when developing thinning functionality. # Retrieve symbol to store this subsample. # symbol_id = Symbol(local_spl.selector.gid) - + # # Store the subsample. # spl.state.subsamples[symbol_id][] = trans diff --git a/test/inference/ess.jl b/test/inference/ess.jl new file mode 100644 index 0000000000..c61cf5c6be --- /dev/null +++ b/test/inference/ess.jl @@ -0,0 +1,58 @@ +using Turing, Random, Test + +dir = splitdir(splitdir(pathof(Turing))[1])[1] +include(dir*"/test/test_utils/AllUtils.jl") + +@testset "ESS" begin + @model demo(x) = begin + m ~ Normal() + x ~ Normal(m, 0.5) + end + demo_default = demo(1.0) + + @model demodot(x) = begin + m = Vector{Float64}(undef, 2) + @. m ~ Normal() + x ~ Normal(m[2], 0.5) + end + demodot_default = demodot(1.0) + + @turing_testset "ESS constructor" begin + Random.seed!(0) + N = 500 + s1 = ESS() + s2 = ESS(:m) + s3 = Gibbs(ESS(:m), MH(:s)) + + c1 = sample(demo_default, s1, N) + c2 = sample(demo_default, s2, N) + c3 = sample(demodot_default, s1, N) + c4 = sample(demodot_default, s2, N) + c5 = sample(gdemo_default, s3, N) + end + + @numerical_testset "ESS inference" begin + Random.seed!(1) + chain = sample(demo_default, ESS(), 5_000) + check_numerical(chain, [:m], [0.8], atol = 0.1) + + Random.seed!(1) + chain = sample(demodot_default, ESS(), 5_000) + check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8], atol = 0.1) + + Random.seed!(100) + alg = Gibbs( + CSMC(15, :s), + ESS(:m)) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) + + # MoGtest + Random.seed!(125) + alg = Gibbs( + CSMC(15, :z1, :z2, :z3, :z4), + ESS(:mu1), ESS(:mu2)) + chain = sample(MoGtest_default, alg, 6000) + check_MoGtest_default(chain, atol = 0.1) + end +end diff --git a/test/inference/gibbs.jl b/test/inference/gibbs.jl index b63287767d..cb03ef50b8 100644 --- a/test/inference/gibbs.jl +++ b/test/inference/gibbs.jl @@ -11,6 +11,7 @@ include(dir*"/test/test_utils/AllUtils.jl") s3 = Gibbs(PG(3, :s), HMC( 0.4, 8, :m)) s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m)) s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m)) + s6 = Gibbs(HMC(0.1, 5, :s), ESS(:m)) c1 = sample(gdemo_default, s1, N) @@ -18,6 +19,7 @@ include(dir*"/test/test_utils/AllUtils.jl") c3 = sample(gdemo_default, s3, N) c4 = sample(gdemo_default, s4, N) c5 = sample(gdemo_default, s5, N) + c6 = sample(gdemo_default, s6, N) # Test gid of each samplers g = Turing.Sampler(s3, gdemo_default) @@ -34,6 +36,13 @@ include(dir*"/test/test_utils/AllUtils.jl") chain = sample(gdemo(1.5, 2.0), alg, 3000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) + Random.seed!(100) + alg = Gibbs( + CSMC(15, :s), + ESS(:m)) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) + alg = CSMC(10) chain = sample(gdemo(1.5, 2.0), alg, 5000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) @@ -48,5 +57,12 @@ include(dir*"/test/test_utils/AllUtils.jl") check_MoGtest_default(chain, atol = 0.1) setadsafe(false) + + Random.seed!(200) + gibbs = Gibbs( + PG(10, :z1, :z2, :z3, :z4), + ESS(:mu1), ESS(:mu2)) + chain = sample(MoGtest_default, gibbs, 1500) + check_MoGtest_default(chain, atol = 0.1) end end diff --git a/test/runtests.jl b/test/runtests.jl index 8d29e8deb8..be044d0b63 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -29,6 +29,7 @@ include("test_utils/AllUtils.jl") include("inference/hmc.jl") include("inference/is.jl") include("inference/mh.jl") + include("inference/ess.jl") include("inference/AdvancedSMC.jl") include("inference/Inference.jl") end