11# ##
22# ## DynamicHMC backend - https://github.com/tpapp/DynamicHMC.jl
33# ##
4- struct DynamicNUTS{AD, space} <: Hamiltonian{AD} end
54
6- using LogDensityProblems: LogDensityProblems
5+ """
6+ DynamicNUTS
77
8- struct FunctionLogDensity{F}
9- dimension:: Int
10- f:: F
11- end
8+ Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package.
9+
10+ To use it, make sure you have DynamicHMC package (version >= 2) loaded:
11+ ```julia
12+ using DynamicHMC
13+ ```
14+ """
15+ struct DynamicNUTS{AD,space} <: Hamiltonian{AD} end
1216
13- LogDensityProblems. dimension (ℓ:: FunctionLogDensity ) = ℓ. dimension
17+ DynamicNUTS (args... ) = DynamicNUTS {ADBackend()} (args... )
18+ DynamicNUTS {AD} (space:: Symbol... ) where AD = DynamicNUTS {AD, space} ()
1419
15- function LogDensityProblems. capabilities (:: Type{<:FunctionLogDensity} )
16- LogDensityProblems. LogDensityOrder {1} ()
20+ DynamicPPL. getspace (:: DynamicNUTS{<:Any, space} ) where {space} = space
21+
22+ struct DynamicHMCLogDensity{M<: Model ,S<: Sampler{<:DynamicNUTS} ,V<: AbstractVarInfo }
23+ model:: M
24+ sampler:: S
25+ varinfo:: V
1726end
1827
19- function LogDensityProblems . logdensity (ℓ:: FunctionLogDensity , x :: AbstractVector )
20- first (ℓ. f (x) )
28+ function DynamicHMC . dimension (ℓ:: DynamicHMCLogDensity )
29+ return length (ℓ. varinfo[ℓ . sampler] )
2130end
2231
23- function LogDensityProblems. logdensity_and_gradient (ℓ:: FunctionLogDensity ,
24- x:: AbstractVector )
25- ℓ. f (x)
32+ function DynamicHMC. capabilities (:: Type{<:DynamicHMCLogDensity} )
33+ return DynamicHMC. LogDensityOrder {1} ()
34+ end
35+
36+ function DynamicHMC. logdensity_and_gradient (
37+ ℓ:: DynamicHMCLogDensity ,
38+ x:: AbstractVector ,
39+ )
40+ return gradient_logp (x, ℓ. varinfo, ℓ. model, ℓ. sampler)
2641end
2742
2843"""
29- DynamicNUTS()
44+ DynamicNUTSState
3045
31- Dynamic No U-Turn Sampling algorithm provided by the DynamicHMC package. To use it, make
32- sure you have the DynamicHMC package (version `2.*`) loaded:
46+ State of the [`DynamicNUTS`](@ref) sampler.
3347
34- ```julia
35- using DynamicHMC
36- ``
48+ # Fields
49+ $(TYPEDFIELDS)
3750"""
38- DynamicNUTS (args... ) = DynamicNUTS {ADBackend()} (args... )
39- DynamicNUTS {AD} () where AD = DynamicNUTS {AD, ()} ()
40- function DynamicNUTS {AD} (space:: Symbol... ) where AD
41- DynamicNUTS {AD, space} ()
42- end
43-
44- struct DynamicNUTSState{V<: AbstractVarInfo ,D}
51+ struct DynamicNUTSState{V<: AbstractVarInfo ,C,M,S}
4552 vi:: V
46- draws:: Vector{D}
53+ " Cache of sample, log density, and gradient of log density."
54+ cache:: C
55+ metric:: M
56+ stepsize:: S
4757end
4858
49- DynamicPPL. getspace (:: DynamicNUTS{<:Any, space} ) where {space} = space
59+ function gibbs_update_state (state:: DynamicNUTSState , varinfo:: AbstractVarInfo )
60+ return DynamicNUTSState (varinfo, state. cache, state. metric, state. stepsize)
61+ end
5062
5163DynamicPPL. initialsampler (:: Sampler{<:DynamicNUTS} ) = SampleFromUniform ()
5264
@@ -55,44 +67,39 @@ function DynamicPPL.initialstep(
5567 model:: Model ,
5668 spl:: Sampler{<:DynamicNUTS} ,
5769 vi:: AbstractVarInfo ;
58- N:: Int ,
5970 kwargs...
6071)
61- # Set up lp function.
62- function _lp (x)
63- gradient_logp (x, vi, model, spl)
64- end
65-
66- link! (vi, spl)
67- l, dl = _lp (vi[spl])
68- while ! isfinite (l) || ! isfinite (dl)
69- model (vi, SampleFromUniform ())
70- link! (vi, spl)
71- l, dl = _lp (vi[spl])
72- end
73-
74- if spl. selector. tag == :default && ! islinked (vi, spl)
75- link! (vi, spl)
76- model (vi, spl)
72+ # Ensure that initial sample is in unconstrained space.
73+ if ! DynamicPPL. islinked (vi, spl)
74+ DynamicPPL. link! (vi, spl)
75+ model (rng, vi, spl)
7776 end
7877
79- results = mcmc_with_warmup (
78+ # Perform initial step.
79+ results = DynamicHMC. mcmc_keep_warmup (
8080 rng,
81- FunctionLogDensity (
82- length (vi[spl]),
83- _lp
84- ),
85- N
81+ DynamicHMCLogDensity (model, spl, vi),
82+ 0 ;
83+ initialization = (q = vi[spl],),
84+ reporter = DynamicHMC. NoProgressReport (),
8685 )
87- draws = results. chain
86+ steps = DynamicHMC. mcmc_steps (results. sampling_logdensity, results. final_warmup_state)
87+ Q, _ = DynamicHMC. mcmc_next_step (steps, results. final_warmup_state. Q)
8888
89- # Compute first transition and state.
90- draw = popfirst! (draws)
91- vi[spl] = draw
92- transition = Transition (vi)
93- state = DynamicNUTSState (vi, draws)
89+ # Update the variables.
90+ vi[spl] = Q. q
91+ DynamicPPL. setlogp! (vi, Q. ℓq)
9492
95- return transition, state
93+ # If a Gibbs component, transform the values back to the constrained space.
94+ if spl. selector. tag != = :default
95+ DynamicPPL. invlink! (vi, spl)
96+ end
97+
98+ # Create first sample and state.
99+ sample = Transition (vi)
100+ state = DynamicNUTSState (vi, Q, steps. H. κ, steps. ϵ)
101+
102+ return sample, state
96103end
97104
98105function AbstractMCMC. step (
@@ -102,55 +109,38 @@ function AbstractMCMC.step(
102109 state:: DynamicNUTSState ;
103110 kwargs...
104111)
105- # Extract VarInfo object .
112+ # Compute next sample .
106113 vi = state. vi
107-
108- # Pop the next draw off the vector.
109- draw = popfirst! (state. draws)
110- vi[spl] = draw
111-
112- # Compute next transition.
113- transition = Transition (vi)
114-
115- return transition, state
116- end
117-
118- # Disable the progress logging for DynamicHMC, since it has its own progress meter.
119- function AbstractMCMC. sample (
120- rng:: AbstractRNG ,
121- model:: AbstractModel ,
122- alg:: DynamicNUTS ,
123- N:: Integer ;
124- chain_type= MCMCChains. Chains,
125- resume_from= nothing ,
126- progress= PROGRESS[],
127- kwargs...
128- )
129- if progress
130- @warn " [HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
131- end
132- if resume_from === nothing
133- return AbstractMCMC. sample (rng, model, Sampler (alg, model), N;
134- chain_type= chain_type, progress= false , N= N, kwargs... )
114+ ℓ = DynamicHMCLogDensity (model, spl, vi)
115+ steps = DynamicHMC. mcmc_steps (
116+ rng,
117+ DynamicHMC. NUTS (),
118+ state. metric,
119+ ℓ,
120+ state. stepsize,
121+ )
122+ Q = if spl. selector. tag != = :default
123+ # When a Gibbs component, transform values to the unconstrained space
124+ # and update the previous evaluation.
125+ DynamicPPL. link! (vi, spl)
126+ DynamicHMC. evaluate_ℓ (ℓ, vi[spl])
135127 else
136- return resume (resume_from, N; chain_type = chain_type, progress = false , N = N, kwargs ... )
128+ state . cache
137129 end
138- end
130+ newQ, _ = DynamicHMC . mcmc_next_step (steps, Q)
139131
140- function AbstractMCMC. sample (
141- rng:: AbstractRNG ,
142- model:: AbstractModel ,
143- alg:: DynamicNUTS ,
144- parallel:: AbstractMCMC.AbstractMCMCParallel ,
145- N:: Integer ,
146- n_chains:: Integer ;
147- chain_type= MCMCChains. Chains,
148- progress= PROGRESS[],
149- kwargs...
150- )
151- if progress
152- @warn " [HMC] Progress logging in Turing is disabled since DynamicHMC provides its own progress meter"
132+ # Update the variables.
133+ vi[spl] = newQ. q
134+ DynamicPPL. setlogp! (vi, newQ. ℓq)
135+
136+ # If a Gibbs component, transform the values back to the constrained space.
137+ if spl. selector. tag != = :default
138+ DynamicPPL. invlink! (vi, spl)
153139 end
154- return AbstractMCMC. sample (rng, model, Sampler (alg, model), parallel, N, n_chains;
155- chain_type= chain_type, progress= false , N= N, kwargs... )
140+
141+ # Create next sample and state.
142+ sample = Transition (vi)
143+ newstate = DynamicNUTSState (vi, newQ, state. metric, state. stepsize)
144+
145+ return sample, newstate
156146end
0 commit comments