@@ -106,6 +106,8 @@ struct LogDensityFunction{
106106 adtype:: AD
107107 " (internal use only) gradient preparation object for the model"
108108 prep:: Union{Nothing,DI.GradientPrep}
109+ " (internal use only) the closure used for the gradient preparation"
110+ closure:: Union{Nothing,Function}
109111
110112 function LogDensityFunction (
111113 model:: Model ,
@@ -114,6 +116,7 @@ struct LogDensityFunction{
114116 adtype:: Union{ADTypes.AbstractADType,Nothing} = nothing ,
115117 )
116118 if adtype === nothing
119+ closure = nothing
117120 prep = nothing
118121 else
119122 # Make backend-specific tweaks to the adtype
@@ -124,10 +127,16 @@ struct LogDensityFunction{
124127 # Get a set of dummy params to use for prep
125128 x = map (identity, varinfo[:])
126129 if use_closure (adtype)
127- prep = DI. prepare_gradient (
128- x -> logdensity_at (x, model, varinfo, context), adtype, x
129- )
130+ # The closure itself has to be stored inside the
131+ # LogDensityFunction to ensure that the signature of the
132+ # function being differentiated is the same as that used for
133+ # preparation. See
134+ # https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an
135+ # explanation.
136+ closure = x -> logdensity_at (x, model, varinfo, context)
137+ prep = DI. prepare_gradient (closure, adtype, x)
130138 else
139+ closure = nothing
131140 prep = DI. prepare_gradient (
132141 logdensity_at,
133142 adtype,
@@ -139,7 +148,7 @@ struct LogDensityFunction{
139148 end
140149 end
141150 return new {typeof(model),typeof(varinfo),typeof(context),typeof(adtype)} (
142- model, varinfo, context, adtype, prep
151+ model, varinfo, context, adtype, prep, closure
143152 )
144153 end
145154end
@@ -208,9 +217,8 @@ function LogDensityProblems.logdensity_and_gradient(
208217 # Make branching statically inferrable, i.e. type-stable (even if the two
209218 # branches happen to return different types)
210219 return if use_closure (f. adtype)
211- DI. value_and_gradient (
212- x -> logdensity_at (x, f. model, f. varinfo, f. context), f. prep, f. adtype, x
213- )
220+ f. closure === nothing && error (" Closure not available; this should not happen" )
221+ DI. value_and_gradient (f. closure, f. prep, f. adtype, x)
214222 else
215223 DI. value_and_gradient (
216224 logdensity_at,
0 commit comments