@@ -18,31 +18,17 @@ is_supported(::ADTypes.AutoReverseDiff) = true
1818 LogDensityFunction(
1919 model::Model,
2020 varinfo::AbstractVarInfo=VarInfo(model),
21- context::AbstractContext=DefaultContext();
22- adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
21+ context::AbstractContext=DefaultContext()
2322 )
2423
25- A struct which contains a model, along with all the information necessary to:
26-
27- - calculate its log density at a given point;
28- - and if `adtype` is provided, calculate the gradient of the log density at
29- that point.
24+ A struct which contains a model, along with all the information necessary to
25+ calculate its log density at a given point.
3026
3127At its most basic level, a LogDensityFunction wraps the model together with its
3228the type of varinfo to be used, as well as the evaluation context. These must
3329be known in order to calculate the log density (using
3430[`DynamicPPL.evaluate!!`](@ref)).
3531
36- If the `adtype` keyword argument is provided, then this struct will also store
37- the adtype along with other information for efficient calculation of the
38- gradient of the log density. Note that preparing a `LogDensityFunction` with an
39- AD type `AutoBackend()` requires the AD backend itself to have been loaded
40- (e.g. with `import Backend`).
41-
42- `DynamicPPL.LogDensityFunction` implements the LogDensityProblems.jl interface.
43- If `adtype` is nothing, then only `logdensity` is implemented. If `adtype` is a
44- concrete AD backend type, then `logdensity_and_gradient` is also implemented.
45-
4632# Fields
4733$(FIELDS)
4834
@@ -84,40 +70,42 @@ julia> # This also respects the context in `model`.
8470julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
8571true
8672
87- julia> # If we also need to calculate the gradient, we can specify an AD backend.
73+ julia> # If we also need to calculate the gradient, an AD backend must be specified as part of the model .
8874 import ForwardDiff, ADTypes
8975
90- julia> f = LogDensityFunction(model, adtype=ADTypes.AutoForwardDiff());
76+ julia> model_with_ad = Model(model, ADTypes.AutoForwardDiff());
77+
78+ julia> f = LogDensityFunction(model_with_ad);
9179
9280julia> LogDensityProblems.logdensity_and_gradient(f, [0.0])
9381(-2.3378770664093453, [1.0])
9482```
9583"""
96- struct LogDensityFunction{
97- M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext ,AD<: Union{Nothing,ADTypes.AbstractADType}
98- }
84+ struct LogDensityFunction{M<: Model ,V<: AbstractVarInfo ,C<: AbstractContext }
9985 " model used for evaluation"
10086 model:: M
10187 " varinfo used for evaluation"
10288 varinfo:: V
10389 " context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
10490 context:: C
105- " AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
106- adtype:: AD
10791 " (internal use only) gradient preparation object for the model"
10892 prep:: Union{Nothing,DI.GradientPrep}
10993
11094 function LogDensityFunction (
11195 model:: Model ,
11296 varinfo:: AbstractVarInfo = VarInfo (model),
113- context:: AbstractContext = leafcontext (model. context);
114- adtype:: Union{ADTypes.AbstractADType,Nothing} = model. adtype,
97+ context:: AbstractContext = leafcontext (model. context),
11598 )
99+ adtype = model. adtype
116100 if adtype === nothing
117101 prep = nothing
118102 else
119103 # Make backend-specific tweaks to the adtype
104+ # This should arguably be done in the model constructor, but it needs the
105+ # varinfo and context to do so, and it seems excessive to construct a
106+ # varinfo at the point of calling Model().
120107 adtype = tweak_adtype (adtype, model, varinfo, context)
108+ model = Model (model, adtype)
121109 # Check whether it is supported
122110 is_supported (adtype) ||
123111 @warn " The AD backend $adtype is not officially supported by DynamicPPL. Gradient calculations may still work, but compatibility is not guaranteed."
@@ -138,8 +126,8 @@ struct LogDensityFunction{
138126 )
139127 end
140128 end
141- return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype) } (
142- model, varinfo, context, adtype, prep
129+ return new {typeof(model),typeof(varinfo),typeof(context)} (
130+ model, varinfo, context, prep
143131 )
144132 end
145133end
@@ -157,10 +145,10 @@ Create a new LogDensityFunction using the model, varinfo, and context from the g
157145function LogDensityFunction (
158146 f:: LogDensityFunction , adtype:: Union{Nothing,ADTypes.AbstractADType}
159147)
160- return if adtype === f. adtype
148+ return if adtype === f. model . adtype
161149 f # Avoid recomputing prep if not needed
162150 else
163- LogDensityFunction (f. model, f. varinfo, f. context; adtype = adtype )
151+ LogDensityFunction (Model ( f. model, adtype), f. varinfo, f. context)
164152 end
165153end
166154
@@ -187,35 +175,46 @@ end
187175# ## LogDensityProblems interface
188176
189177function LogDensityProblems. capabilities (
190- :: Type{<:LogDensityFunction{M,V,C,Nothing}}
191- ) where {M,V,C}
178+ :: Type {
179+ <: LogDensityFunction {
180+ Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,Nothing},V,C
181+ },
182+ },
183+ ) where {F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C}
192184 return LogDensityProblems. LogDensityOrder {0} ()
193185end
194186function LogDensityProblems. capabilities (
195- :: Type{<:LogDensityFunction{M,V,C,AD}}
196- ) where {M,V,C,AD<: ADTypes.AbstractADType }
187+ :: Type {
188+ <: LogDensityFunction {
189+ Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,TAD},V,C
190+ },
191+ },
192+ ) where {
193+ F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx,V,C,TAD<: ADTypes.AbstractADType
194+ }
197195 return LogDensityProblems. LogDensityOrder {1} ()
198196end
199197function LogDensityProblems. logdensity (f:: LogDensityFunction , x:: AbstractVector )
200198 return logdensity_at (x, f. model, f. varinfo, f. context)
201199end
202200function LogDensityProblems. logdensity_and_gradient (
203- f:: LogDensityFunction{M,V,C,AD} , x:: AbstractVector
204- ) where {M,V,C,AD<: ADTypes.AbstractADType }
205- f. prep === nothing &&
206- error (" Gradient preparation not available; this should not happen" )
201+ f:: LogDensityFunction{M,V,C} , x:: AbstractVector
202+ ) where {M,V,C}
203+ f. prep === nothing && error (
204+ " Attempted to call logdensity_and_gradient on a LogDensityFunction without an AD backend. You need to set an AD backend in the model before calculating the gradient of logp." ,
205+ )
207206 x = map (identity, x) # Concretise type
208207 # Make branching statically inferrable, i.e. type-stable (even if the two
209208 # branches happen to return different types)
210- return if use_closure (f. adtype)
209+ return if use_closure (f. model . adtype)
211210 DI. value_and_gradient (
212- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
211+ x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. model . adtype, x
213212 )
214213 else
215214 DI. value_and_gradient (
216215 logdensity_at,
217216 f. prep,
218- f. adtype,
217+ f. model . adtype,
219218 x,
220219 DI. Constant (f. model),
221220 DI. Constant (f. varinfo),
@@ -292,7 +291,7 @@ getmodel(f::DynamicPPL.LogDensityFunction) = f.model
292291Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
293292"""
294293function setmodel (f:: DynamicPPL.LogDensityFunction , model:: DynamicPPL.Model )
295- return LogDensityFunction (model, f. varinfo, f. context; adtype = f . adtype )
294+ return LogDensityFunction (model, f. varinfo, f. context)
296295end
297296
298297"""
0 commit comments