Skip to content

Commit aa6ab07

Browse files
committed
Check ADType use in optimisation
1 parent 547eaee commit aa6ab07

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

test/optimisation/Optimisation.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module OptimisationTests
22

33
using ..Models: gdemo, gdemo_default
4+
using ..ADUtils: ADTypeCheckContext
45
using Distributions
56
using Distributions.FillArrays: Zeros
67
using DynamicPPL: DynamicPPL
@@ -140,7 +141,6 @@ using Turing
140141
gdemo_default, OptimizationOptimJL.LBFGS(); initial_params=true_value
141142
)
142143
m3 = maximum_likelihood(gdemo_default, OptimizationOptimJL.Newton())
143-
# TODO(mhauru) How can we check that the adtype is actually AutoReverseDiff?
144144
m4 = maximum_likelihood(
145145
gdemo_default, OptimizationOptimJL.BFGS(); adtype=AutoReverseDiff()
146146
)
@@ -616,6 +616,17 @@ using Turing
616616
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
617617
@assert get(result, :c) == (; :c => Array{Float64}[])
618618
end
619+
620+
@testset "ADType" begin
621+
Random.seed!(222)
622+
for adbackend in (AutoReverseDiff(), AutoForwardDiff(), AutoTracker())
623+
m = DynamicPPL.contextualize(
624+
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
625+
)
626+
maximum_likelihood(m; adtype=adbackend)
627+
maximum_a_posteriori(m; adtype=adbackend)
628+
end
629+
end
619630
end
620631

621632
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import Turing
77

88
include(pkgdir(Turing) * "/test/test_utils/models.jl")
99
include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl")
10+
include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl")
1011

1112
Turing.setprogress!(false)
1213

test/test_utils/ad_utils.jl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ end
9696
"""
9797
valid_eltypes(context::ADTypeCheckContext)
9898
99-
Return the element types that are valid for the ADType of `context`.
99+
Return the element types that are valid for the ADType of `context` as a tuple.
100100
"""
101101
function valid_eltypes(context::ADTypeCheckContext)
102102
context_at = adtype(context)
@@ -116,7 +116,7 @@ Check that the element types in `vi` are compatible with the ADType of `context`
116116
117117
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered.
118118
"""
119-
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo)
119+
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
120120
valids = valid_eltypes(context)
121121
for val in vi[:]
122122
valtype = typeof(val)
@@ -187,6 +187,14 @@ function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, right, left,
187187
return logp, vi
188188
end
189189

190+
function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
191+
logp, vi = DynamicPPL.dot_tilde_observe(
192+
DynamicPPL.childcontext(context), sampler, right, left, vi
193+
)
194+
check_adtype(context, vi)
195+
return logp, vi
196+
end
197+
190198
# Check that the ADTypeCheckContext works as expected.
191199
Test.@testset "ADTypeCheckContext" begin
192200
Turing.@model test_model() = x ~ Turing.Normal(0, 1)

0 commit comments

Comments
 (0)