From ef0ac6fbe7d28bb287cf6dc75960867a7501e985 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 5 Jun 2023 20:22:46 +0100 Subject: [PATCH 01/20] added methods for extracting parameter values for a model from a given chain --- src/DynamicPPL.jl | 4 +- src/model_utils.jl | 208 +++++++++++++++++++++++++++++++++++++++++++ test/turing/model.jl | 15 ++++ 3 files changed, 226 insertions(+), 1 deletion(-) create mode 100644 src/model_utils.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 594084d66..5fcba5d74 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -119,7 +119,8 @@ export AbstractVarInfo, decondition, # Convenience macros @addlogprob!, - @submodel + @submodel, + value_iterator_from_chain # Reexport using Distributions: loglikelihood @@ -165,5 +166,6 @@ include("submodel_macro.jl") include("test_utils.jl") include("transforming.jl") include("logdensityfunction.jl") +include("model_utils.jl") end # module diff --git a/src/model_utils.jl b/src/model_utils.jl new file mode 100644 index 000000000..fe511918a --- /dev/null +++ b/src/model_utils.jl @@ -0,0 +1,208 @@ +""" + varnames_in_chain(model:::Model, chain) + varnames_in_chain(varinfo::VarInfo, chain) + +Return `true` if all variable names in `model`/`varinfo` are in `chain`. +""" +varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) +function varnames_in_chain(varinfo::VarInfo, chain) + return all(vn -> varname_in_chain(varinfo, vn, chain), keys(varinfo)) +end + +""" + varnames_in_chain!(model::Model, chain, out) + varnames_in_chain!(varinfo::VarInfo, chain, out) + +Return `out` with `true` for all variable names in `model` that are in `chain`. +""" +function varnames_in_chain!(model::Model, chain, out) + return varnames_in_chain!(VarInfo(model), chain, out) +end +function varnames_in_chain!(varinfo::VarInfo, chain, out) + for vn in keys(varinfo) + varname_in_chain!(varinfo, vn, chain, 1, 1, out) + end + + return out +end + +""" + varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) + varname_in_chain(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx) + +Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`. +""" +function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) + return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx) +end + +function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx) + !haskey(varinfo, vn) && return false + return varname_in_chain(varinfo[vn], vn, chain, chain_idx, iteration_idx) +end + +function varname_in_chain(x, vn, chain, chain_idx, iteration_idx) + out = OrderedDict{VarName,Bool}() + varname_in_chain!(x, vn, chain, out, chain_idx, iteration_idx) + return all(values(out)) +end + +""" + varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) + varname_in_chain!(varinfo::VarInfo, vn, chain, out, chain_idx, iteration_idx) + +Return a dictionary mapping the varname `vn` to `true` if `vn` is in `chain` at +`chain_idx` and `iteration_idx`. + +If `chain_idx` and `iteration_idx` are not provided, then they default to `1`. + +This differs from [`varname_in_chain`](@ref) in that it returns a dictionary +rather than a single boolean. This can be quite useful for debugging purposes. +""" +function varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) + return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out) +end + +function varname_in_chain!( + vi::AbstractVarInfo, vn_parent, chain, out, chain_idx, iteration_idx +) + return varname_in_chain!(vi[vn_parent], vn_parent, chain, out, chain_idx, iteration_idx) +end + +function varname_in_chain!(x, vn_parent, chain, out, chain_idx, iteration_idx) + sym = Symbol(vn_parent) + out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx]) + return out +end + +function varname_in_chain!( + x::AbstractArray, vn_parent::VarName{sym}, chain, out, chain_idx, iteration_idx +) where {sym} + # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. + # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` + # to extract the value from the `chain`. + for vn in varname_leaves(VarName{sym}(), x) + # Update `out`, possibly in place, and return. + l = AbstractPPL.getlens(vn) + varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx) + end + return out +end + +""" + values_from_chain(model::Model, chain, chain_idx, iteration_idx) + values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) + +Return a dictionary mapping each variable name in `model`/`varinfo` to its +value in `chain` at `chain_idx` and `iteration_idx`. +""" +function values_from_chain(x, vn_parent, chain, chain_idx, iteration_idx) + # HACK: If it's not an array, we fall back to just returning the first value. + return only(chain[iteration_idx, Symbol(vn_parent), chain_idx]) +end +function values_from_chain( + x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx +) where {sym} + # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. + # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` + # to extract the value from the `chain`. + out = similar(x) + for vn in varname_leaves(VarName{sym}(), x) + # Update `out`, possibly in place, and return. + l = AbstractPPL.getlens(vn) + out = Setfield.set( + out, + BangBang.prefermutation(l), + chain[iteration_idx, Symbol(vn_parent ∘ l), chain_idx], + ) + end + + return out +end +function values_from_chain(vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx) + # Use the value `vi[vn_parent]` to obtain a buffer. + return values_from_chain(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx) +end + +""" + values_from_chain!(model::Model, chain, out, chain_idx, iteration_idx) + values_from_chain!(varinfo::VarInfo, chain, out, chain_idx, iteration_idx) + +Mutate `out` to map each variable name in `model`/`varinfo` to its value in +`chain` at `chain_idx` and `iteration_idx`. +""" +function values_from_chain!(model::DynamicPPL.Model, chain, out, chain_idx, iteration_idx) + return values_from_chain(VarInfo(model), chain, out, chain_idx, iteration_idx) +end + +function values_from_chain!(vi::AbstractVarInfo, chain, out, chain_idx, iteration_idx) + for vn in keys(vi) + out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx) + end + return out +end + +""" + value_iterator_from_chain(model::Model, chain) + value_iterator_from_chain(varinfo::AbstractVarInfo, chain) + +Return an iterator over the values in `chain` for each variable in `model`/`varinfo`. + +# Example +```jldoctest +julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs + +julia> rng = StableRNG(42); + +julia> @model function demo_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + for i in eachindex(x) + x[i] ~ Normal(m, sqrt(s)) + end + + return s, m + end; + +julia> model = demo_model([1.0, 2.0]); + +julia> chain = Chains(rand(rng, 10, 2, 3), [:s, :m]); + +julia> iter = value_iterator_from_chain(model, chain); + +julia> first(iter) +OrderedDict{VarName, Any} with 2 entries: + s => 0.580515 + m => 0.739328 + +julia> collect(iter) +10×3 Matrix{OrderedDict{VarName, Any}}: + OrderedDict(s=>0.580515, m=>0.739328) … OrderedDict(s=>0.186047, m=>0.402423) + OrderedDict(s=>0.191241, m=>0.627342) OrderedDict(s=>0.776277, m=>0.166342) + OrderedDict(s=>0.971133, m=>0.637584) OrderedDict(s=>0.651655, m=>0.712044) + OrderedDict(s=>0.74345, m=>0.110359) OrderedDict(s=>0.469214, m=>0.104502) + OrderedDict(s=>0.170969, m=>0.598514) OrderedDict(s=>0.853546, m=>0.185399) + OrderedDict(s=>0.704776, m=>0.322111) … OrderedDict(s=>0.638301, m=>0.853802) + OrderedDict(s=>0.441044, m=>0.162285) OrderedDict(s=>0.852959, m=>0.0956922) + OrderedDict(s=>0.803972, m=>0.643369) OrderedDict(s=>0.245049, m=>0.871985) + OrderedDict(s=>0.772384, m=>0.646323) OrderedDict(s=>0.906603, m=>0.385502) + OrderedDict(s=>0.70882, m=>0.253105) OrderedDict(s=>0.413222, m=>0.953288) + +julia> # This can be used to `condition` a `Model`. + conditioned_model = model | first(iter); + +julia> conditioned_model() # <= results in same values as the `first(iter)` above +(0.5805148626851955, 0.7393275279160691) +``` +""" +function value_iterator_from_chain(model::DynamicPPL.Model, chain) + return value_iterator_from_chain(VarInfo(model), chain) +end + +function value_iterator_from_chain(vi::AbstractVarInfo, chain) + return Iterators.map( + Iterators.product(1:size(chain, 1), 1:size(chain, 3)) + ) do (iteration_idx, chain_idx) + values_from_chain!(vi, chain, OrderedDict{VarName,Any}(), chain_idx, iteration_idx) + end +end diff --git a/test/turing/model.jl b/test/turing/model.jl index e27b177eb..b7d19285d 100644 --- a/test/turing/model.jl +++ b/test/turing/model.jl @@ -92,4 +92,19 @@ # Ensure that they're not all the same (some can be, because rejected samples) @test any(res34[1:(end - 1)] .!= res34[2:end]) end + + @testset "value_iterator_from_chain" begin + @testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS + chain = sample(model, Prior(), 10; progress=false) + for (i, d) in enumerate(value_iterator_from_chain(model, chain)) + for vn in keys(d) + val = DynamicPPL.getvalue(d, vn) + for vn_leaf in DynamicPPL.varname_leaves(vn, val) + val_leaf = DynamicPPL.getvalue(d, vn_leaf) + @test val_leaf == chain[i, Symbol(vn_leaf), 1] + end + end + end + end + end end From 979db2d8b14a8260a3f62fd30a5d745401b05564 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 6 Jun 2023 15:20:56 +0100 Subject: [PATCH 02/20] added MCMCchains as a dep to docs --- docs/Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index b6cce8b37..028ab7ac3 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,6 +4,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -14,6 +15,7 @@ Distributions = "0.25" Documenter = "0.27" FillArrays = "0.13, 1" LogDensityProblems = "2" +MCMCChains = "6" MLUtils = "0.3, 0.4" Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" From 7511e4f38ae1f5747d8fa990c5c2b22ed11815b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 01:41:50 +0100 Subject: [PATCH 03/20] attempt at fixing doctests --- src/model_utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index fe511918a..40eebb63b 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -162,7 +162,8 @@ julia> @model function demo_model(x) end return s, m - end; + end +demo_model (generic function with 2 methods) julia> model = demo_model([1.0, 2.0]); From 1a2dfc7dac2f3c494ea03e76ee598e030cf6ed11 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 01:51:11 +0100 Subject: [PATCH 04/20] remove the doctest as it's not working for some reason --- docs/Project.toml | 2 -- src/model_utils.jl | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 028ab7ac3..b6cce8b37 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,7 +4,6 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" @@ -15,7 +14,6 @@ Distributions = "0.25" Documenter = "0.27" FillArrays = "0.13, 1" LogDensityProblems = "2" -MCMCChains = "6" MLUtils = "0.3, 0.4" Setfield = "0.7.1, 0.8, 1" StableRNGs = "1" diff --git a/src/model_utils.jl b/src/model_utils.jl index 40eebb63b..6600cfff0 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -149,7 +149,7 @@ end Return an iterator over the values in `chain` for each variable in `model`/`varinfo`. # Example -```jldoctest +```julia julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs julia> rng = StableRNG(42); From 642be6cb9b9eb9773a307615f077f3011508669d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 02:10:00 +0100 Subject: [PATCH 05/20] added docs --- docs/src/api.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 2dfda9119..e18455bcc 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -102,10 +102,17 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob pointwise_loglikelihoods ``` +For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref). + +```@docs +value_iterator_from_chain +``` + ```@docs NamedDist ``` + ## Testing Utilities DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. From 49b6b931086052c0b92f028aae0c6d6662ed982b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 7 Jun 2023 02:12:35 +0100 Subject: [PATCH 06/20] Update docs/src/api.md Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- docs/src/api.md | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index e18455bcc..714d18e17 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -112,7 +112,6 @@ value_iterator_from_chain NamedDist ``` - ## Testing Utilities DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. From 299087e2ca417099942d61da06b6405d6d1607d4 Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 4 Jul 2023 18:01:00 +0100 Subject: [PATCH 07/20] added new functions for `varnames_in_chain` and `values_from_chain`. --- src/model_utils.jl | 125 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/src/model_utils.jl b/src/model_utils.jl index 6600cfff0..c6f372bfc 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -1,3 +1,128 @@ +### Yong ############################################################## +# Yong added the below new functions on 2023-07-04, they are doing the some functionalities as Tor's functions. Some redundancy needs to be removed? +using Turing, Distributions, DynamicPPL, MCMCChains, Test + +#### 1. varname_in_chain #### +# here we just check if vn and its leaves are present in the chain; we are not checking its presence in model. So we don't need to pass model or varinfo to this function. +""" + varname_in_chain(vn::VarName, chain, chain_idx, iteration_idx) + +Return `true` if `vn` or any of `vn_child` is in `chain` at `chain_idx` and `iteration_idx`; also returned is the dictionary containing the names related to `vn` presented in the chain, if any. +""" +function varname_in_chain(vn::VarName, chain, chain_idx=1, iteration_idx=1) + out = OrderedDict{Symbol,Bool}() + for vn_child in namesingroup(chain, Symbol(vn)) # namesingroup: https://github.com/TuringLang/MCMCChains.jl/blob/master/src/chains.jl + # print("\n $vn_child of $vn is in chain") + out[vn_child] = Symbol(vn_child) ∈ names(chain) && !ismissing(chain[iteration_idx, Symbol(vn_child), chain_idx]) + end + return !isempty(out), out +end + +#### 2. varnames_in_chain #### +# we iteratively test whether each of keys(VarInfo(model)) is present in the chain or not +""" + varnames_in_chain(model:::Model, chain) + varnames_in_chain(varinfo::VarInfo, chain) + +Return `true` if all variable names in `model`/`varinfo` are in `chain`; also returned is the dictionary containing the names related to `vn` presented in the chain, if any. +""" +varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) +function varnames_in_chain(varinfo::VarInfo, chain) + out_logical = OrderedDict() + out = OrderedDict() + for vn in keys(varinfo) + out_logical[Symbol(vn)], out[Symbol(vn)] = varname_in_chain(vn, chain, 1, 1) + end + return all(values(out_logical)), out +end + +#### 3. values_from_chain #### +""" + vn_values_from_chain(vn, chain, chain_idx, iteration_idx) + +Return `true` if `vn` or any of its leaves is in `chain`; also returned is the dictionary containing the names related to `vn` presented in the chain, if any. +""" +function vn_values_from_chain(vn::VarName, chain, chain_idx, iteration_idx) + out = OrderedDict() + # no need to test if varname_in_chain(vn, chain)[1] - if vn is not in chain, then out will be empty. + for vn_child in namesingroup(chain, Symbol(vn)) + try + out[vn_child] = chain[iteration_idx, Symbol(vn_child), chain_idx] + catch + println("Error: retrieve value for $vn_child using chain[$iteration_idx, Symbol($vn_child), $chain_idx] not successful!") + end + end + return !isempty(out), out +end + +""" + values_from_chain(model:::Model, chain) + values_from_chain(varinfo::VarInfo, chain) + +Return a dictionary containing the values of all variables in `model`/`varinfo` presented in `chain`, if any. +""" +values_from_chain(model::Model, chain, chain_idx, iteration_idx) = values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx) +function values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) + out = OrderedDict() + for vn in keys(varinfo) + _, out_vn = vn_values_from_chain(vn, chain, chain_idx, iteration_idx) + merge!(out, out_vn) + end + return out +end + +""" + values_from_chain(varinfo, chain, chain_idx, iteration_idx_range) + +Return a dictionary containing the values of all variables in `model`/`varinfo` presented in `chain`, as per iteration_idx_range. +""" +values_from_chain(model::Model, chain, chain_idx_range, iteration_idx_range) = values_from_chain(VarInfo(model), chain, chain_idx_range, iteration_idx_range) +function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::UnitRange) + all_out = OrderedDict() + for chain_idx in chain_idx_range + out = OrderedDict() + for vn in keys(varinfo) + for iteration_idx in iteration_idx_range + _, out_vn = vn_values_from_chain(vn, chain, chain_idx, iteration_idx) + for key in keys(out_vn) + if haskey(out, key) + out[key] = vcat(out[key], out_vn[key]) + else + out[key] = out_vn[key] + end + end + end + end + all_out["chain_idx_"*string(chain_idx)] = out + end + return all_out +end +function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::UnitRange) + return values_from_chain(varinfo, chain, chain_idx_range:chain_idx_range, iteration_idx_range) +end +function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::Int) + return values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range:iteration_idx_range) +end +function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::Int) # this is equivalent to values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) + return values_from_chain(varinfo, chain, chain_idx_range:chain_idx_range, iteration_idx_range:iteration_idx_range) +end +## if either chain_idx_range and iteration_idx_range not specified, then all chains will be included. +function values_from_chain(varinfo::VarInfo, chain, chain_idx_range, iteration_idx_range) + if chain_idx_range === nothing + print("chain_idx_range is missing!") + chain_idx_range = 1:size(chain)[3] + end + if iteration_idx_range === nothing + print("iteration_idx_range is missing!") + iteration_idx_range = 1:size(chain)[1] + end + return values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range) +end + +### Tor ############################################################## + +""" + Tor(model::Model, chain, chain_idx, iteration_idx) """ varnames_in_chain(model:::Model, chain) varnames_in_chain(varinfo::VarInfo, chain) From 545f7274051d6ee182e70866dcc50081c0ed34cd Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Wed, 5 Jul 2023 09:16:53 +0100 Subject: [PATCH 08/20] cleaned up . --- src/model_utils.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index c6f372bfc..94e037836 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -120,9 +120,6 @@ function values_from_chain(varinfo::VarInfo, chain, chain_idx_range, iteration_i end ### Tor ############################################################## - -""" - Tor(model::Model, chain, chain_idx, iteration_idx) """ varnames_in_chain(model:::Model, chain) varnames_in_chain(varinfo::VarInfo, chain) From 20bdea78b653f182da6c7d98267696bb40d0df1d Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Wed, 5 Jul 2023 10:30:59 +0100 Subject: [PATCH 09/20] added doctests for the new functions. --- src/model_utils.jl | 181 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 170 insertions(+), 11 deletions(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 94e037836..7c5652285 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -1,13 +1,44 @@ ### Yong ############################################################## # Yong added the below new functions on 2023-07-04, they are doing the some functionalities as Tor's functions. Some redundancy needs to be removed? -using Turing, Distributions, DynamicPPL, MCMCChains, Test +using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test #### 1. varname_in_chain #### # here we just check if vn and its leaves are present in the chain; we are not checking its presence in model. So we don't need to pass model or varinfo to this function. """ varname_in_chain(vn::VarName, chain, chain_idx, iteration_idx) -Return `true` if `vn` or any of `vn_child` is in `chain` at `chain_idx` and `iteration_idx`; also returned is the dictionary containing the names related to `vn` presented in the chain, if any. +Return two outputs: + - first output: logical `true` if `vn` or ANY of `vn_child` is in `chain` at `chain_idx` and `iteration_idx`; + - second output: a dictionary containing all leaf names of `vn` from the chain, if any. + +# Example +```julia +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) + end +test_model (generic function with 2 methods) + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> varname_in_chain(VarName(:s), chain, 1, 1) +(true, OrderedDict{Symbol, Bool}(:s => 1)) + +julia> varname_in_chain(VarName(:m), chain, 1, 1) +(true, OrderedDict{Symbol, Bool}(:m => 1)) + +julia> varname_in_chain(VarName(:x), chain, 1, 1) +(false, OrderedDict{Symbol, Bool}()) + +``` """ function varname_in_chain(vn::VarName, chain, chain_idx=1, iteration_idx=1) out = OrderedDict{Symbol,Bool}() @@ -19,19 +50,44 @@ function varname_in_chain(vn::VarName, chain, chain_idx=1, iteration_idx=1) end #### 2. varnames_in_chain #### -# we iteratively test whether each of keys(VarInfo(model)) is present in the chain or not +# we iteratively check whether each of keys(VarInfo(model)) is present in the chain or not using `varname_in_chain`. """ varnames_in_chain(model:::Model, chain) varnames_in_chain(varinfo::VarInfo, chain) -Return `true` if all variable names in `model`/`varinfo` are in `chain`; also returned is the dictionary containing the names related to `vn` presented in the chain, if any. +Return two outputs: + - first output: logical `true` if ALL variable names in `model`/`varinfo` are in `chain`; + - second output: a dictionary containing all leaf names of `vn` presented in the chain, if any. + +# Example +```julia +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) +end +test_model (generic function with 2 methods + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> varnames_in_chain(model, chain) +(true, OrderedDict{Any, Any}(:s => OrderedDict{Symbol, Bool}(:s => 1), :m => OrderedDict{Symbol, Bool}(:m => 1))) + +``` """ varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) function varnames_in_chain(varinfo::VarInfo, chain) out_logical = OrderedDict() out = OrderedDict() for vn in keys(varinfo) - out_logical[Symbol(vn)], out[Symbol(vn)] = varname_in_chain(vn, chain, 1, 1) + out_logical[Symbol(vn)], out[Symbol(vn)] = varname_in_chain(vn, chain, 1, 1) # by default, we check the first chain and the first iteration. end return all(values(out_logical)), out end @@ -40,11 +96,42 @@ end """ vn_values_from_chain(vn, chain, chain_idx, iteration_idx) -Return `true` if `vn` or any of its leaves is in `chain`; also returned is the dictionary containing the names related to `vn` presented in the chain, if any. +Return two outputs: + - first output: logical `true` if `vn` or ANY of its leaves is in `chain` + - second output: a dictionary containing all leaf names (if any) of `vn` and their values at `chain_idx`, `iteration_idx` from the chain. + +# Example +```julia +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) +end +test_model (generic function with 2 methods) + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> vn_values_from_chain(VarName(:s), chain, 1, 1) +(true, OrderedDict{Any, Any}(:s => 1.385664578516751)) + +julia> vn_values_from_chain(VarName(:m), chain, 1, 1) +(true, OrderedDict{Any, Any}(:m => 0.9529550916018266)) + +julia> vn_values_from_chain(VarName(:x), chain, 1, 1) +(false, OrderedDict{Any, Any}()) + +``` """ function vn_values_from_chain(vn::VarName, chain, chain_idx, iteration_idx) out = OrderedDict() - # no need to test if varname_in_chain(vn, chain)[1] - if vn is not in chain, then out will be empty. + # no need to check if varname_in_chain(vn, chain)[1] - if vn is not in chain, then out will be empty. for vn_child in namesingroup(chain, Symbol(vn)) try out[vn_child] = chain[iteration_idx, Symbol(vn_child), chain_idx] @@ -59,7 +146,33 @@ end values_from_chain(model:::Model, chain) values_from_chain(varinfo::VarInfo, chain) -Return a dictionary containing the values of all variables in `model`/`varinfo` presented in `chain`, if any. +Return one output: + - a dictionary containing the (leaves_name, value) pair of ALL parameters in `model`/`varinfo` at `chain_idx`, `iteration_idx` from the chain (if ANY of the leaves of `vn` is present in the chain). + +# Example +```julia + +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) +end + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> values_from_chain(model, chain, 1, 1) +OrderedDict{Any, Any} with 2 entries: + :s => 1.38566 + :m => 0.952955 + +``` """ values_from_chain(model::Model, chain, chain_idx, iteration_idx) = values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx) function values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) @@ -72,9 +185,55 @@ function values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) end """ - values_from_chain(varinfo, chain, chain_idx, iteration_idx_range) + values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range) + +Return one output: + - a dictionary containing the values of all leaf names of all parameters in `model`/`varinfo` within `chain_idx_range` and `iteration_idx_range``. -Return a dictionary containing the values of all variables in `model`/`varinfo` presented in `chain`, as per iteration_idx_range. +# Example +```julia + +julia> using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test + +julia> Random.seed!(111) +MersenneTwister(111) + +julia> @model function test_model(x) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x ~ Normal(m, sqrt(s)) +end + +julia> model = test_model(1.5) + +julia> chain = sample(model, NUTS(), 100) + +julia> values_from_chain(model, chain, 1:2, 1:10) +Error: retrieve value for s using chain[1, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[2, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[3, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[4, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[5, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[6, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[7, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[8, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[9, Symbol(s), 2] not successful! +Error: retrieve value for s using chain[10, Symbol(s), 2] not successful! +Error: retrieve value for m using chain[1, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[2, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[3, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[4, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[5, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[6, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[7, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[8, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[9, Symbol(m), 2] not successful! +Error: retrieve value for m using chain[10, Symbol(m), 2] not successful! +OrderedDict{Any, Any} with 2 entries: + "chain_idx_1" => OrderedDict{Any, Any}(:s=>[1.38566, 1.6544, 1.36912, 1.18434, 1.33485, 1.966… + "chain_idx_2" => OrderedDict{Any, Any}() + +``` """ values_from_chain(model::Model, chain, chain_idx_range, iteration_idx_range) = values_from_chain(VarInfo(model), chain, chain_idx_range, iteration_idx_range) function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::UnitRange) @@ -106,7 +265,7 @@ end function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::Int) # this is equivalent to values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) return values_from_chain(varinfo, chain, chain_idx_range:chain_idx_range, iteration_idx_range:iteration_idx_range) end -## if either chain_idx_range and iteration_idx_range not specified, then all chains will be included. +# if either chain_idx_range or iteration_idx_range is specified as `nothing`, then all chains will be included. function values_from_chain(varinfo::VarInfo, chain, chain_idx_range, iteration_idx_range) if chain_idx_range === nothing print("chain_idx_range is missing!") From de5f684821ddf99d6a5bc7aa364e2fad152537cb Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Thu, 6 Jul 2023 18:59:43 +0100 Subject: [PATCH 10/20] added `test/model_utils.jl`. --- test/model_utils.jl | 116 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 116 insertions(+) create mode 100644 test/model_utils.jl diff --git a/test/model_utils.jl b/test/model_utils.jl new file mode 100644 index 000000000..c6fa99af3 --- /dev/null +++ b/test/model_utils.jl @@ -0,0 +1,116 @@ +using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test +Random.seed!(111) + +#### prepare the models and chains for testing #### +# (1) manually create a chain using MCMCChains - we know what parameter names are in the chain +val = [1 2; 3 4; 5 6; 7 8; 9 10; 11 12; 13 14; 15 16; 17 18; 19 20] +val = Matrix{Int}(val) +chn_1 = Chains(val, [:s, :m]) + +# (2) sample a Turing model to create a chain +@model function gdemo(x) + mu ~ MvNormal([0, 0, 0], [1 0 0; 0 1 0; 0 0 1]) + x ~ MvNormal(mu, [1 0 0; 0 1 0; 0 0 1]) +end +model_2 = gdemo([0, 0, 0]) # provide an initial value for `x` +chn_2 = sample(model_2, NUTS(), 100) # NB: the parameter names in an MCMCChains can be retrieved using `namechn_2.name_map[:parameters]_map`: https://github.com/TuringLang/MCMCChains.jl/blob/master/src/chains.jl + +# (3) sample the demo models in DynamicPPL to create a chain - we do not know what parameter names are in the chain beforehand +model_3 = DynamicPPL.TestUtils.DEMO_MODELS[12] +var_info = VarInfo(model_3) +vns = DynamicPPL.TestUtils.varnames(model_3) +# generate a chain. +N = 100 +vals_OrderedDict = mapreduce(hcat, 1:N) do _ + rand(OrderedDict, model_3) +end +vals_mat = mapreduce(hcat, 1:N) do i + [vals_OrderedDict[i][vn] for vn in vns] +end +i = 1 +for col in eachcol(vals_mat) + col_flattened = [] + [push!(col_flattened, x...) for x in col] + if i == 1 + chain_mat = Matrix(reshape(col_flattened, 1, length(col_flattened))) + else + chain_mat = vcat(chain_mat, reshape(col_flattened, 1, length(col_flattened))) + end + i += 1 +end +chain_mat = convert(Matrix{Float64}, chain_mat) +# devise parameter names for chain +sample_values_vec = collect(values(vals_OrderedDict[1])) +symbol_names = [] +chain_sym_map = Dict() +for k in 1:length(keys(var_info)) + vn_parent = keys(var_info)[k] + sym = DynamicPPL.getsym(vn_parent) + vn_children = DynamicPPL.varname_leaves(vn_parent, sample_values_vec[k]) + for vn_child in vn_children + chain_sym_map[Symbol(vn_child)] = sym + symbol_names = [symbol_names; Symbol(vn_child)] + end +end +chn_3 = Chains(chain_mat, symbol_names) + +#### test functions in model_utils.jl #### +@testset "model_utils.jl" begin + @testset "varname_in_chain" begin + # chn_1 + outputs = varname_in_chain(VarName(:s), chn_1, 1, 1) + @test outputs[1] == true && outputs[2][:s] == true + outputs = varname_in_chain(VarName(:m), chn_1, 1, 1) + @test outputs[1] == true && outputs[2][:m] == true + outputs = varname_in_chain(VarName(:x), chn_1, 1, 1) + @test outputs[1] == false && isempty(outputs[2]) + + # chn_2 + outputs = varname_in_chain(VarName(:mu), chn_2, 1, 1) + @test outputs[1] == true && !isempty(outputs[2]) && all(values(outputs[2]) .== 1) + + # chn_3 + outputs = varname_in_chain(VarName(:a), chn_3, 1, 1) + @test outputs[1] == false && isempty(outputs[2]) + outputs = varname_in_chain(VarName(symbol_names[1]), chn_3, 1, 1) + @test !isempty(outputs[2]) && all(values(outputs[2]) .== 1) # note: all(values(outputs[2]) .== 1) is not enough - even an empty dictionary has all(values(outputs[2]) .== 1) + end + @testset "varnames_in_chain" begin + outputs = varnames_in_chain(model_2, chn_2) + @test outputs[1] == true && all(values(outputs[2][:mu])) + outputs = varnames_in_chain(model_3, chn_3) + @test outputs[1] == true + end + @testset "vn_values_from_chain" begin + outputs = vn_values_from_chain(VarName(:mu), chn_2, 1, 1) + @test outputs[1] == true && length(values(outputs[2])) == 3 + outputs = vn_values_from_chain(VarName(:s), chn_3, 1, 1) + @test outputs[1] == true && length(values(outputs[2])) == 2 + outputs = vn_values_from_chain(VarName(Symbol("s[:,1][1]")), chn_3, 1, 2) + @test outputs[1] == true + end + @testset "values_from_chain" begin + output = values_from_chain(model_2, chn_2, 1, 1) + @test length(output["chain_idx_1"]) == 3 + output = values_from_chain(model_3, chn_3, 1, 1) + @test length(output["chain_idx_1"]) == 4 + + output = values_from_chain(model_2, chn_2, 1, 1:10) + @test length(output["chain_idx_1"]) == 3 && all([length(vals) == 10 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_3, chn_3, 1, 1:10) + @test length(output["chain_idx_1"]) == 4 && all([length(vals) == 10 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_2, chn_2, 1, 1) + @test length(output["chain_idx_1"]) == 3 && all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_2, chn_2, 1:1, 1) + @test length(keys(output))==1 && length(output["chain_idx_1"]) == 3 && all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_2, chn_2, 1, 1:1) + @test length(output["chain_idx_1"]) == 3 && all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_2, chn_2, 1:1, 1:1) + @test length(keys(output))==1 && length(output["chain_idx_1"]) == 3 && all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + + output = values_from_chain(model_2, chn_2, nothing, nothing) + @test length(keys(output))==1 && length(output["chain_idx_1"]) == 3 && all([length(vals) == 100 for vals in values(output["chain_idx_1"])]) + output = values_from_chain(model_3, chn_3, nothing, nothing) + @test length(keys(output))==1 && length(output["chain_idx_1"]) == 4 && all([length(vals) == 100 for vals in values(output["chain_idx_1"])]) + end +end \ No newline at end of file From e6888c8fa012c466a25ac2bc159b3d081f739192 Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Thu, 6 Jul 2023 21:35:16 +0100 Subject: [PATCH 11/20] only new functions are kept. --- src/model_utils.jl | 213 +-------------------------------------------- 1 file changed, 1 insertion(+), 212 deletions(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 7c5652285..70f16b6a6 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -276,215 +276,4 @@ function values_from_chain(varinfo::VarInfo, chain, chain_idx_range, iteration_i iteration_idx_range = 1:size(chain)[1] end return values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range) -end - -### Tor ############################################################## -""" - varnames_in_chain(model:::Model, chain) - varnames_in_chain(varinfo::VarInfo, chain) - -Return `true` if all variable names in `model`/`varinfo` are in `chain`. -""" -varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) -function varnames_in_chain(varinfo::VarInfo, chain) - return all(vn -> varname_in_chain(varinfo, vn, chain), keys(varinfo)) -end - -""" - varnames_in_chain!(model::Model, chain, out) - varnames_in_chain!(varinfo::VarInfo, chain, out) - -Return `out` with `true` for all variable names in `model` that are in `chain`. -""" -function varnames_in_chain!(model::Model, chain, out) - return varnames_in_chain!(VarInfo(model), chain, out) -end -function varnames_in_chain!(varinfo::VarInfo, chain, out) - for vn in keys(varinfo) - varname_in_chain!(varinfo, vn, chain, 1, 1, out) - end - - return out -end - -""" - varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) - varname_in_chain(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx) - -Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`. -""" -function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) - return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx) -end - -function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx) - !haskey(varinfo, vn) && return false - return varname_in_chain(varinfo[vn], vn, chain, chain_idx, iteration_idx) -end - -function varname_in_chain(x, vn, chain, chain_idx, iteration_idx) - out = OrderedDict{VarName,Bool}() - varname_in_chain!(x, vn, chain, out, chain_idx, iteration_idx) - return all(values(out)) -end - -""" - varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) - varname_in_chain!(varinfo::VarInfo, vn, chain, out, chain_idx, iteration_idx) - -Return a dictionary mapping the varname `vn` to `true` if `vn` is in `chain` at -`chain_idx` and `iteration_idx`. - -If `chain_idx` and `iteration_idx` are not provided, then they default to `1`. - -This differs from [`varname_in_chain`](@ref) in that it returns a dictionary -rather than a single boolean. This can be quite useful for debugging purposes. -""" -function varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) - return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out) -end - -function varname_in_chain!( - vi::AbstractVarInfo, vn_parent, chain, out, chain_idx, iteration_idx -) - return varname_in_chain!(vi[vn_parent], vn_parent, chain, out, chain_idx, iteration_idx) -end - -function varname_in_chain!(x, vn_parent, chain, out, chain_idx, iteration_idx) - sym = Symbol(vn_parent) - out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx]) - return out -end - -function varname_in_chain!( - x::AbstractArray, vn_parent::VarName{sym}, chain, out, chain_idx, iteration_idx -) where {sym} - # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. - # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` - # to extract the value from the `chain`. - for vn in varname_leaves(VarName{sym}(), x) - # Update `out`, possibly in place, and return. - l = AbstractPPL.getlens(vn) - varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx) - end - return out -end - -""" - values_from_chain(model::Model, chain, chain_idx, iteration_idx) - values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) - -Return a dictionary mapping each variable name in `model`/`varinfo` to its -value in `chain` at `chain_idx` and `iteration_idx`. -""" -function values_from_chain(x, vn_parent, chain, chain_idx, iteration_idx) - # HACK: If it's not an array, we fall back to just returning the first value. - return only(chain[iteration_idx, Symbol(vn_parent), chain_idx]) -end -function values_from_chain( - x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx -) where {sym} - # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. - # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` - # to extract the value from the `chain`. - out = similar(x) - for vn in varname_leaves(VarName{sym}(), x) - # Update `out`, possibly in place, and return. - l = AbstractPPL.getlens(vn) - out = Setfield.set( - out, - BangBang.prefermutation(l), - chain[iteration_idx, Symbol(vn_parent ∘ l), chain_idx], - ) - end - - return out -end -function values_from_chain(vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx) - # Use the value `vi[vn_parent]` to obtain a buffer. - return values_from_chain(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx) -end - -""" - values_from_chain!(model::Model, chain, out, chain_idx, iteration_idx) - values_from_chain!(varinfo::VarInfo, chain, out, chain_idx, iteration_idx) - -Mutate `out` to map each variable name in `model`/`varinfo` to its value in -`chain` at `chain_idx` and `iteration_idx`. -""" -function values_from_chain!(model::DynamicPPL.Model, chain, out, chain_idx, iteration_idx) - return values_from_chain(VarInfo(model), chain, out, chain_idx, iteration_idx) -end - -function values_from_chain!(vi::AbstractVarInfo, chain, out, chain_idx, iteration_idx) - for vn in keys(vi) - out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx) - end - return out -end - -""" - value_iterator_from_chain(model::Model, chain) - value_iterator_from_chain(varinfo::AbstractVarInfo, chain) - -Return an iterator over the values in `chain` for each variable in `model`/`varinfo`. - -# Example -```julia -julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs - -julia> rng = StableRNG(42); - -julia> @model function demo_model(x) - s ~ InverseGamma(2, 3) - m ~ Normal(0, sqrt(s)) - for i in eachindex(x) - x[i] ~ Normal(m, sqrt(s)) - end - - return s, m - end -demo_model (generic function with 2 methods) - -julia> model = demo_model([1.0, 2.0]); - -julia> chain = Chains(rand(rng, 10, 2, 3), [:s, :m]); - -julia> iter = value_iterator_from_chain(model, chain); - -julia> first(iter) -OrderedDict{VarName, Any} with 2 entries: - s => 0.580515 - m => 0.739328 - -julia> collect(iter) -10×3 Matrix{OrderedDict{VarName, Any}}: - OrderedDict(s=>0.580515, m=>0.739328) … OrderedDict(s=>0.186047, m=>0.402423) - OrderedDict(s=>0.191241, m=>0.627342) OrderedDict(s=>0.776277, m=>0.166342) - OrderedDict(s=>0.971133, m=>0.637584) OrderedDict(s=>0.651655, m=>0.712044) - OrderedDict(s=>0.74345, m=>0.110359) OrderedDict(s=>0.469214, m=>0.104502) - OrderedDict(s=>0.170969, m=>0.598514) OrderedDict(s=>0.853546, m=>0.185399) - OrderedDict(s=>0.704776, m=>0.322111) … OrderedDict(s=>0.638301, m=>0.853802) - OrderedDict(s=>0.441044, m=>0.162285) OrderedDict(s=>0.852959, m=>0.0956922) - OrderedDict(s=>0.803972, m=>0.643369) OrderedDict(s=>0.245049, m=>0.871985) - OrderedDict(s=>0.772384, m=>0.646323) OrderedDict(s=>0.906603, m=>0.385502) - OrderedDict(s=>0.70882, m=>0.253105) OrderedDict(s=>0.413222, m=>0.953288) - -julia> # This can be used to `condition` a `Model`. - conditioned_model = model | first(iter); - -julia> conditioned_model() # <= results in same values as the `first(iter)` above -(0.5805148626851955, 0.7393275279160691) -``` -""" -function value_iterator_from_chain(model::DynamicPPL.Model, chain) - return value_iterator_from_chain(VarInfo(model), chain) -end - -function value_iterator_from_chain(vi::AbstractVarInfo, chain) - return Iterators.map( - Iterators.product(1:size(chain, 1), 1:size(chain, 3)) - ) do (iteration_idx, chain_idx) - values_from_chain!(vi, chain, OrderedDict{VarName,Any}(), chain_idx, iteration_idx) - end -end +end \ No newline at end of file From 77c37c9a039fdc5ed735cee629a216cf14f73d5a Mon Sep 17 00:00:00 2001 From: YongchaoHuang <34540771+YongchaoHuang@users.noreply.github.com> Date: Fri, 7 Jul 2023 18:37:05 +0100 Subject: [PATCH 12/20] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/model_utils.jl | 51 +++++++++++++++++++++++++++++++++------------ test/model_utils.jl | 28 ++++++++++++++++++------- 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 70f16b6a6..9d0f4870c 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -44,7 +44,9 @@ function varname_in_chain(vn::VarName, chain, chain_idx=1, iteration_idx=1) out = OrderedDict{Symbol,Bool}() for vn_child in namesingroup(chain, Symbol(vn)) # namesingroup: https://github.com/TuringLang/MCMCChains.jl/blob/master/src/chains.jl # print("\n $vn_child of $vn is in chain") - out[vn_child] = Symbol(vn_child) ∈ names(chain) && !ismissing(chain[iteration_idx, Symbol(vn_child), chain_idx]) + out[vn_child] = + Symbol(vn_child) ∈ names(chain) && + !ismissing(chain[iteration_idx, Symbol(vn_child), chain_idx]) end return !isempty(out), out end @@ -136,7 +138,9 @@ function vn_values_from_chain(vn::VarName, chain, chain_idx, iteration_idx) try out[vn_child] = chain[iteration_idx, Symbol(vn_child), chain_idx] catch - println("Error: retrieve value for $vn_child using chain[$iteration_idx, Symbol($vn_child), $chain_idx] not successful!") + println( + "Error: retrieve value for $vn_child using chain[$iteration_idx, Symbol($vn_child), $chain_idx] not successful!", + ) end end return !isempty(out), out @@ -174,7 +178,9 @@ OrderedDict{Any, Any} with 2 entries: ``` """ -values_from_chain(model::Model, chain, chain_idx, iteration_idx) = values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx) +function values_from_chain(model::Model, chain, chain_idx, iteration_idx) + return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx) +end function values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) out = OrderedDict() for vn in keys(varinfo) @@ -235,8 +241,12 @@ OrderedDict{Any, Any} with 2 entries: ``` """ -values_from_chain(model::Model, chain, chain_idx_range, iteration_idx_range) = values_from_chain(VarInfo(model), chain, chain_idx_range, iteration_idx_range) -function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::UnitRange) +function values_from_chain(model::Model, chain, chain_idx_range, iteration_idx_range) + return values_from_chain(VarInfo(model), chain, chain_idx_range, iteration_idx_range) +end +function values_from_chain( + varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::UnitRange +) all_out = OrderedDict() for chain_idx in chain_idx_range out = OrderedDict() @@ -252,18 +262,33 @@ function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::UnitRange, end end end - all_out["chain_idx_"*string(chain_idx)] = out + all_out["chain_idx_" * string(chain_idx)] = out end return all_out end -function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::UnitRange) - return values_from_chain(varinfo, chain, chain_idx_range:chain_idx_range, iteration_idx_range) +function values_from_chain( + varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::UnitRange +) + return values_from_chain( + varinfo, chain, chain_idx_range:chain_idx_range, iteration_idx_range + ) end -function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::Int) - return values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range:iteration_idx_range) +function values_from_chain( + varinfo::VarInfo, chain, chain_idx_range::UnitRange, iteration_idx_range::Int +) + return values_from_chain( + varinfo, chain, chain_idx_range, iteration_idx_range:iteration_idx_range + ) end -function values_from_chain(varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::Int) # this is equivalent to values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) - return values_from_chain(varinfo, chain, chain_idx_range:chain_idx_range, iteration_idx_range:iteration_idx_range) +function values_from_chain( + varinfo::VarInfo, chain, chain_idx_range::Int, iteration_idx_range::Int +) # this is equivalent to values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) + return values_from_chain( + varinfo, + chain, + chain_idx_range:chain_idx_range, + iteration_idx_range:iteration_idx_range, + ) end # if either chain_idx_range or iteration_idx_range is specified as `nothing`, then all chains will be included. function values_from_chain(varinfo::VarInfo, chain, chain_idx_range, iteration_idx_range) @@ -272,7 +297,7 @@ function values_from_chain(varinfo::VarInfo, chain, chain_idx_range, iteration_i chain_idx_range = 1:size(chain)[3] end if iteration_idx_range === nothing - print("iteration_idx_range is missing!") + print("iteration_idx_range is missing!") iteration_idx_range = 1:size(chain)[1] end return values_from_chain(varinfo, chain, chain_idx_range, iteration_idx_range) diff --git a/test/model_utils.jl b/test/model_utils.jl index c6fa99af3..701f760bb 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -96,21 +96,33 @@ chn_3 = Chains(chain_mat, symbol_names) @test length(output["chain_idx_1"]) == 4 output = values_from_chain(model_2, chn_2, 1, 1:10) - @test length(output["chain_idx_1"]) == 3 && all([length(vals) == 10 for vals in values(output["chain_idx_1"])]) + @test length(output["chain_idx_1"]) == 3 && + all([length(vals) == 10 for vals in values(output["chain_idx_1"])]) output = values_from_chain(model_3, chn_3, 1, 1:10) - @test length(output["chain_idx_1"]) == 4 && all([length(vals) == 10 for vals in values(output["chain_idx_1"])]) + @test length(output["chain_idx_1"]) == 4 && + all([length(vals) == 10 for vals in values(output["chain_idx_1"])]) output = values_from_chain(model_2, chn_2, 1, 1) - @test length(output["chain_idx_1"]) == 3 && all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + @test length(output["chain_idx_1"]) == 3 && + all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) output = values_from_chain(model_2, chn_2, 1:1, 1) - @test length(keys(output))==1 && length(output["chain_idx_1"]) == 3 && all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + @test length(keys(output)) == 1 && + length(output["chain_idx_1"]) == 3 && + all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) output = values_from_chain(model_2, chn_2, 1, 1:1) - @test length(output["chain_idx_1"]) == 3 && all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + @test length(output["chain_idx_1"]) == 3 && + all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) output = values_from_chain(model_2, chn_2, 1:1, 1:1) - @test length(keys(output))==1 && length(output["chain_idx_1"]) == 3 && all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) + @test length(keys(output)) == 1 && + length(output["chain_idx_1"]) == 3 && + all([length(vals) == 1 for vals in values(output["chain_idx_1"])]) output = values_from_chain(model_2, chn_2, nothing, nothing) - @test length(keys(output))==1 && length(output["chain_idx_1"]) == 3 && all([length(vals) == 100 for vals in values(output["chain_idx_1"])]) + @test length(keys(output)) == 1 && + length(output["chain_idx_1"]) == 3 && + all([length(vals) == 100 for vals in values(output["chain_idx_1"])]) output = values_from_chain(model_3, chn_3, nothing, nothing) - @test length(keys(output))==1 && length(output["chain_idx_1"]) == 4 && all([length(vals) == 100 for vals in values(output["chain_idx_1"])]) + @test length(keys(output)) == 1 && + length(output["chain_idx_1"]) == 4 && + all([length(vals) == 100 for vals in values(output["chain_idx_1"])]) end end \ No newline at end of file From 140045baac68ce884dedc2a3802d6e4dd695023a Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 11 Jul 2023 13:23:06 +0100 Subject: [PATCH 13/20] added `using Turing` in `DynamicPPL.jl` --- src/DynamicPPL.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 51c2c8761..3b5ef5139 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -21,6 +21,8 @@ using DocStringExtensions using Random: Random +using Turing + import Base: Symbol, ==, From d75a6595cd16071d615a479eef597772321b64f6 Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 11 Jul 2023 13:43:23 +0100 Subject: [PATCH 14/20] update `Project.toml`. --- docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Project.toml b/docs/Project.toml index b6cce8b37..54e8b3f82 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,6 +7,7 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DataStructures = "0.18" From 05085f4d839fd4c44019737b4987022aaadee35b Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 11 Jul 2023 13:50:18 +0100 Subject: [PATCH 15/20] applied reviewdog suggestion. --- test/model_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/model_utils.jl b/test/model_utils.jl index 701f760bb..a1e949f68 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -10,7 +10,7 @@ chn_1 = Chains(val, [:s, :m]) # (2) sample a Turing model to create a chain @model function gdemo(x) mu ~ MvNormal([0, 0, 0], [1 0 0; 0 1 0; 0 0 1]) - x ~ MvNormal(mu, [1 0 0; 0 1 0; 0 0 1]) + return x ~ MvNormal(mu, [1 0 0; 0 1 0; 0 0 1]) end model_2 = gdemo([0, 0, 0]) # provide an initial value for `x` chn_2 = sample(model_2, NUTS(), 100) # NB: the parameter names in an MCMCChains can be retrieved using `namechn_2.name_map[:parameters]_map`: https://github.com/TuringLang/MCMCChains.jl/blob/master/src/chains.jl From db24f4eeba7cfd2390f4c69f0b559aebce9a52aa Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 11 Jul 2023 14:46:55 +0100 Subject: [PATCH 16/20] removed `using Turing`. --- docs/Project.toml | 1 - src/DynamicPPL.jl | 2 -- 2 files changed, 3 deletions(-) diff --git a/docs/Project.toml b/docs/Project.toml index 54e8b3f82..b6cce8b37 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -7,7 +7,6 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" -Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DataStructures = "0.18" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 3b5ef5139..51c2c8761 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -21,8 +21,6 @@ using DocStringExtensions using Random: Random -using Turing - import Base: Symbol, ==, From f70004453319b795b6629d3a60009dda4f17bdd1 Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 11 Jul 2023 14:46:59 +0100 Subject: [PATCH 17/20] removed `USING tURING` --- .vscode/settings.json | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..0217c48c2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "julia.environmentPath": "/home/yongchao/.julia/dev/DynamicPPL/docs" +} \ No newline at end of file From a4d111ffacc49e97a97507fa8a8c71f121d4cd28 Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 11 Jul 2023 14:52:49 +0100 Subject: [PATCH 18/20] removed Turing dependency in `src/model_utils.jl` --- src/model_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 9d0f4870c..d101c8f7d 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -1,6 +1,6 @@ ### Yong ############################################################## # Yong added the below new functions on 2023-07-04, they are doing the some functionalities as Tor's functions. Some redundancy needs to be removed? -using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test +using Distributions, DynamicPPL, MCMCChains, Random, Test #### 1. varname_in_chain #### # here we just check if vn and its leaves are present in the chain; we are not checking its presence in model. So we don't need to pass model or varinfo to this function. From 57cd078ded0facfe9976805ccc07ee0fbda47f13 Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 11 Jul 2023 14:57:48 +0100 Subject: [PATCH 19/20] removed dependencies in `src/model_utils.jl` --- src/model_utils.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index d101c8f7d..4f317a419 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -1,6 +1,7 @@ ### Yong ############################################################## # Yong added the below new functions on 2023-07-04, they are doing the some functionalities as Tor's functions. Some redundancy needs to be removed? -using Distributions, DynamicPPL, MCMCChains, Random, Test +# using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test +using Distributions, Random, Test #### 1. varname_in_chain #### # here we just check if vn and its leaves are present in the chain; we are not checking its presence in model. So we don't need to pass model or varinfo to this function. From cf30820cf4965ec4c8e96b9fb2dbd9b51c1699c8 Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Tue, 11 Jul 2023 16:37:32 +0100 Subject: [PATCH 20/20] removed unecessary and improper imports. --- src/model_utils.jl | 3 --- test/model_utils.jl | 3 --- 2 files changed, 6 deletions(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 4f317a419..6710aea05 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -1,8 +1,5 @@ ### Yong ############################################################## # Yong added the below new functions on 2023-07-04, they are doing the some functionalities as Tor's functions. Some redundancy needs to be removed? -# using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test -using Distributions, Random, Test - #### 1. varname_in_chain #### # here we just check if vn and its leaves are present in the chain; we are not checking its presence in model. So we don't need to pass model or varinfo to this function. """ diff --git a/test/model_utils.jl b/test/model_utils.jl index a1e949f68..5c490234f 100644 --- a/test/model_utils.jl +++ b/test/model_utils.jl @@ -1,6 +1,3 @@ -using Turing, Distributions, DynamicPPL, MCMCChains, Random, Test -Random.seed!(111) - #### prepare the models and chains for testing #### # (1) manually create a chain using MCMCChains - we know what parameter names are in the chain val = [1 2; 3 4; 5 6; 7 8; 9 10; 11 12; 13 14; 15 16; 17 18; 19 20]