@@ -4,9 +4,11 @@ using Turing
44using DynamicPPL
55using DynamicPPL. TestUtils: DEMO_MODELS
66using DynamicPPL. TestUtils. AD: run_ad
7+ using Random: Random
78using StableRNGs: StableRNG
89using Test
910using .. Models: gdemo_default
11+ import ForwardDiff, ReverseDiff, Mooncake
1012
1113""" Element types that are always valid for a VarInfo regardless of ADType."""
1214const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)
@@ -181,17 +183,49 @@ ADTYPES = [
181183 Turing. AutoMooncake (; config= nothing ),
182184]
183185
186+ # Check that ADTypeCheckContext itself works as expected.
187+ @testset " ADTypeCheckContext" begin
188+ @model test_model () = x ~ Normal (0 , 1 )
189+ tm = test_model ()
190+ adtypes = (
191+ Turing. AutoForwardDiff (),
192+ Turing. AutoReverseDiff (),
193+ # TODO : Mooncake
194+ # Turing.AutoMooncake(config=nothing),
195+ )
196+ for actual_adtype in adtypes
197+ sampler = Turing. HMC (0.1 , 5 ; adtype= actual_adtype)
198+ for expected_adtype in adtypes
199+ contextualised_tm = DynamicPPL. contextualize (
200+ tm, ADTypeCheckContext (expected_adtype, tm. context)
201+ )
202+ @testset " Expected: $expected_adtype , Actual: $actual_adtype " begin
203+ if actual_adtype == expected_adtype
204+ # Check that this does not throw an error.
205+ Turing. sample (contextualised_tm, sampler, 2 )
206+ else
207+ @test_throws AbstractWrongADBackendError Turing. sample (
208+ contextualised_tm, sampler, 2
209+ )
210+ end
211+ end
212+ end
213+ end
214+ end
215+
184216@testset verbose = true " AD / ADTypeCheckContext" begin
185- # This testset ensures that samplers don't accidentally override the AD
186- # backend set in it.
187- @testset " Check ADType " begin
217+ # This testset ensures that samplers or optimisers don't accidentally
218+ # override the AD backend set in it.
219+ @testset " adtype= $adtype " for adtype in ADTYPES
188220 seed = 123
189221 alg = HMC (0.1 , 10 ; adtype= adtype)
190222 m = DynamicPPL. contextualize (
191223 gdemo_default, ADTypeCheckContext (adtype, gdemo_default. context)
192224 )
193225 # These will error if the adbackend being used is not the one set.
194226 sample (StableRNG (seed), m, alg, 10 )
227+ maximum_likelihood (m; adtype= adtype)
228+ maximum_a_posteriori (m; adtype= adtype)
195229 end
196230end
197231
0 commit comments