Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,13 @@ For a chain of samples, one can compute the pointwise log-likelihoods of each ob
pointwise_loglikelihoods
```

For converting a chain into a format that can more easily be fed into a `Model` again, for example using `condition`, you can use [`value_iterator_from_chain`](@ref).

```@docs
value_iterator_from_chain

```

Sometimes it can be useful to extract the priors of a model. This is the possible using [`extract_priors`](@ref).

```@docs
Expand Down
4 changes: 3 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ export AbstractVarInfo,
unfix,
# Convenience macros
@addlogprob!,
@submodel
@submodel,
value_iterator_from_chain

# Reexport
using Distributions: loglikelihood
Expand Down Expand Up @@ -169,6 +170,7 @@ include("submodel_macro.jl")
include("test_utils.jl")
include("transforming.jl")
include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")

end # module
209 changes: 209 additions & 0 deletions src/model_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""
varnames_in_chain(model:::Model, chain)
varnames_in_chain(varinfo::VarInfo, chain)

Return `true` if all variable names in `model`/`varinfo` are in `chain`.
"""
varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain)
function varnames_in_chain(varinfo::VarInfo, chain)
return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo))
end

"""
varnames_in_chain!(model::Model, chain, out)
varnames_in_chain!(varinfo::VarInfo, chain, out)

Return `out` with `true` for all variable names in `model` that are in `chain`.
"""
function varnames_in_chain!(model::Model, chain, out)
return varnames_in_chain!(VarInfo(model), chain, out)
end
function varnames_in_chain!(varinfo::VarInfo, chain, out)
for vn in keys(varinfo)
varname_in_chain!(varinfo, vn, chain, 1, 1, out)
end

return out
end

"""
varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx)
varname_in_chain(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx)

Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`.
"""
function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx)
return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx)
end

function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx)
!haskey(varinfo, vn) && return false
return varname_in_chain(varinfo[vn], vn, chain, chain_idx, iteration_idx)
end

function varname_in_chain(x, vn, chain, chain_idx, iteration_idx)
out = OrderedDict{VarName,Bool}()
varname_in_chain!(x, vn, chain, chain_idx, iteration_idx, out)
return all(values(out))
end

"""
varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out)
varname_in_chain!(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx, out)

Return a dictionary mapping the varname `vn` to `true` if `vn` is in `chain` at
`chain_idx` and `iteration_idx`.

If `chain_idx` and `iteration_idx` are not provided, then they default to `1`.

This differs from [`varname_in_chain`](@ref) in that it returns a dictionary
rather than a single boolean. This can be quite useful for debugging purposes.
"""
function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out)
return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out)
end

function varname_in_chain!(
vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx, out
)
return varname_in_chain!(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx, out)
end

function varname_in_chain!(x, vn_parent, chain, chain_idx, iteration_idx, out)
sym = Symbol(vn_parent)
out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx])
return out
end

function varname_in_chain!(
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out
) where {sym}
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
# to extract the value from the `chain`.
for vn in varname_leaves(VarName{sym}(), x)
Copy link
Contributor

@YongchaoHuang YongchaoHuang Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know if my understanding is correct or not - if we have a chain with variable name s in it, and if the VarName s is passed to the following:

for vn in varname_leaves(VarName(:s), [1])
    l = AbstractPPL.getlens(vn)
    println(VarName(:s) ∘ l)
end

it will generate the VarName s[1] which is obviously not present in the chain. This means for variables that take scalar value, the test will fail. I tested this and it did output s[1] => false; but what we want is s => true I guess.

# Update `out`, possibly in place, and return.
l = AbstractPPL.getlens(vn)
varname_in_chain!(x, vn_parent ∘ l, chain, chain_idx, iteration_idx, out)
end
return out
end

"""
values_from_chain(model::Model, chain, chain_idx, iteration_idx)
values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx)

Return a dictionary mapping each variable name in `model`/`varinfo` to its
value in `chain` at `chain_idx` and `iteration_idx`.
"""
function values_from_chain(x, vn_parent, chain, chain_idx, iteration_idx)
# HACK: If it's not an array, we fall back to just returning the first value.
return only(chain[iteration_idx, Symbol(vn_parent), chain_idx])
end
function values_from_chain(
x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx
) where {sym}
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
# to extract the value from the `chain`.
out = similar(x)
for vn in varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getlens(vn)
out = Setfield.set(
out,
BangBang.prefermutation(l),
chain[iteration_idx, Symbol(vn_parent ∘ l), chain_idx],
)
end

return out
end
function values_from_chain(vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx)
# Use the value `vi[vn_parent]` to obtain a buffer.
return values_from_chain(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx)
end

"""
values_from_chain!(model::Model, chain, chain_idx, iteration_idx, out)
values_from_chain!(varinfo::VarInfo, chain, chain_idx, iteration_idx, out)

Mutate `out` to map each variable name in `model`/`varinfo` to its value in
`chain` at `chain_idx` and `iteration_idx`.
"""
function values_from_chain!(model::DynamicPPL.Model, chain, chain_idx, iteration_idx, out)
return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx, out)
end

function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out)
for vn in keys(vi)
out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx)
end
return out
end

"""
value_iterator_from_chain(model::Model, chain)
value_iterator_from_chain(varinfo::AbstractVarInfo, chain)

Return an iterator over the values in `chain` for each variable in `model`/`varinfo`.

# Example
```julia
julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs

