Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions src/model_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Return `true` if all variable names in `model`/`varinfo` are in `chain`.
"""
varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain)
function varnames_in_chain(varinfo::VarInfo, chain)
return all(vn -> varname_in_chain(varinfo, vn, chain), keys(varinfo))
return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo))
end

"""
Expand All @@ -15,10 +15,10 @@ end

Return `out` with `true` for all variable names in `model` that are in `chain`.
"""
function varnames_in_chain!(model::Model, chain, out)
function varnames_in_chain!(model::Model, chain, out::OrderedDict)
return varnames_in_chain!(VarInfo(model), chain, out)
end
function varnames_in_chain!(varinfo::VarInfo, chain, out)
function varnames_in_chain!(varinfo::VarInfo, chain, out::OrderedDict)
for vn in keys(varinfo)
varname_in_chain!(varinfo, vn, chain, 1, 1, out)
end
Expand All @@ -43,7 +43,7 @@ end

function varname_in_chain(x, vn, chain, chain_idx, iteration_idx)
out = OrderedDict{VarName,Bool}()
varname_in_chain!(x, vn, chain, out, chain_idx, iteration_idx)
varname_in_chain!(x, vn, chain, chain_idx, iteration_idx, out)
return all(values(out))
end

Expand All @@ -59,32 +59,32 @@ If `chain_idx` and `iteration_idx` are not provided, then they default to `1`.
This differs from [`varname_in_chain`](@ref) in that it returns a dictionary
rather than a single boolean. This can be quite useful for debugging purposes.
"""
function varname_in_chain!(model::Model, vn, chain, out, chain_idx, iteration_idx)
function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out::OrderedDict)
return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out)
end

function varname_in_chain!(
vi::AbstractVarInfo, vn_parent, chain, out, chain_idx, iteration_idx
vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx, out::OrderedDict
)
return varname_in_chain!(vi[vn_parent], vn_parent, chain, out, chain_idx, iteration_idx)
return varname_in_chain!(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx, out)
end

function varname_in_chain!(x, vn_parent, chain, out, chain_idx, iteration_idx)
sym = Symbol(vn_parent)
out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx])
function vn_in_chain(vn_child, chain, chain_idx, iteration_idx, out::OrderedDict)
sym = Symbol(vn_child)
out[vn_child] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx])
return out
end

function varname_in_chain!(
x::AbstractArray, vn_parent::VarName{sym}, chain, out, chain_idx, iteration_idx
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out::OrderedDict
) where {sym}
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
# to extract the value from the `chain`.
for vn in varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getlens(vn)
varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx)
vn_in_chain(vn_parent ∘ l, chain, chain_idx, iteration_idx, out)
end
return out
end
Expand All @@ -96,9 +96,9 @@ end
Return a dictionary mapping each variable name in `model`/`varinfo` to its
value in `chain` at `chain_idx` and `iteration_idx`.
"""
function values_from_chain(x, vn_parent, chain, chain_idx, iteration_idx)
function values_from_chain(vn_child, chain, chain_idx, iteration_idx)
# HACK: If it's not an array, we fall back to just returning the first value.
return only(chain[iteration_idx, Symbol(vn_parent), chain_idx])
return only(chain[iteration_idx, Symbol(vn_child), chain_idx])
end
function values_from_chain(
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx
Expand Down Expand Up @@ -131,11 +131,11 @@ end
Mutate `out` to map each variable name in `model`/`varinfo` to its value in
`chain` at `chain_idx` and `iteration_idx`.
"""
function values_from_chain!(model::DynamicPPL.Model, chain, out, chain_idx, iteration_idx)
return values_from_chain(VarInfo(model), chain, out, chain_idx, iteration_idx)
function values_from_chain!(model::DynamicPPL.Model, chain, chain_idx, iteration_idx, out::OrderedDict)
return values_from_chain!(VarInfo(model), chain, chain_idx, iteration_idx, out::OrderedDict)
end

