@@ -4,19 +4,13 @@ using ADTypes: AbstractADType, AutoForwardDiff
44using Chairmarks: @be
55import DifferentiationInterface as DI
66using DocStringExtensions
7- using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo
7+ using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, link
88using LogDensityProblems: logdensity, logdensity_and_gradient
99using Random: Random, Xoshiro
1010using Statistics: median
1111using Test: @test
1212
13- export ADResult, run_ad
14-
15- # This function needed to work around the fact that different backends can
16- # return different AbstractArrays for the gradient. See
17- # https://github.com/JuliaDiff/DifferentiationInterface.jl/issues/754 for more
18- # context.
19- _to_vec_f64 (x:: AbstractArray ) = x isa Vector{Float64} ? x : collect (Float64, x)
13+ export ADResult, run_ad, ADIncorrectException
2014
2115"""
2216 REFERENCE_ADTYPE
@@ -27,33 +21,50 @@ it's the default AD backend used in Turing.jl.
2721const REFERENCE_ADTYPE = AutoForwardDiff ()
2822
2923"""
30- ADResult
24+ ADIncorrectException{T<:AbstractFloat}
25+
26+ Exception thrown when an AD backend returns an incorrect value or gradient.
27+
28+ The type parameter `T` is the numeric type of the value and gradient.
29+ """
30+ struct ADIncorrectException{T<: AbstractFloat } <: Exception
31+ value_expected:: T
32+ value_actual:: T
33+ grad_expected:: Vector{T}
34+ grad_actual:: Vector{T}
35+ end
36+
37+ """
38+ ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
3139
3240Data structure to store the results of the AD correctness test.
41+
42+ The type parameter `Tparams` is the numeric type of the parameters passed in;
43+ `Tresult` is the type of the value and the gradient.
3344"""
34- struct ADResult
45+ struct ADResult{Tparams <: AbstractFloat ,Tresult <: AbstractFloat }
3546 " The DynamicPPL model that was tested"
3647 model:: Model
3748 " The VarInfo that was used"
3849 varinfo:: AbstractVarInfo
3950 " The values at which the model was evaluated"
40- params:: Vector{<:Real }
51+ params:: Vector{Tparams }
4152 " The AD backend that was tested"
4253 adtype:: AbstractADType
4354 " The absolute tolerance for the value of logp"
44- value_atol:: Real
55+ value_atol:: Tresult
4556 " The absolute tolerance for the gradient of logp"
46- grad_atol:: Real
57+ grad_atol:: Tresult
4758 " The expected value of logp"
48- value_expected:: Union{Nothing,Float64 }
59+ value_expected:: Union{Nothing,Tresult }
4960 " The expected gradient of logp"
50- grad_expected:: Union{Nothing,Vector{Float64 }}
61+ grad_expected:: Union{Nothing,Vector{Tresult }}
5162 " The value of logp (calculated using `adtype`)"
52- value_actual:: Union{Nothing,Real }
63+ value_actual:: Union{Nothing,Tresult }
5364 " The gradient of logp (calculated using `adtype`)"
54- grad_actual:: Union{Nothing,Vector{Float64 }}
65+ grad_actual:: Union{Nothing,Vector{Tresult }}
5566 " If benchmarking was requested, the time taken by the AD backend to calculate the gradient of logp, divided by the time taken to evaluate logp itself"
56- time_vs_primal:: Union{Nothing,Float64 }
67+ time_vs_primal:: Union{Nothing,Tresult }
5768end
5869
5970"""
6475 benchmark=false,
6576 value_atol=1e-6,
6677 grad_atol=1e-6,
67- varinfo::AbstractVarInfo=VarInfo(model),
68- params::Vector{<:Real}=varinfo[:] ,
78+ varinfo::AbstractVarInfo=link( VarInfo(model), model),
79+ params::Union{Nothing, Vector{<:AbstractFloat}}=nothing ,
6980 reference_adtype::ADTypes.AbstractADType=REFERENCE_ADTYPE,
70- expected_value_and_grad::Union{Nothing,Tuple{Real ,Vector{<:Real }}}=nothing,
81+ expected_value_and_grad::Union{Nothing,Tuple{AbstractFloat ,Vector{<:AbstractFloat }}}=nothing,
7182 verbose=true,
7283 )::ADResult
7384
85+ ### Description
86+
7487Test the correctness and/or benchmark the AD backend `adtype` for the model
7588`model`.
7689
7790Whether to test and benchmark is controlled by the `test` and `benchmark`
7891keyword arguments. By default, `test` is `true` and `benchmark` is `false`.
7992
80- Returns an [`ADResult`](@ref) object, which contains the results of the
81- test and/or benchmark.
82-
8393Note that to run AD successfully you will need to import the AD backend itself.
8494For example, to test with `AutoReverseDiff()` you will need to run `import
8595ReverseDiff`.
8696
97+ ### Arguments
98+
8799There are two positional arguments, which absolutely must be provided:
88100
891011. `model` - The model being tested.
@@ -96,7 +108,9 @@ Everything else is optional, and can be categorised into several groups:
96108 DynamicPPL contains several different types of VarInfo objects which change
97109 the way model evaluation occurs. If you want to use a specific type of
98110 VarInfo, pass it as the `varinfo` argument. Otherwise, it will default to
99- using a `TypedVarInfo` generated from the model.
111+ using a linked `TypedVarInfo` generated from the model. Here, _linked_
112+ means that the parameters in the VarInfo have been transformed to
113+ unconstrained Euclidean space if they aren't already in that space.
100114
1011152. _How to specify the parameters._
102116
@@ -140,27 +154,40 @@ Everything else is optional, and can be categorised into several groups:
140154
141155 By default, this function prints messages when it runs. To silence it, set
142156 `verbose=false`.
157+
158+ ### Returns / Throws
159+
160+ Returns an [`ADResult`](@ref) object, which contains the results of the
161+ test and/or benchmark.
162+
163+ If `test` is `true` and the AD backend returns an incorrect value or gradient, an
164+ `ADIncorrectException` is thrown. If a different error occurs, it will be
165+ thrown as-is.
143166"""
144167function run_ad (
145168 model:: Model ,
146169 adtype:: AbstractADType ;
147- test= true ,
148- benchmark= false ,
149- value_atol= 1e-6 ,
150- grad_atol= 1e-6 ,
151- varinfo:: AbstractVarInfo = VarInfo (model),
152- params:: Vector{<:Real} = varinfo[:] ,
170+ test:: Bool = true ,
171+ benchmark:: Bool = false ,
172+ value_atol:: AbstractFloat = 1e-6 ,
173+ grad_atol:: AbstractFloat = 1e-6 ,
174+ varinfo:: AbstractVarInfo = link ( VarInfo (model), model),
175+ params:: Union{Nothing, Vector{<:AbstractFloat}} = nothing ,
153176 reference_adtype:: AbstractADType = REFERENCE_ADTYPE,
154- expected_value_and_grad:: Union{Nothing,Tuple{Real ,Vector{<:Real }}} = nothing ,
177+ expected_value_and_grad:: Union{Nothing,Tuple{AbstractFloat ,Vector{<:AbstractFloat }}} = nothing ,
155178 verbose= true ,
156179):: ADResult
180+ if isnothing (params)
181+ params = varinfo[:]
182+ end
183+ params = map (identity, params) # Concretise
184+
157185 verbose && @info " Running AD on $(model. f) with $(adtype) \n "
158- params = map (identity, params)
159186 verbose && println (" params : $(params) " )
160187 ldf = LogDensityFunction (model, varinfo; adtype= adtype)
161188
162189 value, grad = logdensity_and_gradient (ldf, params)
163- grad = _to_vec_f64 (grad)
190+ grad = collect (grad)
164191 verbose && println (" actual : $((value, grad)) " )
165192
166193 if test
@@ -172,10 +199,11 @@ function run_ad(
172199 expected_value_and_grad
173200 end
174201 verbose && println (" expected : $((value_true, grad_true)) " )
175- grad_true = _to_vec_f64 (grad_true)
176- # Then compare
177- @test isapprox (value, value_true; atol= value_atol)
178- @test isapprox (grad, grad_true; atol= grad_atol)
202+ grad_true = collect (grad_true)
203+
204+ exc () = throw (ADIncorrectException (value, value_true, grad, grad_true))
205+ isapprox (value, value_true; atol= value_atol) || exc ()
206+ isapprox (grad, grad_true; atol= grad_atol) || exc ()
179207 else
180208 value_true = nothing
181209 grad_true = nothing
0 commit comments