julia> rng = StableRNG(42);

julia> @model function demo_model(x)
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
for i in eachindex(x)
x[i] ~ Normal(m, sqrt(s))
end

return s, m
end
demo_model (generic function with 2 methods)

julia> model = demo_model([1.0, 2.0]);

julia> chain = Chains(rand(rng, 10, 2, 3), [:s, :m]);

julia> iter = value_iterator_from_chain(model, chain);

julia> first(iter)
OrderedDict{VarName, Any} with 2 entries:
s => 0.580515
m => 0.739328

julia> collect(iter)
10×3 Matrix{OrderedDict{VarName, Any}}:
OrderedDict(s=>0.580515, m=>0.739328) … OrderedDict(s=>0.186047, m=>0.402423)
OrderedDict(s=>0.191241, m=>0.627342) OrderedDict(s=>0.776277, m=>0.166342)
OrderedDict(s=>0.971133, m=>0.637584) OrderedDict(s=>0.651655, m=>0.712044)
OrderedDict(s=>0.74345, m=>0.110359) OrderedDict(s=>0.469214, m=>0.104502)
OrderedDict(s=>0.170969, m=>0.598514) OrderedDict(s=>0.853546, m=>0.185399)
OrderedDict(s=>0.704776, m=>0.322111) … OrderedDict(s=>0.638301, m=>0.853802)
OrderedDict(s=>0.441044, m=>0.162285) OrderedDict(s=>0.852959, m=>0.0956922)
OrderedDict(s=>0.803972, m=>0.643369) OrderedDict(s=>0.245049, m=>0.871985)
OrderedDict(s=>0.772384, m=>0.646323) OrderedDict(s=>0.906603, m=>0.385502)
OrderedDict(s=>0.70882, m=>0.253105) OrderedDict(s=>0.413222, m=>0.953288)

julia> # This can be used to `condition` a `Model`.
conditioned_model = model | first(iter);

julia> conditioned_model() # <= results in same values as the `first(iter)` above
(0.5805148626851955, 0.7393275279160691)
```
"""
function value_iterator_from_chain(model::DynamicPPL.Model, chain)
return value_iterator_from_chain(VarInfo(model), chain)
end

function value_iterator_from_chain(vi::AbstractVarInfo, chain)
return Iterators.map(
Iterators.product(1:size(chain, 1), 1:size(chain, 3))
) do (iteration_idx, chain_idx)
values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}())
end
end
15 changes: 15 additions & 0 deletions test/turing/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,19 @@
test_setval!(model, MCMCChains.get_sections(chain, :parameters))
end
end

@testset "value_iterator_from_chain" begin
@testset "$model" for model in DynamicPPL.TestUtils.DEMO_MODELS
chain = sample(model, Prior(), 10; progress=false)
for (i, d) in enumerate(value_iterator_from_chain(model, chain))
for vn in keys(d)
val = DynamicPPL.getvalue(d, vn)
for vn_leaf in DynamicPPL.varname_leaves(vn, val)
val_leaf = DynamicPPL.getvalue(d, vn_leaf)
@test val_leaf == chain[i, Symbol(vn_leaf), 1]
end
end
end
end
end
end