@@ -11,47 +11,50 @@ using LogDensityProblems
1111include (joinpath (@__DIR__ , " .." , " common.jl" ))
1212
1313@testset " AdvancedHMC GPU" begin
14- n_chains = 1000
15- n_samples = 1000
16- dim = 5
17-
18- T = Float32
19- m, s, θ₀ = zeros (T, dim), ones (T, dim), rand (T, dim, n_chains)
20- m, s, θ₀ = CuArray (m), CuArray (s), CuArray (θ₀)
21-
22- target = Gaussian (m, s)
23- metric = UnitEuclideanMetric (T, size (θ₀))
24- ℓπ, ∇ℓπ = get_ℓπ (target), get_∇ℓπ (target)
25- hamiltonian = Hamiltonian (metric, ℓπ, ∇ℓπ)
26- integrator = Leapfrog (one (T) / 5 )
27- proposal = HMCKernel (Trajectory {EndPointTS} (integrator, FixedNSteps (5 )))
28-
29- samples, stats = sample (hamiltonian, proposal, θ₀, n_samples)
14+ if CUDA. functional ()
15+ n_chains = 1000
16+ n_samples = 1000
17+ dim = 5
18+ T = Float32
19+ m, s, θ₀ = zeros (T, dim), ones (T, dim), rand (T, dim, n_chains)
20+ m, s, θ₀ = CuArray (m), CuArray (s), CuArray (θ₀)
21+ target = Gaussian (m, s)
22+ metric = UnitEuclideanMetric (T, size (θ₀))
23+ ℓπ, ∇ℓπ = get_ℓπ (target), get_∇ℓπ (target)
24+ hamiltonian = Hamiltonian (metric, ℓπ, ∇ℓπ)
25+ integrator = Leapfrog (one (T) / 5 )
26+ proposal = HMCKernel (Trajectory {EndPointTS} (integrator, FixedNSteps (5 )))
27+ samples, stats = sample (hamiltonian, proposal, θ₀, n_samples)
28+ else
29+ println (" GPU tests are skipped because no CUDA devices are found." )
30+ end
3031end
3132
3233@testset " PhasePoint GPU" begin
33- for T in [Float32, Float64]
34- function init_z1 ()
35- return PhasePoint (
36- CuArray ([T (NaN ) T (NaN )]),
37- CuArray ([T (NaN ) T (NaN )]),
38- DualValue (CuArray (zeros (T, 2 )), CuArray (zeros (T, 1 , 2 ))),
39- DualValue (CuArray (zeros (T, 2 )), CuArray (zeros (T, 1 , 2 ))),
40- )
34+ if CUDA. functional ()
35+ for T in [Float32, Float64]
36+ function init_z1 ()
37+ return PhasePoint (
38+ CuArray ([T (NaN ) T (NaN )]),
39+ CuArray ([T (NaN ) T (NaN )]),
40+ DualValue (CuArray (zeros (T, 2 )), CuArray (zeros (T, 1 , 2 ))),
41+ DualValue (CuArray (zeros (T, 2 )), CuArray (zeros (T, 1 , 2 ))),
42+ )
43+ end
44+ function init_z2 ()
45+ return PhasePoint (
46+ CuArray ([T (Inf ) T (Inf )]),
47+ CuArray ([T (Inf ) T (Inf )]),
48+ DualValue (CuArray (zeros (T, 2 )), CuArray (zeros (T, 1 , 2 ))),
49+ DualValue (CuArray (zeros (T, 2 )), CuArray (zeros (T, 1 , 2 ))),
50+ )
51+ end
52+ z1 = init_z1 ()
53+ z2 = init_z2 ()
54+ @test z1. ℓπ. value == z2. ℓπ. value
55+ @test z1. ℓκ. value == z2. ℓκ. value
4156 end
42- function init_z2 ()
43- return PhasePoint (
44- CuArray ([T (Inf ) T (Inf )]),
45- CuArray ([T (Inf ) T (Inf )]),
46- DualValue (CuArray (zeros (T, 2 )), CuArray (zeros (T, 1 , 2 ))),
47- DualValue (CuArray (zeros (T, 2 )), CuArray (zeros (T, 1 , 2 ))),
48- )
49- end
50-
51- z1 = init_z1 ()
52- z2 = init_z2 ()
53-
54- @test z1. ℓπ. value == z2. ℓπ. value
55- @test z1. ℓκ. value == z2. ℓκ. value
57+ else
58+ println (" GPU tests are skipped because no CUDA devices are found." )
5659 end
5760end
0 commit comments