Skip to content

Commit 6ab71d6

Browse files
Skip CUDA tests when no CUDA devices are found. (#436)
* Update cuda.jl * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 5e8326d commit 6ab71d6

File tree

1 file changed

+41
-38
lines changed

1 file changed

+41
-38
lines changed

test/CUDA/cuda.jl

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,47 +11,50 @@ using LogDensityProblems
1111
include(joinpath(@__DIR__, "..", "common.jl"))
1212

1313
@testset "AdvancedHMC GPU" begin
14-
n_chains = 1000
15-
n_samples = 1000
16-
dim = 5
17-
18-
T = Float32
19-
m, s, θ₀ = zeros(T, dim), ones(T, dim), rand(T, dim, n_chains)
20-
m, s, θ₀ = CuArray(m), CuArray(s), CuArray(θ₀)
21-
22-
target = Gaussian(m, s)
23-
metric = UnitEuclideanMetric(T, size(θ₀))
24-
ℓπ, ∇ℓπ = get_ℓπ(target), get_∇ℓπ(target)
25-
hamiltonian = Hamiltonian(metric, ℓπ, ∇ℓπ)
26-
integrator = Leapfrog(one(T) / 5)
27-
proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5)))
28-
29-
samples, stats = sample(hamiltonian, proposal, θ₀, n_samples)
14+
if CUDA.functional()
15+
n_chains = 1000
16+
n_samples = 1000
17+
dim = 5
18+
T = Float32
19+
m, s, θ₀ = zeros(T, dim), ones(T, dim), rand(T, dim, n_chains)
20+
m, s, θ₀ = CuArray(m), CuArray(s), CuArray(θ₀)
21+
target = Gaussian(m, s)
22+
metric = UnitEuclideanMetric(T, size(θ₀))
23+
ℓπ, ∇ℓπ = get_ℓπ(target), get_∇ℓπ(target)
24+
hamiltonian = Hamiltonian(metric, ℓπ, ∇ℓπ)
25+
integrator = Leapfrog(one(T) / 5)
26+
proposal = HMCKernel(Trajectory{EndPointTS}(integrator, FixedNSteps(5)))
27+
samples, stats = sample(hamiltonian, proposal, θ₀, n_samples)
28+
else
29+
println("GPU tests are skipped because no CUDA devices are found.")
30+
end
3031
end
3132

3233
@testset "PhasePoint GPU" begin
33-
for T in [Float32, Float64]
34-
function init_z1()
35-
return PhasePoint(
36-
CuArray([T(NaN) T(NaN)]),
37-
CuArray([T(NaN) T(NaN)]),
38-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
39-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
40-
)
34+
if CUDA.functional()
35+
for T in [Float32, Float64]
36+
function init_z1()
37+
return PhasePoint(
38+
CuArray([T(NaN) T(NaN)]),
39+
CuArray([T(NaN) T(NaN)]),
40+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
41+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
42+
)
43+
end
44+
function init_z2()
45+
return PhasePoint(
46+
CuArray([T(Inf) T(Inf)]),
47+
CuArray([T(Inf) T(Inf)]),
48+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
49+
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
50+
)
51+
end
52+
z1 = init_z1()
53+
z2 = init_z2()
54+
@test z1.ℓπ.value == z2.ℓπ.value
55+
@test z1.ℓκ.value == z2.ℓκ.value
4156
end
42-
function init_z2()
43-
return PhasePoint(
44-
CuArray([T(Inf) T(Inf)]),
45-
CuArray([T(Inf) T(Inf)]),
46-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
47-
DualValue(CuArray(zeros(T, 2)), CuArray(zeros(T, 1, 2))),
48-
)
49-
end
50-
51-
z1 = init_z1()
52-
z2 = init_z2()
53-
54-
@test z1.ℓπ.value == z2.ℓπ.value
55-
@test z1.ℓκ.value == z2.ℓκ.value
57+
else
58+
println("GPU tests are skipped because no CUDA devices are found.")
5659
end
5760
end

0 commit comments

Comments
 (0)