From 584d78a53004dab06d9237c0feafa2b54e78133d Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 23 Nov 2019 20:12:49 +0100 Subject: [PATCH 01/12] Add elliptical slice sampling algorithm --- src/Turing.jl | 1 + src/inference/Inference.jl | 12 ++-- src/inference/ess.jl | 132 +++++++++++++++++++++++++++++++++++++ src/inference/gibbs.jl | 4 +- test/inference/ess.jl | 44 +++++++++++++ test/inference/gibbs.jl | 9 +++ test/runtests.jl | 1 + 7 files changed, 196 insertions(+), 7 deletions(-) create mode 100644 src/inference/ess.jl create mode 100644 test/inference/ess.jl diff --git a/src/Turing.jl b/src/Turing.jl index 7c2816d6cd..67c358e006 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -158,6 +158,7 @@ export @model, # modelling @VarName, MH, # classic sampling + ESS, Gibbs, HMC, # Hamiltonian-like sampling diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index c06390fc60..c033cc2c61 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -1,7 +1,7 @@ module Inference using ..Core, ..Core.RandomVariables, ..Utilities -using ..Core.RandomVariables: Metadata, _tail, TypedVarInfo, +using ..Core.RandomVariables: Metadata, _tail, TypedVarInfo, islinked, invlink!, getlogp, tonamedtuple using Distributions, Libtask, Bijectors using ProgressMeter, LinearAlgebra @@ -11,7 +11,7 @@ using ..Turing: Model, runmodel!, get_pvars, get_dvars, Selector, AbstractSamplerState using ..Turing: in_pvars, in_dvars, Turing using StatsFuns: logsumexp -using Random: GLOBAL_RNG, AbstractRNG +using Random: GLOBAL_RNG, AbstractRNG, randexp using ..Turing.Interface import MCMCChains: Chains @@ -32,6 +32,7 @@ export InferenceAlgorithm, SampleFromUniform, SampleFromPrior, MH, + ESS, Gibbs, # classic sampling HMC, SGLD, @@ -273,8 +274,8 @@ function _params_to_array(ts::Vector{T}, spl::Sampler) where {T<:AbstractTransit end push!(dicts, d) end - - # Convert the set to an ordered vector so the parameter ordering + + # Convert the set to an ordered vector so the parameter ordering # is deterministic. ordered_names = collect(names) vals = Matrix{Union{Real, Missing}}(undef, length(ts), length(ordered_names)) @@ -514,6 +515,7 @@ end # Concrete algorithm implementations. # ####################################### +include("ess.jl") include("hmc.jl") include("mh.jl") include("is.jl") @@ -526,7 +528,7 @@ include("../contrib/inference/AdvancedSMCExtensions.jl") # Typing tools # ################ -for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :Gibbs) +for alg in (:SMC, :PG, :PMMH, :IPMCMC, :MH, :IS, :ESS, :Gibbs) @eval getspace(::$alg{space}) where {space} = space end for alg in (:HMC, :HMCDA, :NUTS, :SGLD, :SGHMC) diff --git a/src/inference/ess.jl b/src/inference/ess.jl new file mode 100644 index 0000000000..7251012724 --- /dev/null +++ b/src/inference/ess.jl @@ -0,0 +1,132 @@ +""" + ESS + +Elliptical slice sampling algorithm. + +# Examples +```jldoctest; setup = :(Random.seed!(1)) +julia> @model gdemo(x) = begin + m ~ Normal() + x ~ Normal(m, 0.5) + end +gdemo (generic function with 2 methods) + +julia> sample(gdemo(1.0), ESS(), 1_000) |> mean +Mean + +│ Row │ parameters │ mean │ +│ │ Symbol │ Float64 │ +├─────┼────────────┼──────────┤ +│ 1 │ m │ 0.811555 │ +``` +""" +struct ESS{space} <: InferenceAlgorithm end + +ESS() = ESS{()}() +ESS(space::Symbol) = ESS{(space,)}() + +mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState + vi::V +end + +ESSState(model::Model) = ESSState(VarInfo(model)) + +function Sampler(alg::ESS, model::Model, s::Selector) + # sanity check + space = getspace(alg) + if isempty(space) + pvars = get_pvars(model) + length(pvars) == 1 || + error("[ESS] no symbol specified to sampler although there is not exactly one model parameter ($pvars)") + end + + state = ESSState(model) + info = Dict{Symbol, Any}() + + return Sampler(alg, info, s, state) +end + +# always accept in the first step +function step!(::AbstractRNG, ::Model, spl::Sampler{<:ESS}, ::Integer; kwargs...) + return Transition(spl) +end + +function step!( + rng::AbstractRNG, + model::Model, + spl::Sampler{<:ESS}, + ::Integer, + ::Transition; + kwargs... +) + # recompute log-likelihood in logp + vi = spl.state.vi + if spl.selector.tag !== :default + runmodel!(model, vi, spl) + end + + # obtain previous sample + f = copy(vi[spl]) + + # sample log-likelihood threshold for the next sample + threshold = getlogp(vi) - randexp(rng) + + # sample from the prior + runmodel!(model, vi, spl) + ν = copy(vi[spl]) + + # sample initial angle + θ = 2 * π * rand(rng) + θₘᵢₙ = θ - 2 * π + θₘₐₓ = θ + + while true + # compute proposal + sinθ, cosθ = sincos(θ) + @. vi[spl] = f * cosθ + ν * sinθ + + # recompute log-likelihood and check if threshold is reached + resetlogp!(vi) + model(vi, spl) + if getlogp(vi) > threshold + break + end + + # shrink the bracket + if θ < 0 + θₘᵢₙ = θ + else + θₘₐₓ = θ + end + + # sample new angle + θ = θₘᵢₙ + rand(rng) * (θₘₐₓ - θₘᵢₙ) + end + + return Transition(spl) +end + +isnormal(dist) = false +isnormal(::Normal) = true +isnormal(::NormalCanon) = true +isnormal(::AbstractMvNormal) = true + +function assume(spl::Sampler{<:ESS}, dist::Distribution, vn::VarName, vi::VarInfo) + space = getspace(spl) + if space === () || space === (vn.sym,) + isnormal(dist) || + error("[ESS] does only support normally distributed prior distributions") + + r = rand(dist) + vi[vn] = vectorize(dist, r) + setgid!(vi, spl.selector, vn) + return r, zero(Base.promote_eltype(dist, r)) + else + r = vi[vn] + return r, logpdf(dist, r) + end +end + +function observe(spl::Sampler{<:ESS}, dist::Distribution, value, vi::VarInfo) + return observe(dist, value, vi) +end diff --git a/src/inference/gibbs.jl b/src/inference/gibbs.jl index 155492327b..05bc202297 100644 --- a/src/inference/gibbs.jl +++ b/src/inference/gibbs.jl @@ -2,7 +2,7 @@ ### Gibbs samplers / compositional samplers. ### -const GibbsComponent = Union{Hamiltonian,MH,PG} +const GibbsComponent = Union{Hamiltonian,MH,ESS,PG} """ Gibbs(algs...) @@ -157,7 +157,7 @@ function step!( # Uncomment when developing thinning functionality. # Retrieve symbol to store this subsample. # symbol_id = Symbol(local_spl.selector.gid) - + # # Store the subsample. # spl.state.subsamples[symbol_id][] = trans diff --git a/test/inference/ess.jl b/test/inference/ess.jl new file mode 100644 index 0000000000..ac2d95383e --- /dev/null +++ b/test/inference/ess.jl @@ -0,0 +1,44 @@ +using Turing, Random, Test + +dir = splitdir(splitdir(pathof(Turing))[1])[1] +include(dir*"/test/test_utils/AllUtils.jl") + +@testset "ESS" begin + @model demo(x) = begin + m ~ Normal() + x ~ Normal(m, 0.5) + end + demo_default = demo(1.0) + + @turing_testset "ESS constructor" begin + Random.seed!(0) + N = 500 + s1 = ESS() + s2 = ESS(:m) + s3 = Gibbs(ESS(:m), MH(:s)) + + c1 = sample(demo_default, s1, N) + c2 = sample(demo_default, s2, N) + c3 = sample(gdemo_default, s3, N) + end + + @numerical_testset "ESS inference" begin + Random.seed!(1) + alg = ESS() + chain = sample(demo_default, alg, 5_000) + check_numerical(chain, [:m], [0.8], atol = 0.1) + + Random.seed!(100) + alg = Gibbs(CSMC(15, :s), ESS(:m)) + chain = sample(gdemo_default, alg, 5_000) + @test_broken check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) + + # MoGtest + Random.seed!(125) + gibbs = Gibbs( + CSMC(15, :z1, :z2, :z3, :z4), + ESS(:mu1), ESS(:mu2)) + chain = sample(MoGtest_default, gibbs, 6000) + check_MoGtest_default(chain, atol = 0.1) + end +end diff --git a/test/inference/gibbs.jl b/test/inference/gibbs.jl index b63287767d..e908cc2742 100644 --- a/test/inference/gibbs.jl +++ b/test/inference/gibbs.jl @@ -11,6 +11,7 @@ include(dir*"/test/test_utils/AllUtils.jl") s3 = Gibbs(PG(3, :s), HMC( 0.4, 8, :m)) s4 = Gibbs(PG(3, :s), HMC(0.4, 8, :m)) s5 = Gibbs(CSMC(3, :s), HMC(0.4, 8, :m)) + s6 = Gibbs(HMC(0.1, 5, :s), ESS(:m)) c1 = sample(gdemo_default, s1, N) @@ -18,6 +19,7 @@ include(dir*"/test/test_utils/AllUtils.jl") c3 = sample(gdemo_default, s3, N) c4 = sample(gdemo_default, s4, N) c5 = sample(gdemo_default, s5, N) + c6 = sample(gdemo_default, s6, N) # Test gid of each samplers g = Turing.Sampler(s3, gdemo_default) @@ -48,5 +50,12 @@ include(dir*"/test/test_utils/AllUtils.jl") check_MoGtest_default(chain, atol = 0.1) setadsafe(false) + + Random.seed!(200) + gibbs = Gibbs( + PG(10, :z1, :z2, :z3, :z4), + ESS(:mu1), ESS(:mu2)) + chain = sample(MoGtest_default, gibbs, 1500) + check_MoGtest_default(chain, atol = 0.1) end end diff --git a/test/runtests.jl b/test/runtests.jl index 876290f7b6..0c5a437852 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,6 +28,7 @@ include("test_utils/AllUtils.jl") include("inference/hmc.jl") include("inference/is.jl") include("inference/mh.jl") + include("inference/ess.jl") include("inference/AdvancedSMC.jl") include("inference/Inference.jl") end From 2de7d8b0f1b41b397022c6784a852a202b744d72 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sat, 7 Dec 2019 07:14:41 +0100 Subject: [PATCH 02/12] Allow nonzero mean and update according to comments --- src/inference/Inference.jl | 2 +- src/inference/ess.jl | 13 +++++++++++-- test/inference/ess.jl | 4 ++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index c033cc2c61..4acf42f255 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -2,7 +2,7 @@ module Inference using ..Core, ..Core.RandomVariables, ..Utilities using ..Core.RandomVariables: Metadata, _tail, TypedVarInfo, - islinked, invlink!, getlogp, tonamedtuple + islinked, invlink!, getlogp, tonamedtuple, _getvns, getdist using Distributions, Libtask, Bijectors using ProgressMeter, LinearAlgebra using ..Turing: PROGRESS, CACHERESET, AbstractSampler diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 7251012724..5de51dc96a 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -65,6 +65,12 @@ function step!( runmodel!(model, vi, spl) end + # obtain mean of distribution + vns = _getvns(vi, spl) + length(vns) == 1 || error("[ESS] does only support one parameter") + dist = getdist(vi, vns[1][1]) + μ = vectorize(dist, mean(dist)) + # obtain previous sample f = copy(vi[spl]) @@ -85,9 +91,12 @@ function step!( sinθ, cosθ = sincos(θ) @. vi[spl] = f * cosθ + ν * sinθ + # apply correction for distributions with nonzero mean + a = 1 - (sinθ + cosθ) + @. vi[spl] += μ * a + # recompute log-likelihood and check if threshold is reached - resetlogp!(vi) - model(vi, spl) + runmodel!(model, vi, spl) if getlogp(vi) > threshold break end diff --git a/test/inference/ess.jl b/test/inference/ess.jl index ac2d95383e..c20c82ed13 100644 --- a/test/inference/ess.jl +++ b/test/inference/ess.jl @@ -29,9 +29,9 @@ include(dir*"/test/test_utils/AllUtils.jl") check_numerical(chain, [:m], [0.8], atol = 0.1) Random.seed!(100) - alg = Gibbs(CSMC(15, :s), ESS(:m)) + alg = Gibbs(CSMC(50, :s), ESS(:m)) chain = sample(gdemo_default, alg, 5_000) - @test_broken check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) + check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) # MoGtest Random.seed!(125) From 82a04ffabcc8b7de6aee79358d734d78b9b2f733 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 6 Dec 2019 23:05:27 -0800 Subject: [PATCH 03/12] Update error message Co-Authored-By: Cameron Pfiffer --- src/inference/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 5de51dc96a..c706f1818a 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -124,7 +124,7 @@ function assume(spl::Sampler{<:ESS}, dist::Distribution, vn::VarName, vi::VarInf space = getspace(spl) if space === () || space === (vn.sym,) isnormal(dist) || - error("[ESS] does only support normally distributed prior distributions") + error("[ESS] only supports normally distributed prior distributions") r = rand(dist) vi[vn] = vectorize(dist, r) From bb4d1ea24453d63bc221e4f84d6d7066df4e17c2 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 19 Dec 2019 18:59:41 +0100 Subject: [PATCH 04/12] Add _getvns methods and remove static parameters --- src/core/RandomVariables.jl | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/src/core/RandomVariables.jl b/src/core/RandomVariables.jl index 513dd2bffb..2a4f0c1f33 100644 --- a/src/core/RandomVariables.jl +++ b/src/core/RandomVariables.jl @@ -186,7 +186,7 @@ function VarInfo(old_vi::UntypedVarInfo, spl, x::AbstractVector) return new_vi end function VarInfo(old_vi::TypedVarInfo, spl, x::AbstractVector) - md = newmetadata(old_vi.metadata, getspaceval(spl), x) + md = newmetadata(old_vi.metadata, Val(getspace(spl)), x) VarInfo(md, Base.RefValue{eltype(x)}(old_vi.logp), Ref(old_vi.num_produce)) end @generated function newmetadata(metadata::NamedTuple{names}, ::Val{space}, x) where {names, space} @@ -446,8 +446,6 @@ Returns a tuple of the unique symbols of random variables sampled in `vi`. syms(vi::UntypedVarInfo) = Tuple(unique!(map(vn -> vn.sym, vi.vns))) # get all symbols syms(vi::TypedVarInfo) = keys(vi.metadata) -getspaceval(alg) = Val(getspace(alg)) - # Get all indices of variables belonging to SampleFromPrior: # if the gid/selector of a var is an empty Set, then that var is assumed to be assigned to # the SampleFromPrior sampler @@ -479,16 +477,14 @@ end # spl.info[:idcs] #else #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHEIDCS - idcs = _getidcs(vi, spl.selector, getspaceval(spl.alg)) + idcs = _getidcs(vi, spl.selector, Val(getspace(spl))) #spl.info[:idcs] = idcs #end return idcs end -@inline function _getidcs(vi::UntypedVarInfo, s::Selector, ::Val{space}) where {space} - findinds(vi, s, Val(space)) -end -@inline function _getidcs(vi::TypedVarInfo, s::Selector, ::Val{space}) where {space} - return _getidcs(vi.metadata, s, Val(space)) +@inline _getidcs(vi::UntypedVarInfo, s::Selector, space::Val) = findinds(vi, s, space) +@inline function _getidcs(vi::TypedVarInfo, s::Selector, space::Val) + return _getidcs(vi.metadata, s, space) end # Get a NamedTuple for all the indices belonging to a given selector for each symbol @generated function _getidcs(metadata::NamedTuple{names}, s::Selector, ::Val{space}) where {names, space} @@ -517,11 +513,10 @@ end end # Get all vns of variables belonging to spl -_getvns(vi::UntypedVarInfo, spl::AbstractSampler) = view(vi.vns, _getidcs(vi, spl)) -function _getvns(vi::TypedVarInfo, spl::AbstractSampler) - # Get a NamedTuple of the indices of variables belonging to `spl`, one entry for each symbol - idcs = _getidcs(vi, spl) - return _getvns(vi.metadata, idcs) +_getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) +_getvns(vi::UntypedVarInfo, s::Selector, space::Val) = view(vi.vns, _getidcs(vi, s, space)) +function _getvns(vi::TypedVarInfo, s::Selector, space::Val) + return _getvns(vi.metadata, _getidcs(vi, s, space)) end # Get a NamedTuple for all the `vns` of indices `idcs`, one entry for each symbol @generated function _getvns(metadata, idcs::NamedTuple{names}) where {names} @@ -541,14 +536,14 @@ end # spl.info[:ranges] #else #spl.info[:cache_updated] = spl.info[:cache_updated] | CACHERANGES - ranges = _getranges(vi, spl.selector, getspaceval(spl.alg)) + ranges = _getranges(vi, spl.selector, Val(getspace(spl))) #spl.info[:ranges] = ranges return ranges #end end # Get the index (in vals) ranges of all the vns of variables belonging to selector `s` in `space` -@inline function _getranges(vi::AbstractVarInfo, s::Selector, ::Val{space}=Val(())) where {space} - _getranges(vi, _getidcs(vi, s, Val(space))) +@inline function _getranges(vi::AbstractVarInfo, s::Selector, space::Val) + _getranges(vi, _getidcs(vi, s, space)) end @inline function _getranges(vi::UntypedVarInfo, idcs::Vector{Int}) mapreduce(i -> vi.ranges[i], vcat, idcs, init=Int[]) @@ -878,7 +873,7 @@ function link!(vi::UntypedVarInfo, spl::Sampler) end function link!(vi::TypedVarInfo, spl::Sampler) vns = _getvns(vi, spl) - return _link!(vi.metadata, vi, vns, getspaceval(spl)) + return _link!(vi.metadata, vi, vns, Val(getspace(spl))) end @generated function _link!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} expr = Expr(:block) @@ -924,7 +919,7 @@ function invlink!(vi::UntypedVarInfo, spl::Sampler) end function invlink!(vi::TypedVarInfo, spl::Sampler) vns = _getvns(vi, spl) - return _invlink!(vi.metadata, vi, vns, getspaceval(spl)) + return _invlink!(vi.metadata, vi, vns, Val(getspace(spl))) end @generated function _invlink!(metadata::NamedTuple{names}, vi, vns, ::Val{space}) where {names, space} expr = Expr(:block) @@ -1310,7 +1305,7 @@ If `vn` doesn't have a sampler selector linked and `vn`'s symbol is in the space `spl`, this function will set `vn`'s `gid` to `Set([spl.selector])`. """ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler) - if vn in getspace(spl.alg) + if vn in getspace(spl) setgid!(vi, spl.selector, vn) end end From 2d9bc42265073676f68f896825aec0bd2e7790f1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 19 Dec 2019 19:13:20 +0100 Subject: [PATCH 05/12] Remove Nothing sampler again --- src/inference/Inference.jl | 5 +---- src/inference/hmc.jl | 4 ++-- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 2378815953..3856537918 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -612,9 +612,8 @@ function observe(spl::Sampler, weight) error("Turing.observe: unmanaged inference algorithm: $(typeof(spl))") end -## Default definitions for assume, observe, when sampler = nothing. +## Default definitions for assume and observe without sampler. function assume( - ::Nothing, dist::Distribution, vn::VarName, vi::VarInfo, @@ -641,7 +640,6 @@ function assume( end function observe( - ::Nothing, dist::Distribution, value, vi::VarInfo, @@ -801,7 +799,6 @@ function _dot_tilde(sampler, right, left::AbstractArray, vi) end function dot_observe( - ::Nothing, dist::Union{Distribution, AbstractArray{<:Distribution}}, value, vi::VarInfo, diff --git a/src/inference/hmc.jl b/src/inference/hmc.jl index 76a3b53e16..0cec61a81f 100644 --- a/src/inference/hmc.jl +++ b/src/inference/hmc.jl @@ -488,7 +488,7 @@ function observe( value, vi::VarInfo, ) - return observe(nothing, d, value, vi) + return observe(d, value, vi) end function dot_observe( @@ -497,7 +497,7 @@ function dot_observe( value::AbstractArray, vi::VarInfo, ) - return dot_observe(nothing, ds, value, vi) + return dot_observe(ds, value, vi) end #### From 249d505c95bc48aff79a9cbc06524ddee6faee5b Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 19 Dec 2019 19:14:34 +0100 Subject: [PATCH 06/12] Update implementation of elliptical slice sampling --- src/inference/Inference.jl | 2 +- src/inference/ess.jl | 71 +++++++++++++++++--------------------- 2 files changed, 33 insertions(+), 40 deletions(-) diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 3856537918..710cbd27f3 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -2,7 +2,7 @@ module Inference using ..Core, ..Core.RandomVariables, ..Utilities using ..Core.RandomVariables: Metadata, _tail, VarInfo, TypedVarInfo, - islinked, invlink!, getlogp, tonamedtuple, VarName + islinked, invlink!, getlogp, tonamedtuple, VarName, _getvns, getdist using ..Core: split_var_str using Distributions, Libtask, Bijectors using ProgressMeter, LinearAlgebra diff --git a/src/inference/ess.jl b/src/inference/ess.jl index c706f1818a..8d92806125 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -29,25 +29,30 @@ mutable struct ESSState{V<:VarInfo} <: AbstractSamplerState vi::V end -ESSState(model::Model) = ESSState(VarInfo(model)) - function Sampler(alg::ESS, model::Model, s::Selector) # sanity check + vi = VarInfo(model) space = getspace(alg) - if isempty(space) - pvars = get_pvars(model) - length(pvars) == 1 || - error("[ESS] no symbol specified to sampler although there is not exactly one model parameter ($pvars)") - end + vns = _getvns(vi, s, Val(space)) + length(vns) == 1 || + error("[ESS] does only support one variable ($(length(vns)) variables specified)") + dist = getdist(vi, vns[1][1]) + isgaussian(dist) || + error("[ESS] only supports Gaussian prior distributions") - state = ESSState(model) + state = ESSState(vi) info = Dict{Symbol, Any}() return Sampler(alg, info, s, state) end +isgaussian(dist) = false +isgaussian(::Normal) = true +isgaussian(::NormalCanon) = true +isgaussian(::AbstractMvNormal) = true + # always accept in the first step -function step!(::AbstractRNG, ::Model, spl::Sampler{<:ESS}, ::Integer; kwargs...) +function step!(::AbstractRNG, model::Model, spl::Sampler{<:ESS}, ::Integer; kwargs...) return Transition(spl) end @@ -59,27 +64,26 @@ function step!( ::Transition; kwargs... ) - # recompute log-likelihood in logp - vi = spl.state.vi - if spl.selector.tag !== :default - runmodel!(model, vi, spl) - end - # obtain mean of distribution - vns = _getvns(vi, spl) - length(vns) == 1 || error("[ESS] does only support one parameter") - dist = getdist(vi, vns[1][1]) + vi = spl.state.vi + vn = _getvns(vi, spl)[1][1] + dist = getdist(vi, vn) μ = vectorize(dist, mean(dist)) # obtain previous sample - f = copy(vi[spl]) + f = vi[vn] + + # recompute log-likelihood in logp + if spl.selector.tag !== :default + runmodel!(model, vi, spl) + end + setgid!(vi, spl.selector, vn) # sample log-likelihood threshold for the next sample threshold = getlogp(vi) - randexp(rng) # sample from the prior - runmodel!(model, vi, spl) - ν = copy(vi[spl]) + ν = vectorize(dist, rand(rng, dist)) # sample initial angle θ = 2 * π * rand(rng) @@ -87,13 +91,10 @@ function step!( θₘₐₓ = θ while true - # compute proposal + # compute proposal and apply correction for distributions with nonzero mean sinθ, cosθ = sincos(θ) - @. vi[spl] = f * cosθ + ν * sinθ - - # apply correction for distributions with nonzero mean a = 1 - (sinθ + cosθ) - @. vi[spl] += μ * a + vi[vn] = @. f * cosθ + ν * sinθ + μ * a # recompute log-likelihood and check if threshold is reached runmodel!(model, vi, spl) @@ -115,24 +116,16 @@ function step!( return Transition(spl) end -isnormal(dist) = false -isnormal(::Normal) = true -isnormal(::NormalCanon) = true -isnormal(::AbstractMvNormal) = true - function assume(spl::Sampler{<:ESS}, dist::Distribution, vn::VarName, vi::VarInfo) + # don't sample + r = vi[vn] + + # avoid possibly costly computation of the prior probability space = getspace(spl) if space === () || space === (vn.sym,) - isnormal(dist) || - error("[ESS] only supports normally distributed prior distributions") - - r = rand(dist) - vi[vn] = vectorize(dist, r) - setgid!(vi, spl.selector, vn) return r, zero(Base.promote_eltype(dist, r)) else - r = vi[vn] - return r, logpdf(dist, r) + return r, logpdf_with_trans(dist, r, istrans(vi, vn)) end end From 946465400a44e187981f446b7625890637976936 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 19 Dec 2019 23:34:25 +0100 Subject: [PATCH 07/12] Remove more nothing --- src/inference/mh.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/inference/mh.jl b/src/inference/mh.jl index 9dcfb7c140..4d8257c223 100644 --- a/src/inference/mh.jl +++ b/src/inference/mh.jl @@ -159,7 +159,7 @@ function assume(spl::Sampler{<:MH}, dist::Distribution, vn::VarName, vi::VarInfo end function observe(spl::Sampler{<:MH}, d::Distribution, value, vi::VarInfo) - return observe(nothing, d, value, vi) # accumulate pdf of likelihood + return observe(d, value, vi) # accumulate pdf of likelihood end function dot_observe( @@ -168,5 +168,5 @@ function dot_observe( value, vi::VarInfo, ) - return dot_observe(nothing, ds, value, vi) # accumulate pdf of likelihood + return dot_observe(ds, value, vi) # accumulate pdf of likelihood end From 695dc1cec527d36822a04bb601e608d4cb7e5938 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 19 Dec 2019 23:34:42 +0100 Subject: [PATCH 08/12] Fix tests --- src/inference/ess.jl | 2 +- test/inference/ess.jl | 6 ++++-- test/inference/gibbs.jl | 7 +++++++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 8d92806125..45bb89c128 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -17,7 +17,7 @@ Mean │ Row │ parameters │ mean │ │ │ Symbol │ Float64 │ ├─────┼────────────┼──────────┤ -│ 1 │ m │ 0.811555 │ +│ 1 │ m │ 0.824853 │ ``` """ struct ESS{space} <: InferenceAlgorithm end diff --git a/test/inference/ess.jl b/test/inference/ess.jl index c20c82ed13..5076bf224f 100644 --- a/test/inference/ess.jl +++ b/test/inference/ess.jl @@ -29,8 +29,10 @@ include(dir*"/test/test_utils/AllUtils.jl") check_numerical(chain, [:m], [0.8], atol = 0.1) Random.seed!(100) - alg = Gibbs(CSMC(50, :s), ESS(:m)) - chain = sample(gdemo_default, alg, 5_000) + alg = Gibbs( + CSMC(15, :s), + ESS(:m)) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) # MoGtest diff --git a/test/inference/gibbs.jl b/test/inference/gibbs.jl index e908cc2742..cb03ef50b8 100644 --- a/test/inference/gibbs.jl +++ b/test/inference/gibbs.jl @@ -36,6 +36,13 @@ include(dir*"/test/test_utils/AllUtils.jl") chain = sample(gdemo(1.5, 2.0), alg, 3000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) + Random.seed!(100) + alg = Gibbs( + CSMC(15, :s), + ESS(:m)) + chain = sample(gdemo(1.5, 2.0), alg, 10_000) + check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) + alg = CSMC(10) chain = sample(gdemo(1.5, 2.0), alg, 5000) check_numerical(chain, [:s, :m], [49/24, 7/6], atol=0.1) From 4225e154fd2c48c4207c03e6fcd076316b1e72fe Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 23 Dec 2019 17:56:47 +0100 Subject: [PATCH 09/12] Remove some Unicode characters --- src/inference/ess.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 1e389298a1..2882a9ef2e 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -87,8 +87,8 @@ function step!( # sample initial angle θ = 2 * π * rand(rng) - θₘᵢₙ = θ - 2 * π - θₘₐₓ = θ + θmin = θ - 2 * π + θmax = θ while true # compute proposal and apply correction for distributions with nonzero mean @@ -104,13 +104,13 @@ function step!( # shrink the bracket if θ < 0 - θₘᵢₙ = θ + θmin = θ else - θₘₐₓ = θ + θmax = θ end # sample new angle - θ = θₘᵢₙ + rand(rng) * (θₘₐₓ - θₘᵢₙ) + θ = θmin + rand(rng) * (θmax - θmin) end return Transition(spl) From 946a6f5aeb601b13f7a74a962bdf3ba933e604e3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 26 Dec 2019 11:27:30 +0100 Subject: [PATCH 10/12] Overload tilde and dot_tilde and test dot notation --- src/core/RandomVariables.jl | 1 + src/inference/Inference.jl | 33 +++++++++++++++++++++-- src/inference/ess.jl | 54 ++++++++++++++++++++++--------------- test/inference/ess.jl | 22 +++++++++++---- 4 files changed, 82 insertions(+), 28 deletions(-) diff --git a/src/core/RandomVariables.jl b/src/core/RandomVariables.jl index 8732d6e285..8140bc8c80 100644 --- a/src/core/RandomVariables.jl +++ b/src/core/RandomVariables.jl @@ -34,6 +34,7 @@ export VarName, resetlogp!, set_retained_vns_del_by_spl!, is_flagged, + set_flag!, unset_flag!, setgid!, updategid!, diff --git a/src/inference/Inference.jl b/src/inference/Inference.jl index 816447c070..52a0e08b8b 100644 --- a/src/inference/Inference.jl +++ b/src/inference/Inference.jl @@ -637,7 +637,14 @@ function assume( vi::VarInfo, ) if haskey(vi, vn) + if is_flagged(vi, vn, "del") + unset_flag!(vi, vn, "del") + r = spl isa SampleFromUniform ? init(dist) : rand(dist) + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, vi.num_produce) + else r = vi[vn] + end else r = isa(spl, SampleFromUniform) ? init(dist) : rand(dist) push!(vi, vn, r, dist, spl) @@ -794,9 +801,19 @@ function get_and_set_val!( ) n = length(vns) if haskey(vi, vns[1]) + if is_flagged(vi, vns[1], "del") + unset_flag!(vi, vns[1], "del") + r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n) + for i in 1:n + vn = vns[i] + vi[vn] = vectorize(dist, r[:, i]) + setorder!(vi, vn, vi.num_produce) + end + else r = vi[vns] + end else - r = isa(spl, SampleFromUniform) ? init(dist, n) : rand(dist, n) + r = spl isa SampleFromUniform ? init(dist, n) : rand(dist, n) for i in 1:n push!(vi, vns[i], r[:,i], dist, spl) end @@ -810,9 +827,21 @@ function get_and_set_val!( spl::AbstractSampler, ) if haskey(vi, vns[1]) + if is_flagged(vi, vns[1], "del") + unset_flag!(vi, vns[1], "del") + f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist) + r = f.(vns, dists) + for i in eachindex(vns) + vn = vns[i] + dist = dists isa AbstractArray ? dists[i] : dists + vi[vn] = vectorize(dist, r[i]) + setorder!(vi, vn, vi.num_produce) + end + else r = reshape(vi[vec(vns)], size(vns)) + end else - f(vn, dist) = isa(spl, SampleFromUniform) ? init(dist) : rand(dist) + f = (vn, dist) -> spl isa SampleFromUniform ? init(dist) : rand(dist) r = f.(vns, dists) push!.(Ref(vi), vns, r, dists, Ref(spl)) end diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 2882a9ef2e..307b5e2d90 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -36,9 +36,11 @@ function Sampler(alg::ESS, model::Model, s::Selector) vns = _getvns(vi, s, Val(space)) length(vns) == 1 || error("[ESS] does only support one variable ($(length(vns)) variables specified)") - dist = getdist(vi, vns[1][1]) - isgaussian(dist) || - error("[ESS] only supports Gaussian prior distributions") + for vn in vns[1] + dist = getdist(vi, vn) + isgaussian(dist) || + error("[ESS] only supports Gaussian prior distributions") + end state = ESSState(vi) info = Dict{Symbol, Any}() @@ -66,24 +68,27 @@ function step!( ) # obtain mean of distribution vi = spl.state.vi - vn = _getvns(vi, spl)[1][1] - dist = getdist(vi, vn) - μ = vectorize(dist, mean(dist)) + vns = _getvns(vi, spl) + μ = mapreduce(vcat, vns[1]) do vn + dist = getdist(vi, vn) + vectorize(dist, mean(dist)) + end # obtain previous sample - f = vi[vn] + f = vi[spl] # recompute log-likelihood in logp if spl.selector.tag !== :default runmodel!(model, vi, spl) end - setgid!(vi, spl.selector, vn) # sample log-likelihood threshold for the next sample threshold = getlogp(vi) - randexp(rng) # sample from the prior - ν = vectorize(dist, rand(rng, dist)) + set_flag!(vi, vns[1][1], "del") + runmodel!(model, vi, spl) + ν = vi[spl] # sample initial angle θ = 2 * π * rand(rng) @@ -94,7 +99,7 @@ function step!( # compute proposal and apply correction for distributions with nonzero mean sinθ, cosθ = sincos(θ) a = 1 - (sinθ + cosθ) - vi[vn] = @. f * cosθ + ν * sinθ + μ * a + vi[spl] = @. f * cosθ + ν * sinθ + μ * a # recompute log-likelihood and check if threshold is reached runmodel!(model, vi, spl) @@ -116,19 +121,26 @@ function step!( return Transition(spl) end -function assume(spl::Sampler{<:ESS}, dist::Distribution, vn::VarName, vi::VarInfo) - # don't sample - r = vi[vn] - - # avoid possibly costly computation of the prior probability - space = getspace(spl) - if space === () || space === (vn.sym,) - return r, zero(Base.promote_eltype(dist, r)) +function tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi) + if vn in getspace(sampler) + return tilde(LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi) else - return r, logpdf_with_trans(dist, r, istrans(vi, vn)) + return tilde(ctx, SampleFromPrior(), right, vn, inds, vi) end end -function observe(spl::Sampler{<:ESS}, dist::Distribution, value, vi::VarInfo) - return observe(SampleFromPrior(), dist, value, vi) +function tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return tilde(ctx, SampleFromPrior(), right, left, vi) +end + +function dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi) + if vn in getspace(sampler) + return dot_tilde(LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi) + else + return dot_tilde(ctx, SampleFromPrior(), right, left, vn, inds, vi) + end end + +function dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return dot_tilde(ctx, SampleFromPrior(), right, left, vi) +end \ No newline at end of file diff --git a/test/inference/ess.jl b/test/inference/ess.jl index 5076bf224f..18da9eb76d 100644 --- a/test/inference/ess.jl +++ b/test/inference/ess.jl @@ -10,6 +10,13 @@ include(dir*"/test/test_utils/AllUtils.jl") end demo_default = demo(1.0) + @model demodot(x) = begin + m = Vector{Float64}(undef, 2) + m .~ Normal() + x ~ Normal(m[2], 0.5) + end + demodot_default = demodot(1.0) + @turing_testset "ESS constructor" begin Random.seed!(0) N = 500 @@ -19,15 +26,20 @@ include(dir*"/test/test_utils/AllUtils.jl") c1 = sample(demo_default, s1, N) c2 = sample(demo_default, s2, N) - c3 = sample(gdemo_default, s3, N) + c3 = sample(demodot_default, s1, N) + c4 = sample(demodot_default, s2, N) + c5 = sample(gdemo_default, s3, N) end @numerical_testset "ESS inference" begin Random.seed!(1) - alg = ESS() - chain = sample(demo_default, alg, 5_000) + chain = sample(demo_default, ESS(), 5_000) check_numerical(chain, [:m], [0.8], atol = 0.1) + Random.seed!(1) + chain = sample(demodot_default, ESS(), 5_000) + check_numerical(chain, ["m[1]", "m[2]"], [0.0, 0.8], atol = 0.1) + Random.seed!(100) alg = Gibbs( CSMC(15, :s), @@ -37,10 +49,10 @@ include(dir*"/test/test_utils/AllUtils.jl") # MoGtest Random.seed!(125) - gibbs = Gibbs( + alg = Gibbs( CSMC(15, :z1, :z2, :z3, :z4), ESS(:mu1), ESS(:mu2)) - chain = sample(MoGtest_default, gibbs, 6000) + chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain, atol = 0.1) end end From e2f9d4447931dfde2383d96392367876be216ded Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 26 Dec 2019 14:20:47 +0100 Subject: [PATCH 11/12] Fix error --- src/core/RandomVariables.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/core/RandomVariables.jl b/src/core/RandomVariables.jl index 8140bc8c80..af4fdae6fa 100644 --- a/src/core/RandomVariables.jl +++ b/src/core/RandomVariables.jl @@ -497,7 +497,7 @@ end # Get all vns of variables belonging to spl _getvns(vi::AbstractVarInfo, spl::Sampler) = _getvns(vi, spl.selector, Val(getspace(spl))) -_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.vns, _getidcs(vi, s, space)) +_getvns(vi::UntypedVarInfo, s::Selector, space) = view(vi.metadata.vns, _getidcs(vi, s, space)) function _getvns(vi::TypedVarInfo, s::Selector, space) return _getvns(vi.metadata, _getidcs(vi, s, space)) end From 94e1e6f45e2b53b41cb76d67ecd2011d9e4be4b0 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 26 Dec 2019 16:16:27 +0100 Subject: [PATCH 12/12] Fix test errors on Julia 1.0 --- test/inference/ess.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/ess.jl b/test/inference/ess.jl index 18da9eb76d..c61cf5c6be 100644 --- a/test/inference/ess.jl +++ b/test/inference/ess.jl @@ -12,7 +12,7 @@ include(dir*"/test/test_utils/AllUtils.jl") @model demodot(x) = begin m = Vector{Float64}(undef, 2) - m .~ Normal() + @. m ~ Normal() x ~ Normal(m[2], 0.5) end demodot_default = demodot(1.0)