From e233f6a4c9486324adaa0919c55cba98ad910821 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 1 Jul 2024 12:45:33 +0100 Subject: [PATCH 01/12] initial copy and paste --- ext/DynamicPPLReverseDiffExt.jl | 13 +++++++++++-- src/logdensityfunction.jl | 25 +++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl index b2b378d45..ced36a6c3 100644 --- a/ext/DynamicPPLReverseDiffExt.jl +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -1,10 +1,10 @@ module DynamicPPLReverseDiffExt if isdefined(Base, :get_extension) - using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using DynamicPPL: Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD using ReverseDiff else - using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ..DynamicPPL: Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD using ..ReverseDiff end @@ -23,4 +23,13 @@ function LogDensityProblemsAD.ADgradient( ) end +function DynamicPPL.setmodel(f::LogDensityProblemsAD.ReverseDiffLogDensity{L,Nothing}, model::DynamicPPL.Model) where {L} + return Accessors.@set f.ℓ = setmodel(f.ℓ, model) +end + +function DynamicPPL.setmodel(f::LogDensityProblemsAD.ReverseDiffLogDensity{L,C}, model::DynamicPPL.Model) where {L,C} + new_f = LogDensityProblemsAD.ADGradient(Val(:ReverseDiff), f.ℓ; compile=Val(true)) # TODO: without a input, can get error + return Accessors.@set new_f.ℓ = setmodel(f.ℓ, model) +end + end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 8935edc12..23ab6a3ef 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -76,6 +76,31 @@ function getcontext(f::LogDensityFunction) return f.context === nothing ? leafcontext(f.model.context) : f.context end +""" + getmodel(f) + +Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. +""" +getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f)) +getmodel(f::DynamicPPL.LogDensityFunction) = f.model + +""" + setmodel(f, model) + +Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. + +!!! warning + Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a + `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` + might require recompilation of the gradient tape, depending on the AD backend. +""" +function setmodel(f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model) + return Accessors.@set f.ℓ = setmodel(f.ℓ, model) +end +function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) + return Accessors.@set f.model = model +end + # HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time # we need to define these annoying methods to ensure that we stay compatible with everything. getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) From 66367a7710921896cb5d1dd1eff5fccc1ce6dced Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Mon, 1 Jul 2024 14:13:28 +0100 Subject: [PATCH 02/12] add some test --- test/logdensityfunction.jl | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index ea70ace29..763bd9fd8 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,5 +1,23 @@ using Test, DynamicPPL, LogDensityProblems +@testset "`getmodel` and `setmodel`" begin + # TODO: does it worth to test all demo models? + model = DynamicPPL.TestUtils.DEMO_MODELS[1] + ℓ = DynamicPPL.LogDensityFunction(model) + @test DynamicPPL.getmodel(ℓ) == model + @test DynamicPPL.setmodel(ℓ, model).model == model + + # ReverseDiff related + ∇ℓ = LogDensityProblems.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) + @test DynamicPPL.getmodel(∇ℓ) == model + @test getmodel(DynamicPPL.setmodel(∇ℓ, model)) == model + + ∇ℓ = LogDensityProblems.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) + new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model) + @test DynamicPPL.getmodel(new_∇ℓ) == model + @test new_∇ℓ.ℓ.compiledtape != ∇ℓ.ℓ.compiledtape +end + @testset "LogDensityFunction" begin @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS example_values = DynamicPPL.TestUtils.rand_prior_true(model) From 7345a88ad3e38ae5065ad79b046d25c521240251 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Mon, 1 Jul 2024 21:15:18 +0800 Subject: [PATCH 03/12] Update ext/DynamicPPLReverseDiffExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/DynamicPPLReverseDiffExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl index ced36a6c3..d365d7845 100644 --- a/ext/DynamicPPLReverseDiffExt.jl +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -1,7 +1,8 @@ module DynamicPPLReverseDiffExt if isdefined(Base, :get_extension) - using DynamicPPL: Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using DynamicPPL: + Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD using ReverseDiff else using ..DynamicPPL: Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD From e978bc844078e0551c51587f9153e43af507088e Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Mon, 1 Jul 2024 21:15:22 +0800 Subject: [PATCH 04/12] Update ext/DynamicPPLReverseDiffExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/DynamicPPLReverseDiffExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl index d365d7845..2586408c9 100644 --- a/ext/DynamicPPLReverseDiffExt.jl +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -5,7 +5,8 @@ if isdefined(Base, :get_extension) Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD using ReverseDiff else - using ..DynamicPPL: Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ..DynamicPPL: + Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD using ..ReverseDiff end From 50889968f9988fe7541f157657ad1abe536e8154 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Mon, 1 Jul 2024 21:15:35 +0800 Subject: [PATCH 05/12] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- ext/DynamicPPLReverseDiffExt.jl | 8 ++++++-- test/logdensityfunction.jl | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl index 2586408c9..a1970bd71 100644 --- a/ext/DynamicPPLReverseDiffExt.jl +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -25,11 +25,15 @@ function LogDensityProblemsAD.ADgradient( ) end -function DynamicPPL.setmodel(f::LogDensityProblemsAD.ReverseDiffLogDensity{L,Nothing}, model::DynamicPPL.Model) where {L} +function DynamicPPL.setmodel( + f::LogDensityProblemsAD.ReverseDiffLogDensity{L,Nothing}, model::DynamicPPL.Model +) where {L} return Accessors.@set f.ℓ = setmodel(f.ℓ, model) end -function DynamicPPL.setmodel(f::LogDensityProblemsAD.ReverseDiffLogDensity{L,C}, model::DynamicPPL.Model) where {L,C} +function DynamicPPL.setmodel( + f::LogDensityProblemsAD.ReverseDiffLogDensity{L,C}, model::DynamicPPL.Model +) where {L,C} new_f = LogDensityProblemsAD.ADGradient(Val(:ReverseDiff), f.ℓ; compile=Val(true)) # TODO: without a input, can get error return Accessors.@set new_f.ℓ = setmodel(f.ℓ, model) end diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 763bd9fd8..890308cc9 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -11,7 +11,6 @@ using Test, DynamicPPL, LogDensityProblems ∇ℓ = LogDensityProblems.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) @test DynamicPPL.getmodel(∇ℓ) == model @test getmodel(DynamicPPL.setmodel(∇ℓ, model)) == model - ∇ℓ = LogDensityProblems.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model) @test DynamicPPL.getmodel(new_∇ℓ) == model From 338d0e34071aaf605acfb92f90d81c85611da37a Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Tue, 16 Jul 2024 17:37:41 +0100 Subject: [PATCH 06/12] update to the new implementation according to Turing --- ext/DynamicPPLReverseDiffExt.jl | 19 ++----------------- src/logdensityfunction.jl | 26 +++++++++++++++----------- 2 files changed, 17 insertions(+), 28 deletions(-) diff --git a/ext/DynamicPPLReverseDiffExt.jl b/ext/DynamicPPLReverseDiffExt.jl index a1970bd71..b2b378d45 100644 --- a/ext/DynamicPPLReverseDiffExt.jl +++ b/ext/DynamicPPLReverseDiffExt.jl @@ -1,12 +1,10 @@ module DynamicPPLReverseDiffExt if isdefined(Base, :get_extension) - using DynamicPPL: - Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD using ReverseDiff else - using ..DynamicPPL: - Accessors, ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD + using ..DynamicPPL: ADTypes, DynamicPPL, LogDensityProblems, LogDensityProblemsAD using ..ReverseDiff end @@ -25,17 +23,4 @@ function LogDensityProblemsAD.ADgradient( ) end -function DynamicPPL.setmodel( - f::LogDensityProblemsAD.ReverseDiffLogDensity{L,Nothing}, model::DynamicPPL.Model -) where {L} - return Accessors.@set f.ℓ = setmodel(f.ℓ, model) -end - -function DynamicPPL.setmodel( - f::LogDensityProblemsAD.ReverseDiffLogDensity{L,C}, model::DynamicPPL.Model -) where {L,C} - new_f = LogDensityProblemsAD.ADGradient(Val(:ReverseDiff), f.ℓ; compile=Val(true)) # TODO: without a input, can get error - return Accessors.@set new_f.ℓ = setmodel(f.ℓ, model) -end - end # module diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 23ab6a3ef..d016a3d63 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -77,15 +77,7 @@ function getcontext(f::LogDensityFunction) end """ - getmodel(f) - -Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. -""" -getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f)) -getmodel(f::DynamicPPL.LogDensityFunction) = f.model - -""" - setmodel(f, model) + setmodel(f, model[, adtype]) Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. @@ -94,8 +86,20 @@ Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. `DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f` might require recompilation of the gradient tape, depending on the AD backend. """ -function setmodel(f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model) - return Accessors.@set f.ℓ = setmodel(f.ℓ, model) +function setmodel( + f::LogDensityProblemsAD.ADGradientWrapper, + model::DynamicPPL.Model, + adtype::ADTypes.AbstractADType +) + # TODO: Should we handle `SciMLBase.NoAD`? + # For an `ADGradientWrapper` we do the following: + # 1. Update the `Model` in the underlying `LogDensityFunction`. + # 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype` + # to ensure that the recompilation of gradient tapes, etc. also occur. For example, + # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just + # replacing the corresponding field with the new model won't be sufficient to obtain + # the correct gradients. + return LogDensityProblemsAD.ADgradient(adtype, setmodel(parent(f), model)) end function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) return Accessors.@set f.model = model From 63c5b765410291eda0ae7fa9aca7a7b0f7bba569 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Wed, 17 Jul 2024 00:40:02 +0800 Subject: [PATCH 07/12] Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index d016a3d63..87a55ea3a 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -89,7 +89,7 @@ Set the `DynamicPPL.Model` in the given log-density function `f` to `model`. function setmodel( f::LogDensityProblemsAD.ADGradientWrapper, model::DynamicPPL.Model, - adtype::ADTypes.AbstractADType + adtype::ADTypes.AbstractADType, ) # TODO: Should we handle `SciMLBase.NoAD`? # For an `ADGradientWrapper` we do the following: From e2d06eec8a15ab740a7fcea5ab91487ba8d56e69 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Wed, 17 Jul 2024 07:57:28 +0100 Subject: [PATCH 08/12] error fixes --- src/logdensityfunction.jl | 10 +++++++++- test/logdensityfunction.jl | 31 +++++++++++++++++-------------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 87a55ea3a..b3468d902 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -76,6 +76,14 @@ function getcontext(f::LogDensityFunction) return f.context === nothing ? leafcontext(f.model.context) : f.context end +""" + getmodel(f) + +Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. +""" +getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(LogDensityProblemsAD.parent(f)) +getmodel(f::DynamicPPL.LogDensityFunction) = f.model + """ setmodel(f, model[, adtype]) @@ -99,7 +107,7 @@ function setmodel( # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just # replacing the corresponding field with the new model won't be sufficient to obtain # the correct gradients. - return LogDensityProblemsAD.ADgradient(adtype, setmodel(parent(f), model)) + return LogDensityProblemsAD.ADgradient(adtype, setmodel(LogDensityProblemsAD.parent(f), model)) end function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) return Accessors.@set f.model = model diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 890308cc9..beda767e6 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -1,20 +1,23 @@ -using Test, DynamicPPL, LogDensityProblems +using Test, DynamicPPL, ADTypes, LogDensityProblems, LogDensityProblemsAD, ReverseDiff @testset "`getmodel` and `setmodel`" begin - # TODO: does it worth to test all demo models? - model = DynamicPPL.TestUtils.DEMO_MODELS[1] - ℓ = DynamicPPL.LogDensityFunction(model) - @test DynamicPPL.getmodel(ℓ) == model - @test DynamicPPL.setmodel(ℓ, model).model == model + @testset "$(nameof(model))" for model in DynamicPPL.TestUtils.DEMO_MODELS + model = DynamicPPL.TestUtils.DEMO_MODELS[1] + ℓ = DynamicPPL.LogDensityFunction(model) + @test DynamicPPL.getmodel(ℓ) == model + @test DynamicPPL.setmodel(ℓ, model).model == model - # ReverseDiff related - ∇ℓ = LogDensityProblems.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) - @test DynamicPPL.getmodel(∇ℓ) == model - @test getmodel(DynamicPPL.setmodel(∇ℓ, model)) == model - ∇ℓ = LogDensityProblems.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) - new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model) - @test DynamicPPL.getmodel(new_∇ℓ) == model - @test new_∇ℓ.ℓ.compiledtape != ∇ℓ.ℓ.compiledtape + # ReverseDiff related + ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(false)) + @test DynamicPPL.getmodel(∇ℓ) == model + @test DynamicPPL.getmodel(DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff())) == + model + ∇ℓ = LogDensityProblemsAD.ADgradient(:ReverseDiff, ℓ; compile=Val(true)) + new_∇ℓ = DynamicPPL.setmodel(∇ℓ, model, AutoReverseDiff()) + @test DynamicPPL.getmodel(new_∇ℓ) == model + # HACK(sunxd): rely on internal implementation detail, i.e., naming of `compiledtape` + @test new_∇ℓ.compiledtape != ∇ℓ.compiledtape + end end @testset "LogDensityFunction" begin From d48a5ffe21be4c5821f757c0300ea39bda43a83b Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Wed, 17 Jul 2024 15:06:09 +0800 Subject: [PATCH 09/12] Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/logdensityfunction.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index b3468d902..b3d68e61e 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -81,7 +81,8 @@ end Return the `DynamicPPL.Model` wrapped in the given log-density function `f`. """ -getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(LogDensityProblemsAD.parent(f)) +getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = + getmodel(LogDensityProblemsAD.parent(f)) getmodel(f::DynamicPPL.LogDensityFunction) = f.model """ From 1b763a03266a5444c6bff9c5d84fa81a2ede9632 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Wed, 17 Jul 2024 15:06:13 +0800 Subject: [PATCH 10/12] Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/logdensityfunction.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index b3d68e61e..9e86590fa 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -108,7 +108,9 @@ function setmodel( # ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just # replacing the corresponding field with the new model won't be sufficient to obtain # the correct gradients. - return LogDensityProblemsAD.ADgradient(adtype, setmodel(LogDensityProblemsAD.parent(f), model)) + return LogDensityProblemsAD.ADgradient( + adtype, setmodel(LogDensityProblemsAD.parent(f), model) + ) end function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model) return Accessors.@set f.model = model From f2d0234d0496458e3c8fe06ee2db059b96ce90be Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 18 Jul 2024 08:43:48 +0100 Subject: [PATCH 11/12] add `HypothesisTests` to turing test dep --- test/turing/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 501359253..ed2b08ce5 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -1,6 +1,7 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" +HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" From 496e358f474097e77b528b943022399b8dbd0407 Mon Sep 17 00:00:00 2001 From: Xianda Sun Date: Thu, 18 Jul 2024 14:32:22 +0100 Subject: [PATCH 12/12] version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 06cba77ab..7ee7d2f97 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.28.1" +version = "0.28.2" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"