-
Notifications
You must be signed in to change notification settings - Fork 37
Added functionality for extracting parameter values for a model from chain #481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ef0ac6f
8f311fe
979db2d
792d1bd
7511e4f
1a2dfc7
642be6c
49b6b93
24efbe3
2631287
ec3e65d
b629553
86d8fb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,209 @@ | ||
| """ | ||
| varnames_in_chain(model:::Model, chain) | ||
| varnames_in_chain(varinfo::VarInfo, chain) | ||
|
|
||
| Return `true` if all variable names in `model`/`varinfo` are in `chain`. | ||
| """ | ||
| varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) | ||
| function varnames_in_chain(varinfo::VarInfo, chain) | ||
| return all(vn -> varname_in_chain(varinfo, vn, chain, 1, 1), keys(varinfo)) | ||
| end | ||
|
|
||
| """ | ||
| varnames_in_chain!(model::Model, chain, out) | ||
| varnames_in_chain!(varinfo::VarInfo, chain, out) | ||
|
|
||
| Return `out` with `true` for all variable names in `model` that are in `chain`. | ||
| """ | ||
| function varnames_in_chain!(model::Model, chain, out) | ||
| return varnames_in_chain!(VarInfo(model), chain, out) | ||
| end | ||
| function varnames_in_chain!(varinfo::VarInfo, chain, out) | ||
| for vn in keys(varinfo) | ||
| varname_in_chain!(varinfo, vn, chain, 1, 1, out) | ||
| end | ||
|
|
||
| return out | ||
| end | ||
|
|
||
| """ | ||
| varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) | ||
| varname_in_chain(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx) | ||
|
|
||
| Return `true` if `vn` is in `chain` at `chain_idx` and `iteration_idx`. | ||
| """ | ||
| function varname_in_chain(model::Model, vn, chain, chain_idx, iteration_idx) | ||
| return varname_in_chain(VarInfo(model), vn, chain, chain_idx, iteration_idx) | ||
| end | ||
|
|
||
| function varname_in_chain(varinfo::AbstractVarInfo, vn, chain, chain_idx, iteration_idx) | ||
| !haskey(varinfo, vn) && return false | ||
| return varname_in_chain(varinfo[vn], vn, chain, chain_idx, iteration_idx) | ||
| end | ||
|
|
||
| function varname_in_chain(x, vn, chain, chain_idx, iteration_idx) | ||
| out = OrderedDict{VarName,Bool}() | ||
| varname_in_chain!(x, vn, chain, chain_idx, iteration_idx, out) | ||
| return all(values(out)) | ||
| end | ||
|
|
||
| """ | ||
| varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out) | ||
| varname_in_chain!(varinfo::VarInfo, vn, chain, chain_idx, iteration_idx, out) | ||
|
|
||
| Return a dictionary mapping the varname `vn` to `true` if `vn` is in `chain` at | ||
| `chain_idx` and `iteration_idx`. | ||
|
|
||
| If `chain_idx` and `iteration_idx` are not provided, then they default to `1`. | ||
|
|
||
| This differs from [`varname_in_chain`](@ref) in that it returns a dictionary | ||
| rather than a single boolean. This can be quite useful for debugging purposes. | ||
| """ | ||
| function varname_in_chain!(model::Model, vn, chain, chain_idx, iteration_idx, out) | ||
| return varname_in_chain!(VarInfo(model), vn, chain, chain_idx, iteration_idx, out) | ||
yebai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| end | ||
|
|
||
| function varname_in_chain!( | ||
| vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx, out | ||
| ) | ||
| return varname_in_chain!(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx, out) | ||
| end | ||
|
|
||
| function varname_in_chain!(x, vn_parent, chain, chain_idx, iteration_idx, out) | ||
| sym = Symbol(vn_parent) | ||
| out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx]) | ||
| return out | ||
| end | ||
|
|
||
| function varname_in_chain!( | ||
| x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx, out | ||
| ) where {sym} | ||
| # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. | ||
| # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` | ||
| # to extract the value from the `chain`. | ||
| for vn in varname_leaves(VarName{sym}(), x) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't know if my understanding is correct or not - if we have a chain with variable name it will generate the VarName |
||
| # Update `out`, possibly in place, and return. | ||
| l = AbstractPPL.getlens(vn) | ||
| varname_in_chain!(x, vn_parent ∘ l, chain, chain_idx, iteration_idx, out) | ||
| end | ||
| return out | ||
| end | ||
|
|
||
| """ | ||
| values_from_chain(model::Model, chain, chain_idx, iteration_idx) | ||
| values_from_chain(varinfo::VarInfo, chain, chain_idx, iteration_idx) | ||
|
|
||
| Return a dictionary mapping each variable name in `model`/`varinfo` to its | ||
| value in `chain` at `chain_idx` and `iteration_idx`. | ||
| """ | ||
| function values_from_chain(x, vn_parent, chain, chain_idx, iteration_idx) | ||
| # HACK: If it's not an array, we fall back to just returning the first value. | ||
| return only(chain[iteration_idx, Symbol(vn_parent), chain_idx]) | ||
| end | ||
| function values_from_chain( | ||
| x::AbstractArray, vn_parent::VarName{sym}, chain, chain_idx, iteration_idx | ||
| ) where {sym} | ||
| # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. | ||
| # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` | ||
| # to extract the value from the `chain`. | ||
| out = similar(x) | ||
| for vn in varname_leaves(VarName{sym}(), x) | ||
| # Update `out`, possibly in place, and return. | ||
| l = AbstractPPL.getlens(vn) | ||
| out = Setfield.set( | ||
| out, | ||
| BangBang.prefermutation(l), | ||
| chain[iteration_idx, Symbol(vn_parent ∘ l), chain_idx], | ||
| ) | ||
| end | ||
|
|
||
| return out | ||
| end | ||
| function values_from_chain(vi::AbstractVarInfo, vn_parent, chain, chain_idx, iteration_idx) | ||
| # Use the value `vi[vn_parent]` to obtain a buffer. | ||
| return values_from_chain(vi[vn_parent], vn_parent, chain, chain_idx, iteration_idx) | ||
| end | ||
|
|
||
| """ | ||
| values_from_chain!(model::Model, chain, chain_idx, iteration_idx, out) | ||
| values_from_chain!(varinfo::VarInfo, chain, chain_idx, iteration_idx, out) | ||
|
|
||
| Mutate `out` to map each variable name in `model`/`varinfo` to its value in | ||
| `chain` at `chain_idx` and `iteration_idx`. | ||
| """ | ||
| function values_from_chain!(model::DynamicPPL.Model, chain, chain_idx, iteration_idx, out) | ||
| return values_from_chain(VarInfo(model), chain, chain_idx, iteration_idx, out) | ||
| end | ||
|
|
||
| function values_from_chain!(vi::AbstractVarInfo, chain, chain_idx, iteration_idx, out) | ||
| for vn in keys(vi) | ||
| out[vn] = values_from_chain(vi, vn, chain, chain_idx, iteration_idx) | ||
| end | ||
| return out | ||
| end | ||
|
|
||
| """ | ||
| value_iterator_from_chain(model::Model, chain) | ||
| value_iterator_from_chain(varinfo::AbstractVarInfo, chain) | ||
|
|
||
| Return an iterator over the values in `chain` for each variable in `model`/`varinfo`. | ||
|
|
||
| # Example | ||
| ```julia | ||
| julia> using MCMCChains, DynamicPPL, Distributions, StableRNGs | ||
|
|
||
| julia> rng = StableRNG(42); | ||
|
|
||
| julia> @model function demo_model(x) | ||
| s ~ InverseGamma(2, 3) | ||
| m ~ Normal(0, sqrt(s)) | ||
| for i in eachindex(x) | ||
| x[i] ~ Normal(m, sqrt(s)) | ||
| end | ||
|
|
||
| return s, m | ||
| end | ||
| demo_model (generic function with 2 methods) | ||
|
|
||
| julia> model = demo_model([1.0, 2.0]); | ||
|
|
||
| julia> chain = Chains(rand(rng, 10, 2, 3), [:s, :m]); | ||
|
|
||
| julia> iter = value_iterator_from_chain(model, chain); | ||
|
|
||
| julia> first(iter) | ||
| OrderedDict{VarName, Any} with 2 entries: | ||
| s => 0.580515 | ||
| m => 0.739328 | ||
|
|
||
| julia> collect(iter) | ||
| 10×3 Matrix{OrderedDict{VarName, Any}}: | ||
| OrderedDict(s=>0.580515, m=>0.739328) … OrderedDict(s=>0.186047, m=>0.402423) | ||
| OrderedDict(s=>0.191241, m=>0.627342) OrderedDict(s=>0.776277, m=>0.166342) | ||
| OrderedDict(s=>0.971133, m=>0.637584) OrderedDict(s=>0.651655, m=>0.712044) | ||
| OrderedDict(s=>0.74345, m=>0.110359) OrderedDict(s=>0.469214, m=>0.104502) | ||
| OrderedDict(s=>0.170969, m=>0.598514) OrderedDict(s=>0.853546, m=>0.185399) | ||
| OrderedDict(s=>0.704776, m=>0.322111) … OrderedDict(s=>0.638301, m=>0.853802) | ||
| OrderedDict(s=>0.441044, m=>0.162285) OrderedDict(s=>0.852959, m=>0.0956922) | ||
| OrderedDict(s=>0.803972, m=>0.643369) OrderedDict(s=>0.245049, m=>0.871985) | ||
| OrderedDict(s=>0.772384, m=>0.646323) OrderedDict(s=>0.906603, m=>0.385502) | ||
| OrderedDict(s=>0.70882, m=>0.253105) OrderedDict(s=>0.413222, m=>0.953288) | ||
|
|
||
| julia> # This can be used to `condition` a `Model`. | ||
| conditioned_model = model | first(iter); | ||
|
|
||
| julia> conditioned_model() # <= results in same values as the `first(iter)` above | ||
| (0.5805148626851955, 0.7393275279160691) | ||
| ``` | ||
| """ | ||
| function value_iterator_from_chain(model::DynamicPPL.Model, chain) | ||
| return value_iterator_from_chain(VarInfo(model), chain) | ||
| end | ||
|
|
||
| function value_iterator_from_chain(vi::AbstractVarInfo, chain) | ||
| return Iterators.map( | ||
| Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
| ) do (iteration_idx, chain_idx) | ||
| values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}()) | ||
| end | ||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.