From bfa0a9cb6b470c265b2b8bfba6ca7dcb83a92f01 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 16 May 2025 13:47:45 +0100 Subject: [PATCH 1/5] DI 0.7 --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 5bef5bcb1..c82316541 100644 --- a/Project.toml +++ b/Project.toml @@ -54,7 +54,7 @@ ChainRulesCore = "1" Chairmarks = "1.3.1" Compat = "4" ConstructionBase = "1.5.4" -DifferentiationInterface = "0.6.41" +DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DocStringExtensions = "0.9" EnzymeCore = "0.6 - 0.8" diff --git a/test/Project.toml b/test/Project.toml index 79e6d129b..92e81bb83 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -38,7 +38,7 @@ Aqua = "0.8" Bijectors = "0.15.1" Combinatorics = "1" Compat = "4.3.0" -DifferentiationInterface = "0.6.41" +DifferentiationInterface = "0.6.41, 0.7" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "1" From 6ecc14608caff239e19bb1a067e581d5e6b3f50a Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 18 May 2025 01:17:38 +0100 Subject: [PATCH 2/5] Fix strictness failure with DifferentiationInterface 0.7 --- src/logdensityfunction.jl | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a42855f05..7d8e1fa15 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -106,6 +106,8 @@ struct LogDensityFunction{ adtype::AD "(internal use only) gradient preparation object for the model" prep::Union{Nothing,DI.GradientPrep} + "(internal use only) the closure used for the gradient preparation" + closure::Union{Nothing,Function} function LogDensityFunction( model::Model, @@ -114,6 +116,7 @@ struct LogDensityFunction{ adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) if adtype === nothing + closure = nothing prep = nothing else # Make backend-specific tweaks to the adtype @@ -124,10 +127,16 @@ struct LogDensityFunction{ # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - prep = DI.prepare_gradient( - x -> logdensity_at(x, model, varinfo, context), adtype, x - ) + # The closure itself has to be stored inside the + # LogDensityFunction to ensure that the signature of the + # function being differentiated is the same as that used for + # preparation. See + # https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an + # explanation. + closure = x -> logdensity_at(x, model, varinfo, context) + prep = DI.prepare_gradient(closure, adtype, x) else + closure = nothing prep = DI.prepare_gradient( logdensity_at, adtype, @@ -139,7 +148,7 @@ struct LogDensityFunction{ end end return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep + model, varinfo, context, adtype, prep, closure ) end end @@ -208,9 +217,8 @@ function LogDensityProblems.logdensity_and_gradient( # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) - DI.value_and_gradient( - x -> logdensity_at(x, f.model, f.varinfo, f.context), f.prep, f.adtype, x - ) + f.closure === nothing && error("Closure not available; this should not happen") + DI.value_and_gradient(f.closure, f.prep, f.adtype, x) else DI.value_and_gradient( logdensity_at, From 36e5cad92dca11e675fbbab12c008ee5c7a4bf2b Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Sun, 18 May 2025 18:32:53 +0100 Subject: [PATCH 3/5] Bump patch --- HISTORY.md | 4 ++++ Project.toml | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/HISTORY.md b/HISTORY.md index 4e9bc2d42..26eaa2d39 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -1,5 +1,9 @@ # DynamicPPL Changelog +## 0.36.4 + +Added compatibility with DifferentiationInterface.jl 0.7. + ## 0.36.3 Moved the `bijector(model)`, where `model` is a `DynamicPPL.Model`, function from the Turing main repo. diff --git a/Project.toml b/Project.toml index c82316541..128822367 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.36.3" +version = "0.36.4" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 6f670c42ad71cec76dfb995f3490695da7e96609 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 19 May 2025 11:16:53 +0100 Subject: [PATCH 4/5] Use `LogDensityAt` callable struct instead of closure --- src/logdensityfunction.jl | 41 +++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 15 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 7d8e1fa15..639e13081 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -106,8 +106,6 @@ struct LogDensityFunction{ adtype::AD "(internal use only) gradient preparation object for the model" prep::Union{Nothing,DI.GradientPrep} - "(internal use only) the closure used for the gradient preparation" - closure::Union{Nothing,Function} function LogDensityFunction( model::Model, @@ -116,7 +114,6 @@ struct LogDensityFunction{ adtype::Union{ADTypes.AbstractADType,Nothing}=nothing, ) if adtype === nothing - closure = nothing prep = nothing else # Make backend-specific tweaks to the adtype @@ -127,16 +124,8 @@ struct LogDensityFunction{ # Get a set of dummy params to use for prep x = map(identity, varinfo[:]) if use_closure(adtype) - # The closure itself has to be stored inside the - # LogDensityFunction to ensure that the signature of the - # function being differentiated is the same as that used for - # preparation. See - # https://github.com/TuringLang/DynamicPPL.jl/pull/922 for an - # explanation. - closure = x -> logdensity_at(x, model, varinfo, context) - prep = DI.prepare_gradient(closure, adtype, x) + prep = DI.prepare_gradient(LogDensityAt(model, varinfo, context), adtype, x) else - closure = nothing prep = DI.prepare_gradient( logdensity_at, adtype, @@ -148,7 +137,7 @@ struct LogDensityFunction{ end end return new{typeof(model),typeof(varinfo),typeof(context),typeof(adtype)}( - model, varinfo, context, adtype, prep, closure + model, varinfo, context, adtype, prep ) end end @@ -193,6 +182,27 @@ function logdensity_at( return getlogp(last(evaluate!!(model, varinfo_new, context))) end +""" + LogDensityAt( + x::AbstractVector, + model::Model, + varinfo::AbstractVarInfo, + context::AbstractContext + ) + +A callable struct that serves the same purpose as `x -> logdensity_at(x, model, +varinfo, context)`. +""" +struct LogDensityAt + model::Model + varinfo::AbstractVarInfo + context::AbstractContext +end +function (ld::LogDensityAt)(x::AbstractVector) + varinfo_new = unflatten(ld.varinfo, x) + return getlogp(last(evaluate!!(ld.model, varinfo_new, ld.context))) +end + ### LogDensityProblems interface function LogDensityProblems.capabilities( @@ -217,8 +227,9 @@ function LogDensityProblems.logdensity_and_gradient( # Make branching statically inferrable, i.e. type-stable (even if the two # branches happen to return different types) return if use_closure(f.adtype) - f.closure === nothing && error("Closure not available; this should not happen") - DI.value_and_gradient(f.closure, f.prep, f.adtype, x) + DI.value_and_gradient( + LogDensityAt(f.model, f.varinfo, f.context), f.prep, f.adtype, x + ) else DI.value_and_gradient( logdensity_at, From 7d75aa0e1d58fbc691c1f9d34150db6796fe1492 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Mon, 19 May 2025 13:03:15 +0100 Subject: [PATCH 5/5] Use type parameters --- src/logdensityfunction.jl | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 639e13081..443c435e0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -183,20 +183,19 @@ function logdensity_at( end """ - LogDensityAt( - x::AbstractVector, - model::Model, - varinfo::AbstractVarInfo, - context::AbstractContext + LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext}( + model::M + varinfo::V + context::C ) A callable struct that serves the same purpose as `x -> logdensity_at(x, model, varinfo, context)`. """ -struct LogDensityAt - model::Model - varinfo::AbstractVarInfo - context::AbstractContext +struct LogDensityAt{M<:Model,V<:AbstractVarInfo,C<:AbstractContext} + model::M + varinfo::V + context::C end function (ld::LogDensityAt)(x::AbstractVector) varinfo_new = unflatten(ld.varinfo, x)