diff --git a/Project.toml b/Project.toml index 4ff052da2f..e779b32519 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,6 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -39,7 +38,6 @@ DynamicPPL = "0.5" EllipticalSliceSampling = "0.2" ForwardDiff = "0.10.3" Libtask = "0.3.1" -LogDensityProblems = "^0.9, 0.10" MCMCChains = "3.0.7" ProgressLogging = "0.1" Reexport = "0.2.0" @@ -55,6 +53,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd" DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" @@ -66,4 +65,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Pkg", "PDMats", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization"] +test = ["Pkg", "PDMats", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "LogDensityProblems", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization"] diff --git a/docs/src/using-turing/dynamichmc.md b/docs/src/using-turing/dynamichmc.md index 4153a49962..6f21b3d225 100644 --- a/docs/src/using-turing/dynamichmc.md +++ b/docs/src/using-turing/dynamichmc.md @@ -10,14 +10,18 @@ Turing supports the use of [DynamicHMC](https://github.com/tpapp/DynamicHMC.jl) `DynamicNUTS` is not appropriate for use in compositional inference. If you intend to use [Gibbs]({{site.baseurl}}/docs/library/#Turing.Inference.Gibbs) sampling, you must use Turing's native `NUTS` function. -To use the `DynamicNUTS` function, you must import the `DynamicHMC` package as well as Turing. Turing does not formally require `DynamicHMC` but will include additional functionality if both packages are present. +To use the `DynamicNUTS` function, you must import the `DynamicHMC` and +`LogDensityProblems` packages as well as Turing. Turing does not formally require +`DynamicHMC` and `LogDensityProblems` but will include additional functionality if both +packages are present. Here is a brief example of how to apply `DynamicNUTS`: ```julia # Import Turing and DynamicHMC. -using LogDensityProblems, DynamicHMC, Turing +using Turing +using LogDensityProblems, DynamicHMC # Model definition. @model gdemo(x, y) = begin diff --git a/src/Turing.jl b/src/Turing.jl index f19eba5662..03e27f7750 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -47,14 +47,15 @@ using .Variational # end # end -@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval Inference begin - import ..DynamicHMC - - if isdefined(DynamicHMC, :mcmc_with_warmup) - using ..DynamicHMC: mcmc_with_warmup - include("contrib/inference/dynamichmc.jl") - else - error("Please update DynamicHMC, v1.x is no longer supported") +@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin + @require LogDensityProblems="6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @eval Inference begin + import ..DynamicHMC, ..LogDensityProblems + + if isdefined(DynamicHMC, :mcmc_with_warmup) + include("contrib/inference/dynamichmc.jl") + else + error("Please update DynamicHMC, v1.x is no longer supported") + end end end diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index e11e981bd1..477f5293cb 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -1,148 +1,172 @@ ### ### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl ### + +""" + DynamicNUTS + +Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make +sure you have the LogDensityProblems package and DynamicHMC package (version >= 2) loaded: + +```julia +using LogDensityProblems, DynamicHMC +``` +""" struct DynamicNUTS{AD, space} <: Hamiltonian{AD} end -using LogDensityProblems: LogDensityProblems +DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...) +DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}() + +getspace(::DynamicNUTS{<:Any, space}) where {space} = space -struct FunctionLogDensity{F} - dimension::Int - f::F +mutable struct DynamicNUTSState{V<:VarInfo} <: AbstractSamplerState + vi::V end -LogDensityProblems.dimension(ℓ::FunctionLogDensity) = ℓ.dimension +function Sampler( + alg::DynamicNUTS, + model::Model, + s::Selector=Selector() +) + # Construct a state, using a default function. + state = DynamicNUTSState(VarInfo(model)) -function LogDensityProblems.capabilities(::Type{<:FunctionLogDensity}) - LogDensityProblems.LogDensityOrder{1}() + # Return a new sampler. + return Sampler(alg, Dict{Symbol,Any}(), s, state) end -function LogDensityProblems.logdensity(ℓ::FunctionLogDensity, x::AbstractVector) - first(ℓ.f(x)) +""" + DynamicNUTSTransition + +Transition for the `DynamicNUTS` sampler. +""" +struct DynamicNUTSTransition{T,F<:AbstractFloat,QType,H,S} + θ::T + lp::F + Q::QType + hamiltonian::H + stepsize::S end -function LogDensityProblems.logdensity_and_gradient(ℓ::FunctionLogDensity, - x::AbstractVector) - ℓ.f(x) +function additional_parameters(::Type{<:DynamicNUTSTransition}) + return [:lp] end -""" - DynamicNUTS() +# Wrapper for the log density function +struct LogDensity{M<:Model,S<:Sampler} + model::M + spl::S +end -Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make -sure you have the DynamicHMC package (version `2.*`) loaded: +function LogDensityProblems.dimension(ℓ::LogDensity) + spl = ℓ.spl + return length(spl.state.vi[spl]) +end -```julia -using DynamicHMC -`` -""" -DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...) -DynamicNUTS{AD}() where AD = DynamicNUTS{AD, ()}() -function DynamicNUTS{AD}(space::Symbol...) where AD - DynamicNUTS{AD, space}() +function LogDensityProblems.capabilities(::Type{<:LogDensity}) + LogDensityProblems.LogDensityOrder{1}() end -mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState - vi::V - draws::Vector{D} +function LogDensityProblems.logdensity(ℓ::LogDensity, x::AbstractVector) + sampler = ℓ.sampler + vi = sampler.state.vi + + x_old = vi[sampler] + lj_old = getlogp(vi) + + vi[sampler] = x + runmodel!(ℓ.model, vi, sampler) + lj = getlogp(vi) + + vi[sampler] = x_old + setlogp!(vi, lj_old) + + return lj end -getspace(::DynamicNUTS{<:Any, space}) where {space} = space +function LogDensityProblems.logdensity_and_gradient(ℓ::LogDensity, + x::AbstractVector) + spl = ℓ.spl + return gradient_logp(x, spl.state.vi, ℓ.model, spl) +end -function AbstractMCMC.sample_init!( +function AbstractMCMC.step!( rng::AbstractRNG, model::Model, spl::Sampler{<:DynamicNUTS}, - N::Integer; + ::Integer, + ::Nothing; kwargs... ) - # Set up lp function. - function _lp(x) - gradient_logp(x, spl.state.vi, model, spl) + # Convert to transformed space. + vi = spl.state.vi + if !islinked(vi, spl) + Turing.DEBUG && @debug "X-> R..." + link!(vi, spl) + runmodel!(model, vi, spl) end - runmodel!(model, spl.state.vi, SampleFromUniform()) - - if spl.selector.tag == :default - link!(spl.state.vi, spl) - runmodel!(model, spl.state.vi, spl) - end - - # Set the parameters to a starting value. - initialize_parameters!(spl; kwargs...) - - results = mcmc_with_warmup( + # Initial step + results = DynamicHMC.mcmc_keep_warmup( rng, - FunctionLogDensity( - length(spl.state.vi[spl]), - _lp - ), - N + LogDensity(model, spl), + 0; + reporter = DynamicHMC.NoProgressReport() ) + steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state) + Q, stats = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q) + + # Update the sample. + vi[spl] = Q.q + logp = stats.π + setlogp!(vi, logp) - spl.state.draws = results.chain + return DynamicNUTSTransition(tonamedtuple(vi), logp, Q, steps.H, steps.ϵ) end function AbstractMCMC.step!( rng::AbstractRNG, model::Model, spl::Sampler{<:DynamicNUTS}, - N::Integer, - transition; + ::Integer, + transition::DynamicNUTSTransition; kwargs... ) - # Pop the next draw off the vector. - draw = popfirst!(spl.state.draws) - spl.state.vi[spl] = draw - return Transition(spl) -end - -function Sampler( - alg::DynamicNUTS, - model::Model, - s::Selector=Selector() -) - # Construct a state, using a default function. - state = DynamicNUTSState(VarInfo(model), []) - - # Return a new sampler. - return Sampler(alg, Dict{Symbol,Any}(), s, state) + # Compute next sample. + hamiltonian = transition.hamiltonian + stepsize = transition.stepsize + steps = DynamicHMC.MCMCSteps(rng, DynamicHMC.NUTS(), hamiltonian, stepsize) + Q, stats = DynamicHMC.mcmc_next_step(steps, transition.Q) + + # Update the sample. + vi = spl.state.vi + vi[spl] = Q.q + logp = stats.π + setlogp!(vi, logp) + + return DynamicNUTSTransition(tonamedtuple(vi), logp, Q, hamiltonian, stepsize) end - # Disable the progress logging for DynamicHMC, since it has its own progress meter. - function AbstractMCMC.sample( - rng::AbstractRNG, - model::AbstractModel, - alg::DynamicNUTS, +# Do not store fields specific to DynamicHMC. +function AbstractMCMC.transitions_init( + transition::DynamicNUTSTransition, + ::Model, + ::Sampler{<:DynamicNUTS}, N::Integer; - chain_type=MCMCChains.Chains, - resume_from=nothing, - progress=PROGRESS[], kwargs... ) - if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" - end - if resume_from === nothing - return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; - chain_type=chain_type, progress=false, kwargs...) - else - return resume(resume_from, N; chain_type=chain_type, progress=false, kwargs...) - end + return Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, N) end -function AbstractMCMC.psample( - rng::AbstractRNG, - model::AbstractModel, - alg::DynamicNUTS, - N::Integer, - n_chains::Integer; - chain_type=MCMCChains.Chains, - progress=PROGRESS[], +function AbstractMCMC.transitions_save!( + transitions::Vector{<:Transition}, + iteration::Integer, + transition::DynamicNUTSTransition, + ::Model, + ::Sampler{<:DynamicNUTS}, + ::Integer; kwargs... ) - if progress - @warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" - end - return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains; - chain_type=chain_type, progress=false, kwargs...) + transitions[iteration] = Transition(transition.θ, transition.lp) + return end \ No newline at end of file diff --git a/test/contrib/inference/dynamichmc.jl b/test/contrib/inference/dynamichmc.jl index c3072910dd..194b887db6 100644 --- a/test/contrib/inference/dynamichmc.jl +++ b/test/contrib/inference/dynamichmc.jl @@ -5,7 +5,7 @@ dir = splitdir(splitdir(pathof(Turing))[1])[1] include(dir*"/test/test_utils/AllUtils.jl") @stage_testset "dynamichmc" "dynamichmc.jl" begin - import DynamicHMC + import LogDensityProblems, DynamicHMC Random.seed!(100) chn = sample(gdemo_default, DynamicNUTS(), 5000); check_numerical(chn, [:s, :m], [49/24, 7/6], atol=0.2) diff --git a/test/runtests.jl b/test/runtests.jl index bdcfb5a472..d7d1399fab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,11 +25,7 @@ include("test_utils/AllUtils.jl") Turing.setadbackend(adbackend) @testset "inference: $adbackend" begin @testset "samplers" begin - # FIXME: DynamicHMC version 1 has (??) a bug on 32bit platforms (but we were too - # lazy to open an issue so Tamas doesn't know about it), retest with 2.0 - if Int === Int64 && Pkg.installed()["DynamicHMC"].major == 2 - include("contrib/inference/dynamichmc.jl") - end + include("contrib/inference/dynamichmc.jl") include("inference/gibbs.jl") include("inference/hmc.jl") include("inference/is.jl")