|
563 | 563 | function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric; kwargs...) |
564 | 564 | return AHMC.Adaptation.NoAdaptation() |
565 | 565 | end |
566 | | - |
567 | | -########################## |
568 | | -# HMC State Constructors # |
569 | | -########################## |
570 | | - |
571 | | -function HMCState( |
572 | | - rng::AbstractRNG, |
573 | | - model::Model, |
574 | | - spl::Sampler{<:Hamiltonian}, |
575 | | - vi::AbstractVarInfo; |
576 | | - kwargs..., |
577 | | -) |
578 | | - # Link everything if needed. |
579 | | - waslinked = islinked(vi, spl) |
580 | | - if !waslinked |
581 | | - vi = link!!(vi, spl, model) |
582 | | - end |
583 | | - |
584 | | - # Get the initial log pdf and gradient functions. |
585 | | - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) |
586 | | - logπ = Turing.LogDensityFunction( |
587 | | - vi, |
588 | | - model, |
589 | | - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), |
590 | | - ) |
591 | | - |
592 | | - # Get the metric type. |
593 | | - metricT = getmetricT(spl.alg) |
594 | | - |
595 | | - # Create a Hamiltonian. |
596 | | - θ_init = Vector{Float64}(spl.state.vi[spl]) |
597 | | - metric = metricT(length(θ_init)) |
598 | | - h = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ) |
599 | | - |
600 | | - # Find good eps if not provided one |
601 | | - if iszero(spl.alg.ϵ) |
602 | | - ϵ = AHMC.find_good_stepsize(rng, h, θ_init) |
603 | | - @info "Found initial step size" ϵ |
604 | | - else |
605 | | - ϵ = spl.alg.ϵ |
606 | | - end |
607 | | - |
608 | | - # Generate a kernel. |
609 | | - kernel = make_ahmc_kernel(spl.alg, ϵ) |
610 | | - |
611 | | - # Generate a phasepoint. Replaced during sample_init! |
612 | | - h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ. |
613 | | - |
614 | | - # Unlink everything, if it was indeed linked before. |
615 | | - if waslinked |
616 | | - vi = invlink!!(vi, spl, model) |
617 | | - end |
618 | | - |
619 | | - return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) |
620 | | -end |
0 commit comments