From 3b5afedc3ef91c8c536cd8002487b31739fd7add Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Jan 2023 18:19:34 +0000 Subject: [PATCH 01/10] use context from model in default constructor for LogDensityFunction --- src/logdensityfunction.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 2c836ca07..a012e1b9e 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -60,7 +60,7 @@ end function LogDensityFunction( model::Model, varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=DefaultContext(), + context::AbstractContext=model.context, ) return LogDensityFunction(varinfo, model, context) end From 25e3c41f49f3ab466026654395a9bae92b5a6e39 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Jan 2023 18:21:50 +0000 Subject: [PATCH 02/10] move varname_leaves to main module because of its usefulness --- src/test_utils.jl | 22 ++-------------------- src/utils.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 605952d88..b5f3a80b3 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -10,26 +10,8 @@ using Random: Random using Bijectors: Bijectors using Setfield: Setfield -""" - varname_leaves(vn::VarName, val) - -Return iterator over all varnames that are represented by `vn` on `val`, -e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. -""" -varname_leaves(vn::VarName, val::Real) = [vn] -function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for - I in CartesianIndices(val) - ) -end -function varname_leaves(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varname_leaves( - VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] - ) for I in CartesianIndices(val) - ) -end +# For backwards compat. +using DynamicPPL: varname_leaves """ update_values!!(vi::AbstractVarInfo, vals::NamedTuple, vns) diff --git a/src/utils.jl b/src/utils.jl index 8f076efee..7da2ccf63 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -740,3 +740,31 @@ infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_ # No need + causes issues for some AD backends, e.g. Zygote. ChainRulesCore.@non_differentiable infer_nested_eltype(x) + +""" + varname_leaves(vn::VarName, val) + +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varname_leaves(vn::VarName, ::Real) = [vn] +function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return ( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + I in CartesianIndices(val) + ) +end +function varname_leaves(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varname_leaves( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] + ) for I in CartesianIndices(val) + ) +end +function varname_leaves(vn::DynamicPPL.VarName, val::NamedTuple) + iter = Iterators.map(keys(val)) do sym + lens = DynamicPPL.Setfield.PropertyLens{sym}() + varname_leaves(vn ∘ lens, get(val, lens)) + end + return Iterators.flatten(iter) +end From 8ad15941f1e29ca0ddb5cffeb9348577ad5c93cd Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Jan 2023 18:26:31 +0000 Subject: [PATCH 03/10] bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 8ddb095d6..562180c08 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.21.5" +version = "0.21.6" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 8520975e2aed09340aa908fb26ab161d856932bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Jan 2023 20:32:41 +0000 Subject: [PATCH 04/10] Apply suggestions from code review Co-authored-by: David Widmann --- src/utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 7da2ccf63..f201d527f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -744,26 +744,26 @@ ChainRulesCore.@non_differentiable infer_nested_eltype(x) """ varname_leaves(vn::VarName, val) -Return iterator over all varnames that are represented by `vn` on `val`, +Return an iterator over all varnames that are represented by `vn` on `val`, e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. """ varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( - VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for I in CartesianIndices(val) ) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( varname_leaves( - VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] + VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] ) for I in CartesianIndices(val) ) end function varname_leaves(vn::DynamicPPL.VarName, val::NamedTuple) iter = Iterators.map(keys(val)) do sym - lens = DynamicPPL.Setfield.PropertyLens{sym}() + lens = Setfield.PropertyLens{sym}() varname_leaves(vn ∘ lens, get(val, lens)) end return Iterators.flatten(iter) From 9f61ffe40a5d495d85e99ca78ae2ccbd1a04bba0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Jan 2023 20:46:45 +0000 Subject: [PATCH 05/10] added doctests for varname_leaves --- src/utils.jl | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 7da2ccf63..6d925303a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -744,8 +744,23 @@ ChainRulesCore.@non_differentiable infer_nested_eltype(x) """ varname_leaves(vn::VarName, val) -Return iterator over all varnames that are represented by `vn` on `val`, -e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +Return iterator over all varnames that are represented by `vn` on `val`. + +# Examples +```jldoctest +julia> foreach(println, varname_leaves(@varname(x), rand(2))) +x[1] +x[2] + +julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) +x[1:2][1] +x[1:2][2] + +julia> foreach(println, varname_leaves(@varname(x), x)) +x.y +x.z[1][1] +x.z[2][1] +``` """ varname_leaves(vn::VarName, ::Real) = [vn] function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) From b5aafb5289768e492410d322c90c8c2cff63fe69 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Jan 2023 20:50:27 +0000 Subject: [PATCH 06/10] added test for LogDensityFunction regarding respecting contextualized models --- src/logdensityfunction.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index a012e1b9e..007dfef11 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -10,7 +10,7 @@ $(FIELDS) ```jldoctest julia> using Distributions -julia> using DynamicPPL: LogDensityFunction +julia> using DynamicPPL: LogDensityFunction, contextualize julia> @model function demo(x) m ~ Normal() @@ -36,6 +36,12 @@ julia> # By default it uses `VarInfo` under the hood, but this is not necessary. julia> LogDensityProblems.logdensity(f, [0.0]) -2.3378770664093453 + +julia> # This also respects the context in `model`. + f_prior = LogDensityFunction(contextualize(model, DynamicPPL.PriorContext()), VarInfo(model)); + +julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0) +true ``` """ struct LogDensityFunction{V,M,C} From b4e690eb92d2547490b6c968428d4ad0ed101f3f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 23 Jan 2023 20:51:00 +0000 Subject: [PATCH 07/10] Update src/utils.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/utils.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 0eb6e62b3..174d84913 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -771,9 +771,8 @@ function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) end function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varname_leaves( - VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] - ) for I in CartesianIndices(val) + varname_leaves(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for + I in CartesianIndices(val) ) end function varname_leaves(vn::DynamicPPL.VarName, val::NamedTuple) From fa239ea259c2b3d5b55eb0890aa66a0ba90af16d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Jan 2023 21:39:56 +0000 Subject: [PATCH 08/10] fixed doctest --- src/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 174d84913..2f3262bc3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -748,6 +748,8 @@ Return an iterator over all varnames that are represented by `vn` on `val`. # Examples ```jldoctest +julia> using DynamicPPL: varname_leaves + julia> foreach(println, varname_leaves(@varname(x), rand(2))) x[1] x[2] From 77ca04803b33c0a596a4a9d58339e83346aae125 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Jan 2023 22:57:43 +0000 Subject: [PATCH 09/10] okay now fixed doctests --- src/utils.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/utils.jl b/src/utils.jl index 2f3262bc3..7c2bd11a9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -758,6 +758,8 @@ julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) x[1:2][1] x[1:2][2] +julia> x = (y = 1, z = [[2.0], [3.0]]) + julia> foreach(println, varname_leaves(@varname(x), x)) x.y x.z[1][1] From 9d3bea6c35a394dfd20aef245eed06eab8c6e172 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 24 Jan 2023 23:02:37 +0000 Subject: [PATCH 10/10] maybe now --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index 7c2bd11a9..78595cb90 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -758,7 +758,7 @@ julia> foreach(println, varname_leaves(@varname(x[1:2]), rand(2))) x[1:2][1] x[1:2][2] -julia> x = (y = 1, z = [[2.0], [3.0]]) +julia> x = (y = 1, z = [[2.0], [3.0]]); julia> foreach(println, varname_leaves(@varname(x), x)) x.y