From b51e1a97133c779ce2801baee44be88938bc16bb Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 30 Apr 2025 16:26:31 +0800 Subject: [PATCH 01/12] Unify argument order in phasepoint and transition --- src/hamiltonian.jl | 12 ------------ src/sampler.jl | 5 +++-- src/trajectory.jl | 10 +++++----- test/trajectory.jl | 4 ++-- 4 files changed, 10 insertions(+), 21 deletions(-) diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index ffcaee89..b0d54e06 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -167,18 +167,6 @@ end energy(args...) = -neg_energy(args...) -#### -#### Momentum refreshment -#### - -function phasepoint( - rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - θ::AbstractVecOrMat{T}, - h::Hamiltonian, -) where {T<:Real} - return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) -end - abstract type AbstractMomentumRefreshment end "Completly resample new momentum." diff --git a/src/sampler.jl b/src/sampler.jl index 09268981..06b4bbb8 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -41,7 +41,8 @@ function sample_init( # Ensure h.metric has the same dim as θ. h = resize(h, θ) # Initial transition - t = Transition(phasepoint(rng, θ, h), NamedTuple()) + refresh_r = rand_momentum(rng, h.metric, h.kinetic, θ) # Momentum refreshment + t = Transition(phasepoint(h, θ, refresh_r), NamedTuple()) return h, t end @@ -54,7 +55,7 @@ function transition( (; refreshment, τ) = κ @set! τ.integrator = jitter(rng, τ.integrator) z = refresh(rng, refreshment, h, z) - return transition(rng, τ, h, z) + return transition(rng, h, τ, z) end function Adaptation.adapt!( diff --git a/src/trajectory.jl b/src/trajectory.jl index 74c9a7fc..a7680760 100644 --- a/src/trajectory.jl +++ b/src/trajectory.jl @@ -244,10 +244,10 @@ $(SIGNATURES) Make a MCMC transition from phase point `z` using the trajectory `τ` under Hamiltonian `h`. -NOTE: This is a RNG-implicit fallback function for `transition(Random.default_rng(), τ, h, z)` +NOTE: This is a RNG-implicit fallback function for `transition(Random.default_rng(), h, τ, z)` """ -function transition(τ::Trajectory, h::Hamiltonian, z::PhasePoint) - return transition(Random.default_rng(), τ, h, z) +function transition(h::Hamiltonian, τ::Trajectory, z::PhasePoint) + return transition(Random.default_rng(), h, τ, z) end ### @@ -256,8 +256,8 @@ end function transition( rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, - τ::Trajectory{TS,I,TC}, h::Hamiltonian, + τ::Trajectory{TS,I,TC}, z::PhasePoint, ) where {TS<:AbstractTrajectorySampler,I,TC<:StaticTerminationCriterion} H0 = energy(z) @@ -665,7 +665,7 @@ function build_tree( end function transition( - rng::AbstractRNG, τ::Trajectory{TS,I,TC}, h::Hamiltonian, z0::PhasePoint + rng::AbstractRNG, h::Hamiltonian, τ::Trajectory{TS,I,TC}, z0::PhasePoint ) where { TS<:AbstractTrajectorySampler,I<:AbstractIntegrator,TC<:DynamicTerminationCriterion } diff --git a/test/trajectory.jl b/test/trajectory.jl index 242bc4b7..403fd446 100644 --- a/test/trajectory.jl +++ b/test/trajectory.jl @@ -129,11 +129,11 @@ end for τ_test in [τ, τ_with_jittered_lf], seed in [1234, 5678, 90] rng = MersenneTwister(seed) z = AdvancedHMC.phasepoint(h, θ_init, r_init) - z1′ = AdvancedHMC.transition(rng, τ_test, h, z).z + z1′ = AdvancedHMC.transition(rng, h, τ_test, z).z rng = MersenneTwister(seed) z = AdvancedHMC.phasepoint(h, θ_init, r_init) - z2′ = AdvancedHMC.transition(rng, τ_test, h, z).z + z2′ = AdvancedHMC.transition(rng, h, τ_test, z).z @test z1′.θ == z2′.θ @test z1′.r == z2′.r From 24faf46625e9baa49eae32a0f2e745b6431252ab Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 30 Apr 2025 17:46:28 +0800 Subject: [PATCH 02/12] Use Test in quality tests --- test/quality.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quality.jl b/test/quality.jl index 94d263d1..f3548bbf 100644 --- a/test/quality.jl +++ b/test/quality.jl @@ -1,5 +1,5 @@ using AdvancedHMC -using ReTest +using Test using Aqua: Aqua using JET using ForwardDiff From d023f9534735b900df52986556fc4bebacd0d921 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 30 Apr 2025 17:50:34 +0800 Subject: [PATCH 03/12] Revert change to Test --- test/quality.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quality.jl b/test/quality.jl index f3548bbf..94d263d1 100644 --- a/test/quality.jl +++ b/test/quality.jl @@ -1,5 +1,5 @@ using AdvancedHMC -using Test +using ReTest using Aqua: Aqua using JET using ForwardDiff From 93c92eb74bcecdf1a67c6fa228c1bd0b27f9b59f Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 30 Apr 2025 19:27:36 +0800 Subject: [PATCH 04/12] Try with Test --- test/Project.toml | 1 + test/quality.jl | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 2eea074f..f3821481 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -20,6 +20,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReTest = "e0db7c4e-2690-44b9-bad6-7687da720f89" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/test/quality.jl b/test/quality.jl index 94d263d1..f3548bbf 100644 --- a/test/quality.jl +++ b/test/quality.jl @@ -1,5 +1,5 @@ using AdvancedHMC -using ReTest +using Test using Aqua: Aqua using JET using ForwardDiff From 6dbcda78fd7a66df2fb5ab7cf3da6438a07e11f0 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 30 Apr 2025 19:35:02 +0800 Subject: [PATCH 05/12] Try with Test --- test/quality.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/quality.jl b/test/quality.jl index f3548bbf..81a7114e 100644 --- a/test/quality.jl +++ b/test/quality.jl @@ -4,10 +4,10 @@ using Aqua: Aqua using JET using ForwardDiff -@testset "Aqua" begin +Test.@testset "Aqua" begin Aqua.test_all(AdvancedHMC) end -@testset "JET" begin +Test.@testset "JET" begin JET.test_package(AdvancedHMC; target_defined_modules=true) end From 774b36a7b44a8e10525b58c9765a8fbc00be65f7 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 30 Apr 2025 19:39:43 +0800 Subject: [PATCH 06/12] Use Test namespace --- test/quality.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/quality.jl b/test/quality.jl index 81a7114e..ebdf9ea7 100644 --- a/test/quality.jl +++ b/test/quality.jl @@ -1,5 +1,5 @@ using AdvancedHMC -using Test +using Test: Test using Aqua: Aqua using JET using ForwardDiff From 5e8326db6c316a8cdf919291cee943098f1c5f71 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Wed, 30 Apr 2025 20:32:21 +0800 Subject: [PATCH 07/12] Retrigger CUDA test --- src/sampler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sampler.jl b/src/sampler.jl index 06b4bbb8..f5c60b50 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -40,8 +40,8 @@ function sample_init( ) # Ensure h.metric has the same dim as θ. h = resize(h, θ) - # Initial transition refresh_r = rand_momentum(rng, h.metric, h.kinetic, θ) # Momentum refreshment + # Initial transition t = Transition(phasepoint(h, θ, refresh_r), NamedTuple()) return h, t end From 6ab71d6dc09daed602ffaa8b004e9492b74ba5ce Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Wed, 30 Apr 2025 20:45:36 +0100 Subject: [PATCH 08/12] Skip CUDA tests when no CUDA devices are found. (#436) * Update cuda.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/CUDA/cuda.jl | 79 ++++++++++++++++++++++++----------------------- 1 file changed, 41 insertions(+), 38 deletions(-) diff --git a/test/CUDA/cuda.jl b/test/CUDA/cuda.jl index 4cce17ce..0c84232e 100644 --- a/test/CUDA/cuda.jl +++ b/test/CUDA/cuda.jl @@ -11,47 +11,50 @@ using LogDensityProblems include(joinpath(@__DIR__, "..", "common.jl")) @testset "AdvancedHMC GPU" begin - n_chains = 1000 - n_samples = 1000 - dim = 5 - - T = Float32 - m, s, θ₀ = zeros(T, dim), ones(T, dim), rand(T, dim, n_chains) - m, s, θ₀ = CuArray(m), CuArray(s), CuArray(θ₀) - - target = Gaussian(m, s) - metric = UnitEuclideanMetric(T, size(θ₀)) - ℓπ, ∇ℓπ = get_ℓπ(target), get_∇ℓπ(target) - hamiltonian = Hamiltonian(metric, ℓπ, ∇ℓπ) - integrator = Leapfrog(one(T) / 5) - proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5))) - - samples, stats = sample(hamiltonian, proposal, θ₀, n_samples) + if CUDA.functional() + n_chains = 1000 + n_samples = 1000 + dim = 5 + T = Float32 + m, s, θ₀ = zeros(T, dim), ones(T, dim), rand(T, dim, n_chains) + m, s, θ₀ = CuArray(m), CuArray(s), CuArray(θ₀) + target = Gaussian(m, s) + metric = UnitEuclideanMetric(T, size(θ₀)) + ℓπ, ∇ℓπ = get_ℓπ(target), get_∇ℓπ(target) + hamiltonian = Hamiltonian(metric, ℓπ, ∇ℓπ) + integrator = Leapfrog(one(T) / 5) + proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5))) + samples, stats = sample(hamiltonian, proposal, θ₀, n_samples) + else + println("GPU tests are skipped because no CUDA devices are found.") + end end @testset "PhasePoint GPU" begin - for T in [Float32, Float64] - function init_z1() - return PhasePoint( - CuArray([T(NaN) T(NaN)]), - CuArray([T(NaN) T(NaN)]), - DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))), - DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))), - ) + if CUDA.functional() + for T in [Float32, Float64] + function init_z1() + return PhasePoint( + CuArray([T(NaN) T(NaN)]), + CuArray([T(NaN) T(NaN)]), + DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))), + DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))), + ) + end + function init_z2() + return PhasePoint( + CuArray([T(Inf) T(Inf)]), + CuArray([T(Inf) T(Inf)]), + DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))), + DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))), + ) + end + z1 = init_z1() + z2 = init_z2() + @test z1.ℓπ.value == z2.ℓπ.value + @test z1.ℓκ.value == z2.ℓκ.value end - function init_z2() - return PhasePoint( - CuArray([T(Inf) T(Inf)]), - CuArray([T(Inf) T(Inf)]), - DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))), - DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))), - ) - end - - z1 = init_z1() - z2 = init_z2() - - @test z1.ℓπ.value == z2.ℓπ.value - @test z1.ℓκ.value == z2.ℓκ.value + else + println("GPU tests are skipped because no CUDA devices are found.") end end From 5f6671d4cde6960ec67723d6d450f3819c98d22b Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Thu, 1 May 2025 14:43:26 +0800 Subject: [PATCH 09/12] Revert phasepoint unifying --- Project.toml | 2 +- src/hamiltonian.jl | 12 ++++++++++++ src/sampler.jl | 3 +-- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index c0f46d18..bae9e5dd 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedHMC" uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" -version = "0.7.1" +version = "0.8.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/hamiltonian.jl b/src/hamiltonian.jl index b0d54e06..ffcaee89 100644 --- a/src/hamiltonian.jl +++ b/src/hamiltonian.jl @@ -167,6 +167,18 @@ end energy(args...) = -neg_energy(args...) +#### +#### Momentum refreshment +#### + +function phasepoint( + rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}}, + θ::AbstractVecOrMat{T}, + h::Hamiltonian, +) where {T<:Real} + return phasepoint(h, θ, rand_momentum(rng, h.metric, h.kinetic, θ)) +end + abstract type AbstractMomentumRefreshment end "Completly resample new momentum." diff --git a/src/sampler.jl b/src/sampler.jl index f5c60b50..c0a42681 100644 --- a/src/sampler.jl +++ b/src/sampler.jl @@ -40,9 +40,8 @@ function sample_init( ) # Ensure h.metric has the same dim as θ. h = resize(h, θ) - refresh_r = rand_momentum(rng, h.metric, h.kinetic, θ) # Momentum refreshment # Initial transition - t = Transition(phasepoint(h, θ, refresh_r), NamedTuple()) + t = Transition(phasepoint(rng, θ, h), NamedTuple()) return h, t end From b4ed51886198089a46a6be4fc34837356a1e78eb Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Thu, 1 May 2025 14:46:51 +0800 Subject: [PATCH 10/12] Bump compat for docs --- docs/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/Project.toml b/docs/Project.toml index 7c50d939..c48bd544 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,6 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244" [compat] -AdvancedHMC = "0.7" +AdvancedHMC = "0.8" Documenter = "1" DocumenterCitations = "1" \ No newline at end of file From b90ac4005693b93b1417c46ed4696c93eb1ba360 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Thu, 1 May 2025 16:28:35 +0800 Subject: [PATCH 11/12] Update changelog --- HISTORY.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/HISTORY.md b/HISTORY.md index 8f43666a..e6fa1ed4 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # AdvancedHMC Changelog +## 0.8.0 + + - To make an MCMC transtion from phasepoint `z` using trajectory `τ` under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`. + ## v0.7.1 - README has been simplified, many docs transfered to docs: https://turinglang.org/AdvancedHMC.jl/dev/. From 8c34fdcbedf0bf2bf77c72a8d49dbedde6866cd4 Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Thu, 1 May 2025 16:39:17 +0800 Subject: [PATCH 12/12] Better changelog --- HISTORY.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index e6fa1ed4..83d8882e 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -2,7 +2,7 @@ ## 0.8.0 - - To make an MCMC transtion from phasepoint `z` using trajectory `τ` under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`. + - To make an MCMC transtion from phasepoint `z` using trajectory `τ`(or HMCKernel `κ`) under Hamiltonian `h`, use `transition(h, τ, z)` or `transition(rng, h, τ, z)`(if using HMCKernel, use `transition(h, κ, z)` or `transition(rng, h, κ, z)`). ## v0.7.1