Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ EllipticalSliceSampling = "cad2338a-1db2-11e9-3401-43bc07c9ede2"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -39,7 +38,6 @@ DynamicPPL = "0.5"
EllipticalSliceSampling = "0.2"
ForwardDiff = "0.10.3"
Libtask = "0.3.1"
LogDensityProblems = "^0.9, 0.10"
MCMCChains = "3.0.7"
ProgressLogging = "0.1"
Reexport = "0.2.0"
Expand All @@ -55,6 +53,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CmdStan = "593b3428-ca2f-500c-ae53-031589ec8ddd"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand All @@ -66,4 +65,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Pkg", "PDMats", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization"]
test = ["Pkg", "PDMats", "TerminalLoggers", "Test", "UnicodePlots", "StatsBase", "FiniteDifferences", "DynamicHMC", "LogDensityProblems", "CmdStan", "BenchmarkTools", "Zygote", "ReverseDiff", "Memoization"]
8 changes: 6 additions & 2 deletions docs/src/using-turing/dynamichmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@ Turing supports the use of [DynamicHMC](https://github.com/tpapp/DynamicHMC.jl)
`DynamicNUTS` is not appropriate for use in compositional inference. If you intend to use [Gibbs]({{site.baseurl}}/docs/library/#Turing.Inference.Gibbs) sampling, you must use Turing's native `NUTS` function.


To use the `DynamicNUTS` function, you must import the `DynamicHMC` package as well as Turing. Turing does not formally require `DynamicHMC` but will include additional functionality if both packages are present.
To use the `DynamicNUTS` function, you must import the `DynamicHMC` and
`LogDensityProblems` packages as well as Turing. Turing does not formally require
`DynamicHMC` and `LogDensityProblems` but will include additional functionality if both
packages are present.

Here is a brief example of how to apply `DynamicNUTS`:


```julia
# Import Turing and DynamicHMC.
using LogDensityProblems, DynamicHMC, Turing
using Turing
using LogDensityProblems, DynamicHMC

# Model definition.
@model gdemo(x, y) = begin
Expand Down
17 changes: 9 additions & 8 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,15 @@ using .Variational
# end
# end

@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" @eval Inference begin
import ..DynamicHMC

if isdefined(DynamicHMC, :mcmc_with_warmup)
using ..DynamicHMC: mcmc_with_warmup
include("contrib/inference/dynamichmc.jl")
else
error("Please update DynamicHMC, v1.x is no longer supported")
@init @require DynamicHMC="bbc10e6e-7c05-544b-b16e-64fede858acb" begin
@require LogDensityProblems="6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" @eval Inference begin
import ..DynamicHMC, ..LogDensityProblems

if isdefined(DynamicHMC, :mcmc_with_warmup)
include("contrib/inference/dynamichmc.jl")
else
error("Please update DynamicHMC, v1.x is no longer supported")
end
end
end

Expand Down
220 changes: 122 additions & 98 deletions src/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
@@ -1,148 +1,172 @@
###
### DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
###

"""
DynamicNUTS

Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make
sure you have the LogDensityProblems package and DynamicHMC package (version >= 2) loaded:

```julia
using LogDensityProblems, DynamicHMC
```
"""
struct DynamicNUTS{AD, space} <: Hamiltonian{AD} end

using LogDensityProblems: LogDensityProblems
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
DynamicNUTS{AD}(space::Symbol...) where AD = DynamicNUTS{AD, space}()

getspace(::DynamicNUTS{<:Any, space}) where {space} = space

struct FunctionLogDensity{F}
dimension::Int
f::F
mutable struct DynamicNUTSState{V<:VarInfo} <: AbstractSamplerState
vi::V
end

LogDensityProblems.dimension(ℓ::FunctionLogDensity) = ℓ.dimension
function Sampler(
alg::DynamicNUTS,
model::Model,
s::Selector=Selector()
)
# Construct a state, using a default function.
state = DynamicNUTSState(VarInfo(model))

function LogDensityProblems.capabilities(::Type{<:FunctionLogDensity})
LogDensityProblems.LogDensityOrder{1}()
# Return a new sampler.
return Sampler(alg, Dict{Symbol,Any}(), s, state)
end

function LogDensityProblems.logdensity(ℓ::FunctionLogDensity, x::AbstractVector)
first(ℓ.f(x))
"""
DynamicNUTSTransition

Transition for the `DynamicNUTS` sampler.
"""
struct DynamicNUTSTransition{T,F<:AbstractFloat,QType,H,S}
θ::T
lp::F
Q::QType
hamiltonian::H
stepsize::S
end

function LogDensityProblems.logdensity_and_gradient(ℓ::FunctionLogDensity,
x::AbstractVector)
ℓ.f(x)
function additional_parameters(::Type{<:DynamicNUTSTransition})
return [:lp]
end

"""
DynamicNUTS()
# Wrapper for the log density function
struct LogDensity{M<:Model,S<:Sampler}
model::M
spl::S
end

Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make
sure you have the DynamicHMC package (version `2.*`) loaded:
function LogDensityProblems.dimension(ℓ::LogDensity)
spl = ℓ.spl
return length(spl.state.vi[spl])
end

```julia
using DynamicHMC
``
"""
DynamicNUTS(args...) = DynamicNUTS{ADBackend()}(args...)
DynamicNUTS{AD}() where AD = DynamicNUTS{AD, ()}()
function DynamicNUTS{AD}(space::Symbol...) where AD
DynamicNUTS{AD, space}()
function LogDensityProblems.capabilities(::Type{<:LogDensity})
LogDensityProblems.LogDensityOrder{1}()
end

mutable struct DynamicNUTSState{V<:VarInfo, D} <: AbstractSamplerState
vi::V
draws::Vector{D}
function LogDensityProblems.logdensity(ℓ::LogDensity, x::AbstractVector)
sampler = ℓ.sampler
vi = sampler.state.vi

x_old = vi[sampler]
lj_old = getlogp(vi)

vi[sampler] = x
runmodel!(ℓ.model, vi, sampler)
lj = getlogp(vi)

vi[sampler] = x_old
setlogp!(vi, lj_old)

return lj
end

getspace(::DynamicNUTS{<:Any, space}) where {space} = space
function LogDensityProblems.logdensity_and_gradient(ℓ::LogDensity,
x::AbstractVector)
spl = ℓ.spl
return gradient_logp(x, spl.state.vi, ℓ.model, spl)
end

function AbstractMCMC.sample_init!(
function AbstractMCMC.step!(
rng::AbstractRNG,
model::Model,
spl::Sampler{<:DynamicNUTS},
N::Integer;
::Integer,
::Nothing;
kwargs...
)
# Set up lp function.
function _lp(x)
gradient_logp(x, spl.state.vi, model, spl)
# Convert to transformed space.
vi = spl.state.vi
if !islinked(vi, spl)
Turing.DEBUG && @debug "X-> R..."
link!(vi, spl)
runmodel!(model, vi, spl)
end

runmodel!(model, spl.state.vi, SampleFromUniform())

if spl.selector.tag == :default
link!(spl.state.vi, spl)
runmodel!(model, spl.state.vi, spl)
end

# Set the parameters to a starting value.
initialize_parameters!(spl; kwargs...)

results = mcmc_with_warmup(
# Initial step
results = DynamicHMC.mcmc_keep_warmup(
rng,
FunctionLogDensity(
length(spl.state.vi[spl]),
_lp
),
N
LogDensity(model, spl),
0;
reporter = DynamicHMC.NoProgressReport()
)
steps = DynamicHMC.mcmc_steps(results.sampling_logdensity, results.final_warmup_state)
Q, stats = DynamicHMC.mcmc_next_step(steps, results.final_warmup_state.Q)

# Update the sample.
vi[spl] = Q.q
logp = stats.π
setlogp!(vi, logp)

spl.state.draws = results.chain
return DynamicNUTSTransition(tonamedtuple(vi), logp, Q, steps.H, steps.ϵ)
end

function AbstractMCMC.step!(
rng::AbstractRNG,
model::Model,
spl::Sampler{<:DynamicNUTS},
N::Integer,
transition;
::Integer,
transition::DynamicNUTSTransition;
kwargs...
)
# Pop the next draw off the vector.
draw = popfirst!(spl.state.draws)
spl.state.vi[spl] = draw
return Transition(spl)
end

function Sampler(
alg::DynamicNUTS,
model::Model,
s::Selector=Selector()
)
# Construct a state, using a default function.
state = DynamicNUTSState(VarInfo(model), [])

# Return a new sampler.
return Sampler(alg, Dict{Symbol,Any}(), s, state)
# Compute next sample.
hamiltonian = transition.hamiltonian
stepsize = transition.stepsize
steps = DynamicHMC.MCMCSteps(rng, DynamicHMC.NUTS(), hamiltonian, stepsize)
Q, stats = DynamicHMC.mcmc_next_step(steps, transition.Q)

# Update the sample.
vi = spl.state.vi
vi[spl] = Q.q
logp = stats.π
setlogp!(vi, logp)

return DynamicNUTSTransition(tonamedtuple(vi), logp, Q, hamiltonian, stepsize)
end

# Disable the progress logging for DynamicHMC, since it has its own progress meter.
function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::DynamicNUTS,
# Do not store fields specific to DynamicHMC.
function AbstractMCMC.transitions_init(
transition::DynamicNUTSTransition,
::Model,
::Sampler{<:DynamicNUTS},
N::Integer;
chain_type=MCMCChains.Chains,
resume_from=nothing,
progress=PROGRESS[],
kwargs...
)
if progress
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
if resume_from === nothing
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N;
chain_type=chain_type, progress=false, kwargs...)
else
return resume(resume_from, N; chain_type=chain_type, progress=false, kwargs...)
end
return Vector{Transition{typeof(transition.θ),typeof(transition.lp)}}(undef, N)
end

function AbstractMCMC.psample(
rng::AbstractRNG,
model::AbstractModel,
alg::DynamicNUTS,
N::Integer,
n_chains::Integer;
chain_type=MCMCChains.Chains,
progress=PROGRESS[],
function AbstractMCMC.transitions_save!(
transitions::Vector{<:Transition},
iteration::Integer,
transition::DynamicNUTSTransition,
::Model,
::Sampler{<:DynamicNUTS},
::Integer;
kwargs...
)
if progress
@warn "[$(alg_str(alg))] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
end
return AbstractMCMC.psample(rng, model, Sampler(alg, model), N, n_chains;
chain_type=chain_type, progress=false, kwargs...)
transitions[iteration] = Transition(transition.θ, transition.lp)
return
end
2 changes: 1 addition & 1 deletion test/contrib/inference/dynamichmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dir = splitdir(splitdir(pathof(Turing))[1])[1]
include(dir*"/test/test_utils/AllUtils.jl")

@stage_testset "dynamichmc" "dynamichmc.jl" begin
import DynamicHMC
import LogDensityProblems, DynamicHMC
Random.seed!(100)
chn = sample(gdemo_default, DynamicNUTS(), 5000);
check_numerical(chn, [:s, :m], [49/24, 7/6], atol=0.2)
Expand Down
6 changes: 1 addition & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,7 @@ include("test_utils/AllUtils.jl")
Turing.setadbackend(adbackend)
@testset "inference: $adbackend" begin
@testset "samplers" begin
# FIXME: DynamicHMC version 1 has (??) a bug on 32bit platforms (but we were too
# lazy to open an issue so Tamas doesn't know about it), retest with 2.0
if Int === Int64 && Pkg.installed()["DynamicHMC"].major == 2
include("contrib/inference/dynamichmc.jl")
end
include("contrib/inference/dynamichmc.jl")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pkg.installed is deprecated on Julia 1.4. I assume that it is unlikely that anyone will run the tests with DynamicHMC < 2 and according to the comment the bug (?) was observed on 32bit with DynamicHMC < 2, so I guess it should be safe to remove the check completely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to remove this check.

include("inference/gibbs.jl")
include("inference/hmc.jl")
include("inference/is.jl")
Expand Down