From ef0ac6fbe7d28bb287cf6dc75960867a7501e985 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 5 Jun 2023 20:22:46 +0100 Subject: [PATCH 1/9] added methods for extracting parameter values for a model from a given chain --- src/DynamicPPL.jl | 4 +- src/model_utils.jl | 208 +++++++++++++++++++++++++++++++++++++++++++ test/turing/model.jl | 15 ++++ 3 files changed, 226 insertions(+), 1 deletion(-) create mode 100644 src/model_utils.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 594084d66..5fcba5d74 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -119,7 +119,8 @@ export AbstractVarInfo, decondition, # Convenience macros @addlogprob!, - @submodel + @submodel, + value_iterator_from_chain # Reexport using Distributions: loglikelihood @@ -165,5 +166,6 @@ include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") include("logdensityfunction.jl") +include("model_utils.jl") end # module diff --git a/src/model_utils.jl b/src/model_utils.jl new file mode 100644 index 000000000..fe511918a --- /dev/null +++ b/src/model_utils.jl @@ -0,0 +1,208 @@ +""" + varnames_in_chain(model:::Model, chain) + varnames_in_chain(varinfo::VarInfo, chain) + +Return `true` if all variable names in `model`/`varinfo` are in `chain`. +""" +varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) +function varnames_in_chain(varinfo::VarInfo, chain) + return all(vn -> varname_in_chain(varinfo, vn, chain), keys(varinfo)) +end + +""" + varnames_in_chain!(model::Model, chain, out) + varnames_in_chain!(varinfo::VarInfo, chain, out) + +Return `out` with `true` for all variable names in `model` that are in `chain`. +""" +function varnames_in_chain!(model::Model, chain, out) + return varnames_in_chain!(VarInfo(model), chain, out) +end +function varnames_in_chain!(varinfo::VarInfo, chain, out) + for vn in keys(varinfo) + varname_in_chain!(varinfo, vn, chain, 1, 1, out) + end + + return out +end + +""" + varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) + varname_in_chain(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx) + +Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`. +""" +function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) + return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx) +end + +function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx) + !haskey(varinfo, vn) && return false + return varname_in_chain(varinfo[vn], vn, chain, chain_idx, iteration_idx) +end + +function varname_in_chain(x, vn, chain, chain_idx, iteration_idx) + out = OrderedDict{VarName,Bool}() + varname_in_chain!(x, vn, chain, out, chain_idx, iteration_idx) + return all(values(out)) +end + +""" + varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) + varname_in_chain!(varinfo::VarInfo, vn, chain, out, chain_idx, iteration_idx) + +Return a dictionary mapping the varname `vn` to `true` if `vn` is in `chain` at +`chain_idx` and `iteration_idx`. + +If `chain_idx` and `iteration_idx` are not provided, then they default to `1`. + +This differs from [`varname_in_chain`](@ref) in that it returns a dictionary +rather than a single boolean. This can be quite useful for debugging purposes. +""" +function varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) + return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out) +end + +function varname_in_chain!( + vi::AbstractVarInfo, vn_parent, chain, out, chain_idx, iteration_idx +) + return varname_in_chain!(vi[vn_parent], vn_parent, chain, out, chain_idx, iteration_idx) +end + +function varname_in_chain!(x, vn_parent, chain, out, chain_idx, iteration_idx) + sym = Symbol(vn_parent) + out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx]) + return out +end + +function varname_in_chain!( + x::AbstractArray, vn_parent::VarName{sym}, chain, out, chain_idx, iteration_idx +) where {sym} + # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. + # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` + # to extract the value from the `chain`. + for vn in varname_leaves(VarName{sym}(), x) + # Update `out`, possibly in place, and return. + l = AbstractPPL.getlens(vn) + varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx) + end + return out +end + +""" + values_from_chain(model::Model, chain, chain_idx, iteration_idx) + values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) + +Return a dictionary mapping each variable name in `model`/`varinfo` to its +value in `chain` at `chain_idx` and `iteration_idx`. +""" +function values_from_chain(x, vn_parent, chain, chain_idx, iteration_idx) + # HACK: If it's not an array, we fall back to just returning the first value. + return only(chain[iteration_idx, Symbol(vn_parent), chain_idx]) +end +function values_from_chain( + x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx +) where {sym} + # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. + # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` + # to extract the value from the `chain`. + out = similar(x) + for vn in varname_leaves(VarName{sym}(), x) + # Update `out`, possibly in place, and return. + l = AbstractPPL.getlens(vn) + out = Setfield.set( + out, + BangBang.prefermutation(l), + chain[iteration_idx, Symbol(vn_parent ∘ l), chain_idx], + ) + end + + return out +end +function values_from_chain(vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx) + # Use the value `vi[vn_parent]` to obtain a buffer. + return values_from_chain(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx) +end + +""" + values_from_chain!(model::Model, chain, out, chain_idx, iteration_idx) + values_from_chain!(varinfo::VarInfo, chain, out, chain_idx, iteration_idx) + +Mutate `out` to map each variable name in `model`/`varinfo` to its value in +`chain` at `chain_idx` and `iteration_idx`. +""" +function values_from_chain!(model::DynamicPPL.Model, chain, out, chain_idx, iteration_idx) + return values_from_chain(VarInfo(model), chain, out, chain_idx, iteration_idx) +end + +function values_from_chain!(vi::AbstractVarInfo, chain, out, chain_idx, iteration_idx) + for vn in keys(vi) + out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx) + end + return out +end + +""" + value_iterator_from_chain(model::Model, chain) + value_iterator_from_chain(varinfo::AbstractVarInfo, chain) + +Return an iterator over the values in `chain` for each variable in `model`/`varinfo`. + +# Example +```jldoctest +julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs + +julia> rng = StableRNG(42); + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + + return s, m + end; + +julia> model = demo_model([1.0, 2.0]); + +julia> chain = Chains(rand(rng, 10, 2, 3), [:s, :m]); + +julia> iter = value_iterator_from_chain(model, chain); + +julia> first(iter) +OrderedDict{VarName, Any} with 2 entries: + s => 0.580515 + m => 0.739328 + +julia> collect(iter) +10×3 Matrix{OrderedDict{VarName, Any}}: + OrderedDict(s=>0.580515, m=>0.739328) … OrderedDict(s=>0.186047, m=>0.402423) + OrderedDict(s=>0.191241, m=>0.627342) OrderedDict(s=>0.776277, m=>0.166342) + OrderedDict(s=>0.971133, m=>0.637584) OrderedDict(s=>0.651655, m=>0.712044) + OrderedDict(s=>0.74345, m=>0.110359) OrderedDict(s=>0.469214, m=>0.104502) + OrderedDict(s=>0.170969, m=>0.598514) OrderedDict(s=>0.853546, m=>0.185399) + OrderedDict(s=>0.704776, m=>0.322111) … OrderedDict(s=>0.638301, m=>0.853802) + OrderedDict(s=>0.441044, m=>0.162285) OrderedDict(s=>0.852959, m=>0.0956922) + OrderedDict(s=>0.803972, m=>0.643369) OrderedDict(s=>0.245049, m=>0.871985) + OrderedDict(s=>0.772384, m=>0.646323) OrderedDict(s=>0.906603, m=>0.385502) + OrderedDict(s=>0.70882, m=>0.253105) OrderedDict(s=>0.413222, m=>0.953288) + +julia> # This can be used to `condition` a `Model`. + conditioned_model = model | first(iter); + +julia> conditioned_model() # <= results in same values as the `first(iter)` above +(0.5805148626851955, 0.7393275279160691) +``` +""" +function value_iterator_from_chain(model::DynamicPPL.Model, chain) + return value_iterator_from_chain(VarInfo(model), chain) +end + +function value_iterator_from_chain(vi::AbstractVarInfo, chain) + return Iterators.map( + Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + ) do (iteration_idx, chain_idx) + values_from_chain!(vi, chain, OrderedDict{VarName,Any}(), chain_idx, iteration_idx) + end +end diff --git a/test/turing/model.jl b/test/turing/model.jl index e27b177eb..b7d19285d 100644 --- a/test/turing/model.jl +++ b/test/turing/model.jl @@ -92,4 +92,19 @@ # Ensure that they're not all the same (some can be, because rejected samples) @test any(res34[1:(end - 1)] .!= res34[2:end]) end + + @testset "value_iterator_from_chain" begin + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = sample(model, Prior(), 10; progress=false) + for (i, d) in enumerate(value_iterator_from_chain(model, chain)) + for vn in keys(d) + val = DynamicPPL.getvalue(d, vn) + for vn_leaf in DynamicPPL.varname_leaves(vn, val) + val_leaf = DynamicPPL.getvalue(d, vn_leaf) + @test val_leaf == chain[i, Symbol(vn_leaf), 1] + end + end + end + end + end end From 979db2d8b14a8260a3f62fd30a5d745401b05564 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Jun 2023 15:20:56 +0100 Subject: [PATCH 2/9] added MCMCchains as a dep to docs --- docs/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index b6cce8b37..028ab7ac3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -14,6 +15,7 @@ Distributions = "0.25" Documenter = "0.27" FillArrays = "0.13, 1" LogDensityProblems = "2" +MCMCChains = "6" MLUtils = "0.3, 0.4" Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" From 7511e4f38ae1f5747d8fa990c5c2b22ed11815b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 01:41:50 +0100 Subject: [PATCH 3/9] attempt at fixing doctests --- src/model_utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index fe511918a..40eebb63b 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -162,7 +162,8 @@ julia> @model function demo_model(x) end return s, m - end; + end +demo_model (generic function with 2 methods) julia> model = demo_model([1.0, 2.0]); From 1a2dfc7dac2f3c494ea03e76ee598e030cf6ed11 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 01:51:11 +0100 Subject: [PATCH 4/9] remove the doctest as it's not working for some reason --- docs/Project.toml | 2 -- src/model_utils.jl | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 028ab7ac3..b6cce8b37 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,7 +4,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -15,7 +14,6 @@ Distributions = "0.25" Documenter = "0.27" FillArrays = "0.13, 1" LogDensityProblems = "2" -MCMCChains = "6" MLUtils = "0.3, 0.4" Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" diff --git a/src/model_utils.jl b/src/model_utils.jl index 40eebb63b..6600cfff0 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -149,7 +149,7 @@ end Return an iterator over the values in `chain` for each variable in `model`/`varinfo`. # Example -```jldoctest +```julia julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs julia> rng = StableRNG(42); From 642be6cb9b9eb9773a307615f077f3011508669d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 02:10:00 +0100 Subject: [PATCH 5/9] added docs --- docs/src/api.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 2dfda9119..e18455bcc 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -102,10 +102,17 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob pointwise_loglikelihoods ``` +For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). + +```@docs +value_iterator_from_chain +``` + ```@docs NamedDist ``` + ## Testing Utilities DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. From 49b6b931086052c0b92f028aae0c6d6662ed982b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 02:12:35 +0100 Subject: [PATCH 6/9] Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/api.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index e18455bcc..714d18e17 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -112,7 +112,6 @@ value_iterator_from_chain NamedDist ``` - ## Testing Utilities DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. From 2631287030227234e82374b884aa408e741b5988 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 17 Jul 2023 12:53:04 +0100 Subject: [PATCH 7/9] Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/api.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/src/api.md b/docs/src/api.md index 660f904b9..d2be26c56 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -136,6 +136,7 @@ For converting a chain into a format that can more easily be fed into a `Model` value_iterator_from_chain ``` + Sometimes it can be useful to extract the priors of a model. This is the possible using [`extract_priors`](@ref). ```@docs From ec3e65de3952aa3b7c66a0ebb1c561d63de73470 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 17 Jul 2023 16:30:39 +0100 Subject: [PATCH 8/9] fixed incorrect function call as pointed out by @YongchaoHuang --- src/model_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 6600cfff0..3f3be5209 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -6,7 +6,7 @@ Return `true` if all variable names in `model`/`varinfo` are in `chain`. """ varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) function varnames_in_chain(varinfo::VarInfo, chain) - return all(vn -> varname_in_chain(varinfo, vn, chain), keys(varinfo)) + return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo)) end """ From b62955333257c8e03c932b031b49cb43f0281730 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 17 Jul 2023 16:34:20 +0100 Subject: [PATCH 9/9] moved out argument to the end up of the function signature --- src/model_utils.jl | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 3f3be5209..3e686b917 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -43,13 +43,13 @@ end function varname_in_chain(x, vn, chain, chain_idx, iteration_idx) out = OrderedDict{VarName,Bool}() - varname_in_chain!(x, vn, chain, out, chain_idx, iteration_idx) + varname_in_chain!(x, vn, chain, chain_idx, iteration_idx, out) return all(values(out)) end """ - varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) - varname_in_chain!(varinfo::VarInfo, vn, chain, out, chain_idx, iteration_idx) + varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out) + varname_in_chain!(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx, out) Return a dictionary mapping the varname `vn` to `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`. @@ -59,24 +59,24 @@ If `chain_idx` and `iteration_idx` are not provided, then they default to `1`. This differs from [`varname_in_chain`](@ref) in that it returns a dictionary rather than a single boolean. This can be quite useful for debugging purposes. """ -function varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) +function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out) return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out) end function varname_in_chain!( - vi::AbstractVarInfo, vn_parent, chain, out, chain_idx, iteration_idx + vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx, out ) - return varname_in_chain!(vi[vn_parent], vn_parent, chain, out, chain_idx, iteration_idx) + return varname_in_chain!(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx, out) end -function varname_in_chain!(x, vn_parent, chain, out, chain_idx, iteration_idx) +function varname_in_chain!(x, vn_parent, chain, chain_idx, iteration_idx, out) sym = Symbol(vn_parent) out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx]) return out end function varname_in_chain!( - x::AbstractArray, vn_parent::VarName{sym}, chain, out, chain_idx, iteration_idx + x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out ) where {sym} # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` @@ -84,7 +84,7 @@ function varname_in_chain!( for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getlens(vn) - varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx) + varname_in_chain!(x, vn_parent ∘ l, chain, chain_idx, iteration_idx, out) end return out end @@ -125,17 +125,17 @@ function values_from_chain(vi::AbstractVarInfo, vn_parent, chain, chain_idx, ite end """ - values_from_chain!(model::Model, chain, out, chain_idx, iteration_idx) - values_from_chain!(varinfo::VarInfo, chain, out, chain_idx, iteration_idx) + values_from_chain!(model::Model, chain, chain_idx, iteration_idx, out) + values_from_chain!(varinfo::VarInfo, chain, chain_idx, iteration_idx, out) Mutate `out` to map each variable name in `model`/`varinfo` to its value in `chain` at `chain_idx` and `iteration_idx`. """ -function values_from_chain!(model::DynamicPPL.Model, chain, out, chain_idx, iteration_idx) - return values_from_chain(VarInfo(model), chain, out, chain_idx, iteration_idx) +function values_from_chain!(model::DynamicPPL.Model, chain, chain_idx, iteration_idx, out) + return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx, out) end -function values_from_chain!(vi::AbstractVarInfo, chain, out, chain_idx, iteration_idx) +function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out) for vn in keys(vi) out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx) end @@ -204,6 +204,6 @@ function value_iterator_from_chain(vi::AbstractVarInfo, chain) return Iterators.map( Iterators.product(1:size(chain, 1), 1:size(chain, 3)) ) do (iteration_idx, chain_idx) - values_from_chain!(vi, chain, OrderedDict{VarName,Any}(), chain_idx, iteration_idx) + values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}()) end end