function values_from_chain!(vi::AbstractVarInfo, chain, out, chain_idx, iteration_idx)
function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out::OrderedDict)
for vn in keys(vi)
out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx)
end
Expand Down Expand Up @@ -204,6 +204,6 @@ function value_iterator_from_chain(vi::AbstractVarInfo, chain)
return Iterators.map(
Iterators.product(1:size(chain, 1), 1:size(chain, 3))
) do (iteration_idx, chain_idx)
values_from_chain!(vi, chain, OrderedDict{VarName,Any}(), chain_idx, iteration_idx)
values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict())
end
end
60 changes: 60 additions & 0 deletions test/model_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#### prepare the models and chains for testing ####
# (1) manually create a chain using MCMCChains - we know what parameter names are in the chain
val = [1 2; 3 4; 5 6; 7 8; 9 10; 11 12; 13 14; 15 16; 17 18; 19 20]
val = Matrix{Int}(val)
chn_1 = Chains(val, [:s, :m])

# (2) sample a Turing model to create a chain
@model function gdemo(x)
mu ~ MvNormal([0, 0, 0], [1 0 0; 0 1 0; 0 0 1])
return x ~ MvNormal(mu, [1 0 0; 0 1 0; 0 0 1])
end
model_2 = gdemo([0, 0, 0])
chn_2 = sample(model_2, NUTS(), 100)

#### test the functions in model_utils.jl ####
@testset "model_utils.jl" begin
@testset "varname_in_chain" begin
@test all(values(vn_in_chain(VarName(:s), chn_1, 1, 1, OrderedDict())))
@test all(values(vn_in_chain(VarName(:m), chn_1, 1, 1, OrderedDict())))
@test all(values(vn_in_chain(VarName(Symbol("mu[1]")), chn_2, 1, 1, OrderedDict())))

@test all(values(varname_in_chain!(VarInfo(model_2)[VarName(:mu)], VarName(:mu), chn_2, 1, 1, OrderedDict())))

@test length(keys(varname_in_chain!(VarInfo(model_2), VarName(:mu), chn_2, 1, 1, OrderedDict())))==3
@test all(values(varname_in_chain!(VarInfo(model_2), VarName(:mu), chn_2, 1, 1, OrderedDict())))

@test length(keys(varname_in_chain!(model_2, VarName(:mu), chn_2, 1, 1, OrderedDict())))==3
@test all(values(varname_in_chain!(model_2, VarName(:mu), chn_2, 1, 1, OrderedDict())))

@test varname_in_chain(VarInfo(model_2), VarName(:mu), chn_2, 1, 1)
@test varname_in_chain(model_2, VarName(:mu), chn_2, 1, 1)

end
@testset "varnames_in_chain" begin
@test length(varnames_in_chain!(VarInfo(model_2), chn_2, OrderedDict())) == 3
@test all(values(varnames_in_chain!(VarInfo(model_2), chn_2, OrderedDict())))
@test length(varnames_in_chain!(model_2, chn_2, OrderedDict())) == 3
@test all(values(varnames_in_chain!(model_2, chn_2, OrderedDict())))

@test varnames_in_chain(model_2, chn_2)
end
@testset "values_from_chain" begin
@test isa(values_from_chain(VarName(:s), chn_1, 1, 1), Number)
@test isa(values_from_chain(VarName(Symbol("mu[1]")), chn_2, 1, 1), Number)

@test all(isa.(values_from_chain(VarInfo(model_2)[VarName(:mu)], VarName(:mu), chn_2, 1, 1), Number))

@test all(isa.(values_from_chain(VarInfo(model_2), VarName(:mu), chn_2, 1, 1), Number))

@test all(isa.(collect(values(values_from_chain!(model_2, chn_2, 1, 1, OrderedDict())))[1], Number))

@test all(isa.(collect(values(values_from_chain!(VarInfo(model_2), chn_2, 1, 1, OrderedDict())))[1], Number))
end
@testset "value_iterator_from_chain" begin
all_values = collect(value_iterator_from_chain(model_2, chn_2))
for ordered_dict in all_values
@test all(isa.(collect(values(ordered_dict))[1], Number))
end
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ include("test_util.jl")
include("serialization.jl")

include("loglikelihoods.jl")

include("model_utils.jl")
end

@testset "compat" begin
Expand Down