Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.15.5"
version = "0.15.6"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -18,7 +18,6 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
NamedArrays = "86f7a689-2022-50b4-a561-43c23ac3c673"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down Expand Up @@ -47,7 +46,6 @@ DynamicPPL = "0.10.2"
EllipticalSliceSampling = "0.4"
ForwardDiff = "0.10.3"
Libtask = "0.4, 0.5"
LogDensityProblems = "^0.9, 0.10"
MCMCChains = "4"
NamedArrays = "0.9"
Reexport = "0.2.0"
Expand Down
6 changes: 1 addition & 5 deletions docs/src/using-turing/dynamichmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,14 @@ title: Using DynamicHMC

Turing supports the use of [DynamicHMC](https://github.com/tpapp/DynamicHMC.jl) as a sampler through the `DynamicNUTS` function.


`DynamicNUTS` is not appropriate for use in compositional inference. If you intend to use [Gibbs]({{site.baseurl}}/docs/library/#Turing.Inference.Gibbs) sampling, you must use Turing's native `NUTS` function.


To use the `DynamicNUTS` function, you must import the `DynamicHMC` package as well as Turing. Turing does not formally require `DynamicHMC` but will include additional functionality if both packages are present.

Here is a brief example of how to apply `DynamicNUTS`:


```julia
# Import Turing and DynamicHMC.
using LogDensityProblems, DynamicHMC, Turing
using DynamicHMC, Turing

# Model definition.
@model function gdemo(x, y)
Expand Down
17 changes: 9 additions & 8 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,15 @@ using .Variational
# end
# end

@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval Inference begin
import ..DynamicHMC

if isdefined(DynamicHMC, :mcmc_with_warmup)
using ..DynamicHMC: mcmc_with_warmup
include("contrib/inference/dynamichmc.jl")
else
error("Please update DynamicHMC, v1.x is no longer supported")
@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
@eval Inference begin
import ..DynamicHMC

if isdefined(DynamicHMC, :mcmc_with_warmup)
include("contrib/inference/dynamichmc.jl")
else
error("Please update DynamicHMC, v1.x is no longer supported")
end
end
end

Expand Down
200 changes: 95 additions & 105 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
@@ -1,52 +1,64 @@
###
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
###
struct DynamicNUTS{AD, space} <: Hamiltonian{AD} end

using LogDensityProblems: LogDensityProblems
"""
DynamicNUTS

struct FunctionLogDensity{F}
dimension::Int
f::F
end
Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package.

To use it, make sure you have DynamicHMC package (version >= 2) loaded:
```julia
using DynamicHMC
```
"""
struct DynamicNUTS{AD,space} <: Hamiltonian{AD} end

LogDensityProblems.dimension(ℓ::FunctionLogDensity) = ℓ.dimension
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}()

function LogDensityProblems.capabilities(::Type{<:FunctionLogDensity})
LogDensityProblems.LogDensityOrder{1}()
DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space

struct DynamicHMCLogDensity{M<:Model,S<:Sampler{<:DynamicNUTS},V<:AbstractVarInfo}
model::M
sampler::S
varinfo::V
end

function LogDensityProblems.logdensity(ℓ::FunctionLogDensity, x::AbstractVector)
first(ℓ.f(x))
function DynamicHMC.dimension(ℓ::DynamicHMCLogDensity)
return length(ℓ.varinfo[ℓ.sampler])
end

function LogDensityProblems.logdensity_and_gradient(ℓ::FunctionLogDensity,
x::AbstractVector)
ℓ.f(x)
function DynamicHMC.capabilities(::Type{<:DynamicHMCLogDensity})
return DynamicHMC.LogDensityOrder{1}()
end

function DynamicHMC.logdensity_and_gradient(
ℓ::DynamicHMCLogDensity,
x::AbstractVector,
)
return gradient_logp(x, ℓ.varinfo, ℓ.model, ℓ.sampler)
end

"""
DynamicNUTS()
DynamicNUTSState

Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make
sure you have the DynamicHMC package (version `2.*`) loaded:
State of the [`DynamicNUTS`](@ref) sampler.

```julia
using DynamicHMC
``
# Fields
$(TYPEDFIELDS)
"""
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
DynamicNUTS{AD}() where AD = DynamicNUTS{AD, ()}()
function DynamicNUTS{AD}(space::Symbol...) where AD
DynamicNUTS{AD, space}()
end

struct DynamicNUTSState{V<:AbstractVarInfo,D}
struct DynamicNUTSState{V<:AbstractVarInfo,C,M,S}
vi::V
draws::Vector{D}
"Cache of sample, log density, and gradient of log density."
cache::C
metric::M
stepsize::S
end

DynamicPPL.getspace(::DynamicNUTS{<:Any, space}) where {space} = space
function gibbs_update_state(state::DynamicNUTSState, varinfo::AbstractVarInfo)
return DynamicNUTSState(varinfo, state.cache, state.metric, state.stepsize)
end

DynamicPPL.initialsampler(::Sampler{<:DynamicNUTS}) = SampleFromUniform()

Expand All @@ -55,44 +67,39 @@ function DynamicPPL.initialstep(
model::Model,
spl::Sampler{<:DynamicNUTS},
vi::AbstractVarInfo;
N::Int,
kwargs...
)
# Set up lp function.
function _lp(x)
gradient_logp(x, vi, model, spl)
end

link!(vi, spl)
l, dl = _lp(vi[spl])
while !isfinite(l) || !isfinite(dl)
model(vi, SampleFromUniform())
link!(vi, spl)
l, dl = _lp(vi[spl])
end

if spl.selector.tag == :default && !islinked(vi, spl)
link!(vi, spl)
model(vi, spl)
# Ensure that initial sample is in unconstrained space.
if !DynamicPPL.islinked(vi, spl)
DynamicPPL.link!(vi, spl)
model(rng, vi, spl)
end

results = mcmc_with_warmup(
# Perform initial step.
results = DynamicHMC.mcmc_keep_warmup(
rng,
FunctionLogDensity(
length(vi[spl]),
_lp
),
N
DynamicHMCLogDensity(model, spl, vi),
0;
initialization = (q = vi[spl],),
reporter = DynamicHMC.NoProgressReport(),
)
draws = results.chain
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
Q, _ = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)

# Compute first transition and state.
draw = popfirst!(draws)
vi[spl] = draw
transition = Transition(vi)
state = DynamicNUTSState(vi, draws)
# Update the variables.
vi[spl] = Q.q
DynamicPPL.setlogp!(vi, Q.ℓq)

return transition, state
# If a Gibbs component, transform the values back to the constrained space.
if spl.selector.tag !== :default
DynamicPPL.invlink!(vi, spl)
end

# Create first sample and state.
sample = Transition(vi)
state = DynamicNUTSState(vi, Q, steps.H.κ, steps.ϵ)

return sample, state
end

function AbstractMCMC.step(
Expand All @@ -102,55 +109,38 @@ function AbstractMCMC.step(
state::DynamicNUTSState;
kwargs...
)
# Extract VarInfo object.
# Compute next sample.
vi = state.vi

# Pop the next draw off the vector.
draw = popfirst!(state.draws)
vi[spl] = draw

# Compute next transition.
transition = Transition(vi)

return transition, state
end

# Disable the progress logging for DynamicHMC, since it has its own progress meter.
function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::DynamicNUTS,
N::Integer;
chain_type=MCMCChains.Chains,
resume_from=nothing,
progress=PROGRESS[],
kwargs...
)
if progress
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
if resume_from === nothing
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
chain_type=chain_type, progress=false, N=N, kwargs...)
ℓ = DynamicHMCLogDensity(model, spl, vi)
steps = DynamicHMC.mcmc_steps(
rng,
DynamicHMC.NUTS(),
state.metric,
ℓ,
state.stepsize,
)
Q = if spl.selector.tag !== :default
# When a Gibbs component, transform values to the unconstrained space
# and update the previous evaluation.
DynamicPPL.link!(vi, spl)
DynamicHMC.evaluate_ℓ(ℓ, vi[spl])
else
return resume(resume_from, N; chain_type=chain_type, progress=false, N=N, kwargs...)
state.cache
end
end
newQ, _ = DynamicHMC.mcmc_next_step(steps, Q)

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::DynamicNUTS,
parallel::AbstractMCMC.AbstractMCMCParallel,
N::Integer,
n_chains::Integer;
chain_type=MCMCChains.Chains,
progress=PROGRESS[],
kwargs...
)
if progress
@warn "[HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
# Update the variables.
vi[spl] = newQ.q
DynamicPPL.setlogp!(vi, newQ.ℓq)

# If a Gibbs component, transform the values back to the constrained space.
if spl.selector.tag !== :default
DynamicPPL.invlink!(vi, spl)
end
return AbstractMCMC.sample(rng, model, Sampler(alg, model), parallel, N, n_chains;
chain_type=chain_type, progress=false, N=N, kwargs...)

# Create next sample and state.
sample = Transition(vi)
newstate = DynamicNUTSState(vi, newQ, state.metric, state.stepsize)

return sample, newstate
end
10 changes: 8 additions & 2 deletions test/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ include(dir*"/test/test_utils/AllUtils.jl")

@test DynamicPPL.alg_str(Sampler(DynamicNUTS(), gdemo_default)) == "DynamicNUTS"

chn = sample(gdemo_default, DynamicNUTS(), 5000)
check_numerical(chn, [:s, :m], [49/24, 7/6], atol=0.2)
chn = sample(gdemo_default, DynamicNUTS(), 10_000)
check_gdemo(chn)

chn2 = sample(gdemo_default, Gibbs(PG(15, :s), DynamicNUTS(:m)), 10_000)
check_gdemo(chn2)

chn3 = sample(gdemo_default, Gibbs(DynamicNUTS(:s), ESS(:m)), 10_000)
check_gdemo(chn3)
end