From be8d42b5cc49110206cfef837c95060fcd4a81b0 Mon Sep 17 00:00:00 2001 From: YongchaoHuang Date: Thu, 13 Jul 2023 19:47:43 +0100 Subject: [PATCH 1/2] suggested changes: 1. fixed some typos in `src/model_utils.jl` (e.g. missing !, inconsistent sequence of `out`, etc), removed unused `x` from argument. 2. wrote tests in `test/model_utils.jl` 3. included the test file in `runtests.jl` --- src/model_utils.jl | 37 ++++++++++++++-------------- test/model_utils.jl | 60 +++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 ++ 3 files changed, 81 insertions(+), 18 deletions(-) create mode 100644 test/model_utils.jl diff --git a/src/model_utils.jl b/src/model_utils.jl index 6600cfff0..37aa4aca8 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -6,7 +6,7 @@ Return `true` if all variable names in `model`/`varinfo` are in `chain`. """ varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) function varnames_in_chain(varinfo::VarInfo, chain) - return all(vn -> varname_in_chain(varinfo, vn, chain), keys(varinfo)) + return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo)) end """ @@ -15,10 +15,10 @@ end Return `out` with `true` for all variable names in `model` that are in `chain`. """ -function varnames_in_chain!(model::Model, chain, out) +function varnames_in_chain!(model::Model, chain, out::OrderedDict) return varnames_in_chain!(VarInfo(model), chain, out) end -function varnames_in_chain!(varinfo::VarInfo, chain, out) +function varnames_in_chain!(varinfo::VarInfo, chain, out::OrderedDict) for vn in keys(varinfo) varname_in_chain!(varinfo, vn, chain, 1, 1, out) end @@ -43,7 +43,7 @@ end function varname_in_chain(x, vn, chain, chain_idx, iteration_idx) out = OrderedDict{VarName,Bool}() - varname_in_chain!(x, vn, chain, out, chain_idx, iteration_idx) + varname_in_chain!(x, vn, chain, chain_idx, iteration_idx, out) return all(values(out)) end @@ -59,24 +59,24 @@ If `chain_idx` and `iteration_idx` are not provided, then they default to `1`. This differs from [`varname_in_chain`](@ref) in that it returns a dictionary rather than a single boolean. This can be quite useful for debugging purposes. """ -function varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx) +function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out::OrderedDict) return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out) end function varname_in_chain!( - vi::AbstractVarInfo, vn_parent, chain, out, chain_idx, iteration_idx + vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx, out::OrderedDict ) - return varname_in_chain!(vi[vn_parent], vn_parent, chain, out, chain_idx, iteration_idx) + return varname_in_chain!(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx, out) end -function varname_in_chain!(x, vn_parent, chain, out, chain_idx, iteration_idx) - sym = Symbol(vn_parent) - out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx]) +function vn_in_chain(vn_child, chain, chain_idx, iteration_idx, out::OrderedDict) + sym = Symbol(vn_child) + out[vn_child] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx]) return out end function varname_in_chain!( - x::AbstractArray, vn_parent::VarName{sym}, chain, out, chain_idx, iteration_idx + x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out::OrderedDict ) where {sym} # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` @@ -84,7 +84,8 @@ function varname_in_chain!( for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getlens(vn) - varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx) + println(vn_parent ∘ l) + vn_in_chain(vn_parent ∘ l, chain, chain_idx, iteration_idx, out) end return out end @@ -96,9 +97,9 @@ end Return a dictionary mapping each variable name in `model`/`varinfo` to its value in `chain` at `chain_idx` and `iteration_idx`. """ -function values_from_chain(x, vn_parent, chain, chain_idx, iteration_idx) +function values_from_chain(vn_child, chain, chain_idx, iteration_idx) # HACK: If it's not an array, we fall back to just returning the first value. - return only(chain[iteration_idx, Symbol(vn_parent), chain_idx]) + return only(chain[iteration_idx, Symbol(vn_child), chain_idx]) end function values_from_chain( x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx @@ -131,11 +132,11 @@ end Mutate `out` to map each variable name in `model`/`varinfo` to its value in `chain` at `chain_idx` and `iteration_idx`. """ -function values_from_chain!(model::DynamicPPL.Model, chain, out, chain_idx, iteration_idx) - return values_from_chain(VarInfo(model), chain, out, chain_idx, iteration_idx) +function values_from_chain!(model::DynamicPPL.Model, chain, chain_idx, iteration_idx, out::OrderedDict) + return values_from_chain!(VarInfo(model), chain, chain_idx, iteration_idx, out::OrderedDict) end -function values_from_chain!(vi::AbstractVarInfo, chain, out, chain_idx, iteration_idx) +function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out::OrderedDict) for vn in keys(vi) out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx) end @@ -204,6 +205,6 @@ function value_iterator_from_chain(vi::AbstractVarInfo, chain) return Iterators.map( Iterators.product(1:size(chain, 1), 1:size(chain, 3)) ) do (iteration_idx, chain_idx) - values_from_chain!(vi, chain, OrderedDict{VarName,Any}(), chain_idx, iteration_idx) + values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict()) end end diff --git a/test/model_utils.jl b/test/model_utils.jl new file mode 100644 index 000000000..8fbe030cb --- /dev/null +++ b/test/model_utils.jl @@ -0,0 +1,60 @@ +#### prepare the models and chains for testing #### +# (1) manually create a chain using MCMCChains - we know what parameter names are in the chain +val = [1 2; 3 4; 5 6; 7 8; 9 10; 11 12; 13 14; 15 16; 17 18; 19 20] +val = Matrix{Int}(val) +chn_1 = Chains(val, [:s, :m]) + +# (2) sample a Turing model to create a chain +@model function gdemo(x) + mu ~ MvNormal([0, 0, 0], [1 0 0; 0 1 0; 0 0 1]) + return x ~ MvNormal(mu, [1 0 0; 0 1 0; 0 0 1]) +end +model_2 = gdemo([0, 0, 0]) +chn_2 = sample(model_2, NUTS(), 100) + +#### test the functions in model_utils.jl #### +@testset "model_utils.jl" begin + @testset "varname_in_chain" begin + @test all(values(vn_in_chain(VarName(:s), chn_1, 1, 1, OrderedDict()))) + @test all(values(vn_in_chain(VarName(:m), chn_1, 1, 1, OrderedDict()))) + @test all(values(vn_in_chain(VarName(Symbol("mu[1]")), chn_2, 1, 1, OrderedDict()))) + + @test all(values(varname_in_chain!(VarInfo(model_2)[VarName(:mu)], VarName(:mu), chn_2, 1, 1, OrderedDict()))) + + @test length(keys(varname_in_chain!(VarInfo(model_2), VarName(:mu), chn_2, 1, 1, OrderedDict())))==3 + @test all(values(varname_in_chain!(VarInfo(model_2), VarName(:mu), chn_2, 1, 1, OrderedDict()))) + + @test length(keys(varname_in_chain!(model_2, VarName(:mu), chn_2, 1, 1, OrderedDict())))==3 + @test all(values(varname_in_chain!(model_2, VarName(:mu), chn_2, 1, 1, OrderedDict()))) + + @test varname_in_chain(VarInfo(model_2), VarName(:mu), chn_2, 1, 1) + @test varname_in_chain(model_2, VarName(:mu), chn_2, 1, 1) + + end + @testset "varnames_in_chain" begin + @test length(varnames_in_chain!(VarInfo(model_2), chn_2, OrderedDict())) == 3 + @test all(values(varnames_in_chain!(VarInfo(model_2), chn_2, OrderedDict()))) + @test length(varnames_in_chain!(model_2, chn_2, OrderedDict())) == 3 + @test all(values(varnames_in_chain!(model_2, chn_2, OrderedDict()))) + + @test varnames_in_chain(model_2, chn_2) + end + @testset "values_from_chain" begin + @test isa(values_from_chain(VarName(:s), chn_1, 1, 1), Number) + @test isa(values_from_chain(VarName(Symbol("mu[1]")), chn_2, 1, 1), Number) + + @test all(isa.(values_from_chain(VarInfo(model_2)[VarName(:mu)], VarName(:mu), chn_2, 1, 1), Number)) + + @test all(isa.(values_from_chain(VarInfo(model_2), VarName(:mu), chn_2, 1, 1), Number)) + + @test all(isa.(collect(values(values_from_chain!(model_2, chn_2, 1, 1, OrderedDict())))[1], Number)) + + @test all(isa.(collect(values(values_from_chain!(VarInfo(model_2), chn_2, 1, 1, OrderedDict())))[1], Number)) + end + @testset "value_iterator_from_chain" begin + all_values = collect(value_iterator_from_chain(model_2, chn_2)) + for ordered_dict in all_values + @test all(isa.(collect(values(ordered_dict))[1], Number)) + end + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 27889b5e5..4916ea3cc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -50,6 +50,8 @@ include("test_util.jl") include("serialization.jl") include("loglikelihoods.jl") + + include("model_utils.jl") end @testset "compat" begin From 43983d0bbf67417e1603ab8da5036433a1932273 Mon Sep 17 00:00:00 2001 From: YongchaoHuang <34540771+YongchaoHuang@users.noreply.github.com> Date: Thu, 13 Jul 2023 20:35:28 +0100 Subject: [PATCH 2/2] Update src/model_utils.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- src/model_utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/model_utils.jl b/src/model_utils.jl index 37aa4aca8..9d41075b3 100644 --- a/src/model_utils.jl +++ b/src/model_utils.jl @@ -84,7 +84,6 @@ function varname_in_chain!( for vn in varname_leaves(VarName{sym}(), x) # Update `out`, possibly in place, and return. l = AbstractPPL.getlens(vn) - println(vn_parent ∘ l) vn_in_chain(vn_parent ∘ l, chain, chain_idx, iteration_idx, out) end return out