diff --git a/Project.toml b/Project.toml index c8f9a794bb..3aa2c8bdc9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.15.5" +version = "0.15.6" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -18,7 +18,6 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" @@ -47,7 +46,6 @@ DynamicPPL = "0.10.2" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" Libtask = "0.4, 0.5" -LogDensityProblems = "^0.9, 0.10" MCMCChains = "4" NamedArrays = "0.9" Reexport = "0.2.0" diff --git a/docs/src/using-turing/dynamichmc.md b/docs/src/using-turing/dynamichmc.md index eb226be94f..4854512af5 100644 --- a/docs/src/using-turing/dynamichmc.md +++ b/docs/src/using-turing/dynamichmc.md @@ -6,10 +6,6 @@ title: Using DynamicHMC Turing supports the use of [DynamicHMC](https://github.com/tpapp/DynamicHMC.jl) as a sampler through the `DynamicNUTS` function. - -`DynamicNUTS` is not appropriate for use in compositional inference. If you intend to use [Gibbs]({{site.baseurl}}/docs/library/#Turing.Inference.Gibbs) sampling, you must use Turing's native `NUTS` function. - - To use the `DynamicNUTS` function, you must import the `DynamicHMC` package as well as Turing. Turing does not formally require `DynamicHMC` but will include additional functionality if both packages are present. Here is a brief example of how to apply `DynamicNUTS`: @@ -17,7 +13,7 @@ Here is a brief example of how to apply `DynamicNUTS`: ```julia # Import Turing and DynamicHMC. -using LogDensityProblems, DynamicHMC, Turing +using DynamicHMC, Turing # Model definition. @model function gdemo(x, y) diff --git a/src/Turing.jl b/src/Turing.jl index 408abf40b7..1306367ed9 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -55,14 +55,15 @@ using .Variational # end # end -@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval Inference begin - import ..DynamicHMC - - if isdefined(DynamicHMC, :mcmc_with_warmup) - using ..DynamicHMC: mcmc_with_warmup - include("contrib/inference/dynamichmc.jl") - else - error("Please update DynamicHMC, v1.x is no longer supported") +@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin + @eval Inference begin + import ..DynamicHMC + + if isdefined(DynamicHMC, :mcmc_with_warmup) + include("contrib/inference/dynamichmc.jl") + else + error("Please update DynamicHMC, v1.x is no longer supported") + end end end diff --git a/src/contrib/inference/dynamichmc.jl b/src/contrib/inference/dynamichmc.jl index 1953463bbd..6f137512da 100644 --- a/src/contrib/inference/dynamichmc.jl +++ b/src/contrib/inference/dynamichmc.jl @@ -1,52 +1,64 @@ ### ### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl ### -struct DynamicNUTS{AD, space} <: Hamiltonian{AD} end -using LogDensityProblems: LogDensityProblems +""" + DynamicNUTS -struct FunctionLogDensity{F} - dimension::Int - f::F -end +Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. + +To use it, make sure you have DynamicHMC package (version >= 2) loaded: +```julia +using DynamicHMC +``` +""" +struct DynamicNUTS{AD,space} <: Hamiltonian{AD} end -LogDensityProblems.dimension(ℓ::FunctionLogDensity) = ℓ.dimension +DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...) +DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}() -function LogDensityProblems.capabilities(::Type{<:FunctionLogDensity}) - LogDensityProblems.LogDensityOrder{1}() +DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space + +struct DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo} + model::M + sampler::S + varinfo::V end -function LogDensityProblems.logdensity(ℓ::FunctionLogDensity, x::AbstractVector) - first(ℓ.f(x)) +function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity) + return length(ℓ.varinfo[ℓ.sampler]) end -function LogDensityProblems.logdensity_and_gradient(ℓ::FunctionLogDensity, - x::AbstractVector) - ℓ.f(x) +function DynamicHMC.capabilities(::Type{<:DynamicHMCLogDensity}) + return DynamicHMC.LogDensityOrder{1}() +end + +function DynamicHMC.logdensity_and_gradient( + ℓ::DynamicHMCLogDensity, + x::AbstractVector, +) + return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler) end """ - DynamicNUTS() + DynamicNUTSState -Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make -sure you have the DynamicHMC package (version `2.*`) loaded: +State of the [`DynamicNUTS`](@ref) sampler. -```julia -using DynamicHMC -`` +# Fields +$(TYPEDFIELDS) """ -DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...) -DynamicNUTS{AD}() where AD = DynamicNUTS{AD, ()}() -function DynamicNUTS{AD}(space::Symbol...) where AD - DynamicNUTS{AD, space}() -end - -struct DynamicNUTSState{V<:AbstractVarInfo,D} +struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S} vi::V - draws::Vector{D} + "Cache of sample, log density, and gradient of log density." + cache::C + metric::M + stepsize::S end -DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space +function gibbs_update_state(state::DynamicNUTSState, varinfo::AbstractVarInfo) + return DynamicNUTSState(varinfo, state.cache, state.metric, state.stepsize) +end DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform() @@ -55,44 +67,39 @@ function DynamicPPL.initialstep( model::Model, spl::Sampler{<:DynamicNUTS}, vi::AbstractVarInfo; - N::Int, kwargs... ) - # Set up lp function. - function _lp(x) - gradient_logp(x, vi, model, spl) - end - - link!(vi, spl) - l, dl = _lp(vi[spl]) - while !isfinite(l) || !isfinite(dl) - model(vi, SampleFromUniform()) - link!(vi, spl) - l, dl = _lp(vi[spl]) - end - - if spl.selector.tag == :default && !islinked(vi, spl) - link!(vi, spl) - model(vi, spl) + # Ensure that initial sample is in unconstrained space. + if !DynamicPPL.islinked(vi, spl) + DynamicPPL.link!(vi, spl) + model(rng, vi, spl) end - results = mcmc_with_warmup( + # Perform initial step. + results = DynamicHMC.mcmc_keep_warmup( rng, - FunctionLogDensity( - length(vi[spl]), - _lp - ), - N + DynamicHMCLogDensity(model, spl, vi), + 0; + initialization = (q = vi[spl],), + reporter = DynamicHMC.NoProgressReport(), ) - draws = results.chain + steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state) + Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q) - # Compute first transition and state. - draw = popfirst!(draws) - vi[spl] = draw - transition = Transition(vi) - state = DynamicNUTSState(vi, draws) + # Update the variables. + vi[spl] = Q.q + DynamicPPL.setlogp!(vi, Q.ℓq) - return transition, state + # If a Gibbs component, transform the values back to the constrained space. + if spl.selector.tag !== :default + DynamicPPL.invlink!(vi, spl) + end + + # Create first sample and state. + sample = Transition(vi) + state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ) + + return sample, state end function AbstractMCMC.step( @@ -102,55 +109,38 @@ function AbstractMCMC.step( state::DynamicNUTSState; kwargs... ) - # Extract VarInfo object. + # Compute next sample. vi = state.vi - - # Pop the next draw off the vector. - draw = popfirst!(state.draws) - vi[spl] = draw - - # Compute next transition. - transition = Transition(vi) - - return transition, state -end - -# Disable the progress logging for DynamicHMC, since it has its own progress meter. -function AbstractMCMC.sample( - rng::AbstractRNG, - model::AbstractModel, - alg::DynamicNUTS, - N::Integer; - chain_type=MCMCChains.Chains, - resume_from=nothing, - progress=PROGRESS[], - kwargs... -) - if progress - @warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" - end - if resume_from === nothing - return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; - chain_type=chain_type, progress=false, N=N, kwargs...) + ℓ = DynamicHMCLogDensity(model, spl, vi) + steps = DynamicHMC.mcmc_steps( + rng, + DynamicHMC.NUTS(), + state.metric, + ℓ, + state.stepsize, + ) + Q = if spl.selector.tag !== :default + # When a Gibbs component, transform values to the unconstrained space + # and update the previous evaluation. + DynamicPPL.link!(vi, spl) + DynamicHMC.evaluate_ℓ(ℓ, vi[spl]) else - return resume(resume_from, N; chain_type=chain_type, progress=false, N=N, kwargs...) + state.cache end -end + newQ, _ = DynamicHMC.mcmc_next_step(steps, Q) -function AbstractMCMC.sample( - rng::AbstractRNG, - model::AbstractModel, - alg::DynamicNUTS, - parallel::AbstractMCMC.AbstractMCMCParallel, - N::Integer, - n_chains::Integer; - chain_type=MCMCChains.Chains, - progress=PROGRESS[], - kwargs... -) - if progress - @warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter" + # Update the variables. + vi[spl] = newQ.q + DynamicPPL.setlogp!(vi, newQ.ℓq) + + # If a Gibbs component, transform the values back to the constrained space. + if spl.selector.tag !== :default + DynamicPPL.invlink!(vi, spl) end - return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains; - chain_type=chain_type, progress=false, N=N, kwargs...) + + # Create next sample and state. + sample = Transition(vi) + newstate = DynamicNUTSState(vi, newQ, state.metric, state.stepsize) + + return sample, newstate end diff --git a/test/contrib/inference/dynamichmc.jl b/test/contrib/inference/dynamichmc.jl index b4586b016c..1673fc1f55 100644 --- a/test/contrib/inference/dynamichmc.jl +++ b/test/contrib/inference/dynamichmc.jl @@ -10,6 +10,12 @@ include(dir*"/test/test_utils/AllUtils.jl") @test DynamicPPL.alg_str(Sampler(DynamicNUTS(), gdemo_default)) == "DynamicNUTS" - chn = sample(gdemo_default, DynamicNUTS(), 5000) - check_numerical(chn, [:s, :m], [49/24, 7/6], atol=0.2) + chn = sample(gdemo_default, DynamicNUTS(), 10_000) + check_gdemo(chn) + + chn2 = sample(gdemo_default, Gibbs(PG(15, :s), DynamicNUTS(:m)), 10_000) + check_gdemo(chn2) + + chn3 = sample(gdemo_default, Gibbs(DynamicNUTS(:s), ESS(:m)), 10_000) + check_gdemo(chn3) end