Skip to content

Commit c2561b7

Browse files
committed
Use ADTypeCheckContext with hmc tests
1 parent aa6ab07 commit c2561b7

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

test/mcmc/hmc.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module HMCTests
22

33
using ..Models: gdemo_default
4+
using ..ADUtils: ADTypeCheckContext
45
#using ..Models: gdemo
56
using ..NumericalTests: check_gdemo, check_numerical
67
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
@@ -321,6 +322,15 @@ using Turing
321322
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
322323
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
323324
end
325+
326+
@testset "Check ADType" begin
327+
alg = HMC(0.1, 10; adtype=adbackend)
328+
m = DynamicPPL.contextualize(
329+
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
330+
)
331+
# These will error if the adbackend being used is not the one set.
332+
sample(rng, m, alg, 10)
333+
end
324334
end
325335

326336
end

test/optimisation/Optimisation.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,6 +623,7 @@ using Turing
623623
m = DynamicPPL.contextualize(
624624
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
625625
)
626+
# These will error if the adbackend being used is not the one set.
626627
maximum_likelihood(m; adtype=adbackend)
627628
maximum_a_posteriori(m; adtype=adbackend)
628629
end

0 commit comments

Comments
 (0)