From cc984297b14fb2e1e740955e3e6ab945255e1dd4 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 09:48:35 +0000 Subject: [PATCH 01/17] ADTypes interop --- src/logdensityfunction.jl | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 9e86590fa..eebbfad69 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -144,3 +144,17 @@ function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) end # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) + + +if isdefined(Base, :get_extension) + using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD +else + using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD +end + +# This is important for performance. +function LogDensityProblemsAD.ADgradient( + ad::ADTypes.AbstractADType, ℓ::DynamicPPL.LogDensityFunction +) + return LogDensityProblemsAD.ADgradient(ad, ℓ; x=map(identity, DynamicPPL.getparams(ℓ))) +end From 749eca9473113640bf06e22de7af158f0b1d3319 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 09:51:55 +0000 Subject: [PATCH 02/17] Improve comment --- src/logdensityfunction.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index eebbfad69..7e0cc204a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -152,7 +152,10 @@ else using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD end -# This is important for performance. +# This is important for performance -- one needs to provide `ADGradient` with a vector of +# parameters, or DifferentiationInterface will not have sufficient information to e.g. +# compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate +# a tape when using ReverseDiff.jl. function LogDensityProblemsAD.ADgradient( ad::ADTypes.AbstractADType, ℓ::DynamicPPL.LogDensityFunction ) From 4e6b97cff8b4eafc23484814fa2ebbb9eb551447 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 09:52:10 +0000 Subject: [PATCH 03/17] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ebc70b5ab..37096b55c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.30.5" +version = "0.30.6" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From f8847dbd3152344eed73c56807a3b7d280aafcb7 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 09:52:30 +0000 Subject: [PATCH 04/17] Formatting --- src/logdensityfunction.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 7e0cc204a..20db10ee0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -156,8 +156,6 @@ end # parameters, or DifferentiationInterface will not have sufficient information to e.g. # compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate # a tape when using ReverseDiff.jl. -function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AbstractADType, ℓ::DynamicPPL.LogDensityFunction -) +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) return LogDensityProblemsAD.ADgradient(ad, ℓ; x=map(identity, DynamicPPL.getparams(ℓ))) end From 1b33c75ebaf69fe34022b0112a769af89b58a1ce Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 09:53:08 +0000 Subject: [PATCH 05/17] Formatting --- src/logdensityfunction.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 20db10ee0..f19e1dd98 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -145,17 +145,10 @@ end # TODO: should we instead implement and call on `length(f.varinfo)` (at least in the cases where no sampler is involved)? LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) - -if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD -else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD -end - # This is important for performance -- one needs to provide `ADGradient` with a vector of # parameters, or DifferentiationInterface will not have sufficient information to e.g. # compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate # a tape when using ReverseDiff.jl. function LogDensityProblemsAD.ADgradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) - return LogDensityProblemsAD.ADgradient(ad, ℓ; x=map(identity, DynamicPPL.getparams(ℓ))) + return LogDensityProblemsAD.ADgradient(ad, ℓ; x=map(identity, getparams(ℓ))) end From cc17471d19ed3fbd02f5b30ef20b90a7212d9328 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 09:53:40 +0000 Subject: [PATCH 06/17] Improve documentation --- src/logdensityfunction.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index f19e1dd98..19d87fd5b 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -150,5 +150,6 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) # compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate # a tape when using ReverseDiff.jl. function LogDensityProblemsAD.ADgradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) - return LogDensityProblemsAD.ADgradient(ad, ℓ; x=map(identity, getparams(ℓ))) + x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params + return LogDensityProblemsAD.ADgradient(ad, ℓ; x) end From ea39bc7b1808cae4f5847d04fc08d788ccc83def Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 11:30:06 +0000 Subject: [PATCH 07/17] Testing infrastructure --- test/Project.toml | 2 ++ test/runtests.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index 36fcd1b69..fbf305db5 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -17,6 +17,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -43,6 +44,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6.0.4" MacroTools = "0.5.6" +Mooncake = "0.4.50" ReverseDiff = "1" StableRNGs = "1" Tracker = "0.2.23" diff --git a/test/runtests.jl b/test/runtests.jl index a832a0f08..1bf30cb71 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ using ForwardDiff using LogDensityProblems, LogDensityProblemsAD using MacroTools using MCMCChains +using Mooncake: Mooncake using Tracker using ReverseDiff using Zygote From c8c95c534bac6207f800f3e4f804d04c899bd462 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 11:30:17 +0000 Subject: [PATCH 08/17] Remove extras from main Project toml --- Project.toml | 8 -------- 1 file changed, 8 deletions(-) diff --git a/Project.toml b/Project.toml index 37096b55c..754f56216 100644 --- a/Project.toml +++ b/Project.toml @@ -67,11 +67,3 @@ ReverseDiff = "1" Test = "1.6" ZygoteRules = "0.2" julia = "1.10" - -[extras] -ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" From 3877cf0a1a1ae11d0e315478bef13ce49aa1a668 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 17:31:51 +0000 Subject: [PATCH 09/17] Apply some basic tests --- test/Project.toml | 1 + test/logdensityfunction.jl | 12 ++++++++++++ test/runtests.jl | 1 + 3 files changed, 14 insertions(+) diff --git a/test/Project.toml b/test/Project.toml index fbf305db5..60dfcb70b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index beda767e6..5a2798ab8 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -31,6 +31,18 @@ end θ = varinfo[:] @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ) + + # Test a single backend on the generic + # ADgradient(::AbstractADType, ::LogDensityFunction) method. This really just + # checks that it runs at all. + if varinfo isa DynamicPPL.TypedVarInfo + ad = ADTypes.AutoMooncake(; config=nothing) + ∇ℓ = LogDensityProblemsAD.ADgradient(ad, logdensity) + @test isa( + LogDensityProblems.logdensity_and_gradient(∇ℓ, θ), + Tuple{Float64, Vector{Float64}}, + ) + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index 1bf30cb71..dbfa319b0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,7 @@ using DynamicPPL using AbstractMCMC using AbstractPPL using Bijectors +using DifferentiationInterface using Distributions using DistributionsAD using Documenter From 46bbf06226086f2ce37ecdf62486686a63c34deb Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 28 Nov 2024 17:43:38 +0000 Subject: [PATCH 10/17] Locate tests better --- test/ad.jl | 22 ++++++++++++++++------ test/logdensityfunction.jl | 12 ------------ 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index 6046cfda4..d8064a6f1 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -1,4 +1,4 @@ -@testset "AD: ForwardDiff and ReverseDiff" begin +@testset "AD: ForwardDiff, ReverseDiff, and Mooncake" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS f = DynamicPPL.LogDensityFunction(m) rand_param_values = DynamicPPL.TestUtils.rand_prior_true(m) @@ -17,11 +17,21 @@ θ = convert(Vector{Float64}, varinfo[:]) logp, ref_grad = LogDensityProblems.logdensity_and_gradient(ad_forwarddiff_f, θ) - @testset "ReverseDiff with compile=$compile" for compile in (false, true) - adtype = ADTypes.AutoReverseDiff(; compile=compile) - ad_f = LogDensityProblemsAD.ADgradient(adtype, f) - _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) - @test grad ≈ ref_grad + @testset "$adtype" for adtype in [ + ADTypes.AutoReverseDiff(; compile=false), + ADTypes.AutoReverseDiff(; compile=true), + ADTypes.AutoMooncake(; config=nothing), + ] + # Mooncake can't currently handle something that is going on in + # SimpleVarInfo{<:VarNamedVector}. Disable tests for now. + if adtype isa ADTypes.AutoMooncake && + varinfo isa DynamicPPL.SimpleVarInfo{<:DynamicPPL.VarNamedVector} + @test_broken 1 == 0 + else + ad_f = LogDensityProblemsAD.ADgradient(adtype, f) + _, grad = LogDensityProblems.logdensity_and_gradient(ad_f, θ) + @test grad ≈ ref_grad + end end end end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 5a2798ab8..beda767e6 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -31,18 +31,6 @@ end θ = varinfo[:] @test LogDensityProblems.logdensity(logdensity, θ) ≈ logjoint(model, varinfo) @test LogDensityProblems.dimension(logdensity) == length(θ) - - # Test a single backend on the generic - # ADgradient(::AbstractADType, ::LogDensityFunction) method. This really just - # checks that it runs at all. - if varinfo isa DynamicPPL.TypedVarInfo - ad = ADTypes.AutoMooncake(; config=nothing) - ∇ℓ = LogDensityProblemsAD.ADgradient(ad, logdensity) - @test isa( - LogDensityProblems.logdensity_and_gradient(∇ℓ, θ), - Tuple{Float64, Vector{Float64}}, - ) - end end end end From 4a234d66d461e417d1e7a9c60edfa76485e8e075 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 29 Nov 2024 11:11:00 +0000 Subject: [PATCH 11/17] Internal _make_ad_gradient --- Project.toml | 5 ++++- ext/DynamicPPLMooncakeExt.jl | 8 ++++++++ ext/DynamicPPLReverseDiffExt.jl | 24 +++--------------------- src/logdensityfunction.jl | 2 +- 4 files changed, 16 insertions(+), 23 deletions(-) create mode 100644 ext/DynamicPPLMooncakeExt.jl diff --git a/Project.toml b/Project.toml index 754f56216..1bdc6820d 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" @@ -38,6 +39,7 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] +DynamicPPLMooncakeExt = ["Mooncake"] DynamicPPLReverseDiffExt = ["ReverseDiff"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] @@ -60,10 +62,11 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" +Mooncake = "0.4.52" OrderedCollections = "1" Random = "1.6" -Requires = "1" ReverseDiff = "1" +Requires = "1" Test = "1.6" ZygoteRules = "0.2" julia = "1.10" diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl new file mode 100644 index 000000000..400da3b20 --- /dev/null +++ b/ext/DynamicPPLMooncakeExt.jl @@ -0,0 +1,8 @@ +module DynamicPPLMooncakeExt + +import LogDensityProblemsAD: ADgradient +using DynamicPPL: ADTypes, _make_ad_gradient, LogDensityFunction + +ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) = _make_ad_gradient(ad, f) + +end # module diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl index 3fd174ed1..3728068ce 100644 --- a/ext/DynamicPPLReverseDiffExt.jl +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -1,26 +1,8 @@ module DynamicPPLReverseDiffExt -if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ReverseDiff -else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD - using ..ReverseDiff -end +import LogDensityProblemsAD: ADgradient +using DynamicPPL: ADTypes, _make_ad_gradient, LogDensityFunction -function LogDensityProblemsAD.ADgradient( - ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction -) where {Tcompile} - return LogDensityProblemsAD.ADgradient( - Val(:ReverseDiff), - ℓ; - compile=Val(Tcompile), - # `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0 - # because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473 - # `zero(D)` will return 0 when D is Real. - # here we use `identity` to possibly concretize the type to `Vector{Float64}` in the case of `Vector{Real}`. - x=map(identity, DynamicPPL.getparams(ℓ)), - ) -end +ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) = _make_ad_gradient(ad, f) end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 19d87fd5b..d47c6ccd4 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -149,7 +149,7 @@ LogDensityProblems.dimension(f::LogDensityFunction) = length(getparams(f)) # parameters, or DifferentiationInterface will not have sufficient information to e.g. # compile a rule for Mooncake (because it won't know the type of the input), or pre-allocate # a tape when using ReverseDiff.jl. -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) +function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params return LogDensityProblemsAD.ADgradient(ad, ℓ; x) end From 6fb7f9b26b56cb5db6b362ee99927f85e1d1c3db Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 29 Nov 2024 11:11:14 +0000 Subject: [PATCH 12/17] Mark failing tests as broken --- test/ad.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index d8064a6f1..f63a65945 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -23,9 +23,9 @@ ADTypes.AutoMooncake(; config=nothing), ] # Mooncake can't currently handle something that is going on in - # SimpleVarInfo{<:VarNamedVector}. Disable tests for now. + # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. if adtype isa ADTypes.AutoMooncake && - varinfo isa DynamicPPL.SimpleVarInfo{<:DynamicPPL.VarNamedVector} + varinfo isa DynamicPPL.SimpleVarInfo @test_broken 1 == 0 else ad_f = LogDensityProblemsAD.ADgradient(adtype, f) From 99532e05e94156ac3ad446f3d55397d021a1ad11 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 29 Nov 2024 16:53:33 +0000 Subject: [PATCH 13/17] Formatting --- test/ad.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/ad.jl b/test/ad.jl index f63a65945..768a55ad3 100644 --- a/test/ad.jl +++ b/test/ad.jl @@ -24,8 +24,7 @@ ] # Mooncake can't currently handle something that is going on in # SimpleVarInfo{<:VarNamedVector}. Disable all SimpleVarInfo tests for now. - if adtype isa ADTypes.AutoMooncake && - varinfo isa DynamicPPL.SimpleVarInfo + if adtype isa ADTypes.AutoMooncake && varinfo isa DynamicPPL.SimpleVarInfo @test_broken 1 == 0 else ad_f = LogDensityProblemsAD.ADgradient(adtype, f) From 21c2a0aa3bdc15b151186e1ae1d355676b0dbf67 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 2 Dec 2024 12:36:19 +0000 Subject: [PATCH 14/17] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 451a794ca..a4ec7fcbd 100644 --- a/Project.toml +++ b/Project.toml @@ -62,7 +62,7 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" -Mooncake = "0.4.52" +Mooncake = "0.4.54" OrderedCollections = "1" Random = "1.6" ReverseDiff = "1" From 73fbf34d0073e75deece91af8bf24688b86ca7fc Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Wed, 4 Dec 2024 18:42:20 +0000 Subject: [PATCH 15/17] Updates --- Project.toml | 6 ------ ext/DynamicPPLMooncakeExt.jl | 8 -------- ext/DynamicPPLReverseDiffExt.jl | 8 -------- src/logdensityfunction.jl | 7 +++++++ 4 files changed, 7 insertions(+), 22 deletions(-) delete mode 100644 ext/DynamicPPLMooncakeExt.jl delete mode 100644 ext/DynamicPPLReverseDiffExt.jl diff --git a/Project.toml b/Project.toml index a4ec7fcbd..fd8c62a92 100644 --- a/Project.toml +++ b/Project.toml @@ -30,8 +30,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" -Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [extensions] @@ -39,8 +37,6 @@ DynamicPPLChainRulesCoreExt = ["ChainRulesCore"] DynamicPPLEnzymeCoreExt = ["EnzymeCore"] DynamicPPLForwardDiffExt = ["ForwardDiff"] DynamicPPLMCMCChainsExt = ["MCMCChains"] -DynamicPPLMooncakeExt = ["Mooncake"] -DynamicPPLReverseDiffExt = ["ReverseDiff"] DynamicPPLZygoteRulesExt = ["ZygoteRules"] [compat] @@ -62,10 +58,8 @@ LogDensityProblems = "2" LogDensityProblemsAD = "1.7.0" MCMCChains = "6" MacroTools = "0.5.6" -Mooncake = "0.4.54" OrderedCollections = "1" Random = "1.6" -ReverseDiff = "1" Requires = "1" Test = "1.6" ZygoteRules = "0.2" diff --git a/ext/DynamicPPLMooncakeExt.jl b/ext/DynamicPPLMooncakeExt.jl deleted file mode 100644 index 400da3b20..000000000 --- a/ext/DynamicPPLMooncakeExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module DynamicPPLMooncakeExt - -import LogDensityProblemsAD: ADgradient -using DynamicPPL: ADTypes, _make_ad_gradient, LogDensityFunction - -ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) = _make_ad_gradient(ad, f) - -end # module diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl deleted file mode 100644 index 3728068ce..000000000 --- a/ext/DynamicPPLReverseDiffExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module DynamicPPLReverseDiffExt - -import LogDensityProblemsAD: ADgradient -using DynamicPPL: ADTypes, _make_ad_gradient, LogDensityFunction - -ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) = _make_ad_gradient(ad, f) - -end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index d47c6ccd4..214369ab0 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -153,3 +153,10 @@ function _make_ad_gradient(ad::ADTypes.AbstractADType, ℓ::LogDensityFunction) x = map(identity, getparams(ℓ)) # ensure we concretise the elements of the params return LogDensityProblemsAD.ADgradient(ad, ℓ; x) end + +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoMooncake, f::LogDensityFunction) + return _make_ad_gradient(ad, f) +end +function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, f::LogDensityFunction) + return _make_ad_gradient(ad, f) +end From a785c5cf60014867c7c2d5de92b64d4c3fd63520 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Thu, 5 Dec 2024 14:40:43 +0000 Subject: [PATCH 16/17] Bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index da69345a2..af9c6eee1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.1" +version = "0.31.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" From 5b7ab97877621a9a909b55c05d763ff0709e3031 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 6 Dec 2024 22:04:20 +0000 Subject: [PATCH 17/17] Bump patch again --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index af9c6eee1..909be870f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.31.2" +version = "0.31.3" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"