-
Notifications
You must be signed in to change notification settings - Fork 37
Added functionality for extracting parameter values for a model from chain #481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…m:TuringLang/DynamicPPL.jl into torfjelde/extract-model-values-from-chain
Pull Request Test Coverage Report for Build 5577657491
💛 - Coveralls |
Codecov ReportPatch coverage has no change and project coverage change:
Additional details and impacted files@@ Coverage Diff @@
## master #481 +/- ##
==========================================
- Coverage 76.69% 75.10% -1.60%
==========================================
Files 22 23 +1
Lines 2639 2695 +56
==========================================
Hits 2024 2024
- Misses 615 671 +56
☔ View full report in Codecov by Sentry. |
|
I believe tests should now be passing 👍 Should be ready for a look-over. |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
(I need to add tests for this PR) |
src/model_utils.jl
Outdated
| return varname_in_chain!(vi[vn_parent], vn_parent, chain, out, chain_idx, iteration_idx) | ||
| end | ||
|
|
||
| function varname_in_chain!(x, vn_parent, chain, out, chain_idx, iteration_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need the 'x' as an input argument here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup! It's used below for example,, so we need to pass it to every one of these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the clarification, yes just found it in varname_leaves.
src/model_utils.jl
Outdated
| """ | ||
| varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain) | ||
| function varnames_in_chain(varinfo::VarInfo, chain) | ||
| return all(vn -> varname_in_chain(varinfo, vn, chain), keys(varinfo)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
varname_in_chain(varinfo, vn, chain) doesn't exist; maybe using varname_in_chain(varinfo, vn, chain,1,1)?
src/model_utils.jl
Outdated
| for vn in varname_leaves(VarName{sym}(), x) | ||
| # Update `out`, possibly in place, and return. | ||
| l = AbstractPPL.getlens(vn) | ||
| varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the function used here varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx) has the same signature as the function we are defining in Line 78 - will it lead to circular dead lop (I found this when testing it on my machine)?
Maybe revise both as follows?
function vn_in_chain(x, vn_parent, chain, out, chain_idx, iteration_idx)
sym = Symbol(vn_parent)
out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx])
return out
end
function varname_in_chain!(
x::AbstractArray, vn_parent::VarName{sym}, chain, out, chain_idx, iteration_idx
) where {sym}
# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
# to extract the value from the `chain`.
for vn in varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getlens(vn)
println(vn_parent ∘ l)
print(typeof(vn_parent ∘ l))
vn_in_chain(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx)
end
return out
end
| # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens. | ||
| # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)` | ||
| # to extract the value from the `chain`. | ||
| for vn in varname_leaves(VarName{sym}(), x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't know if my understanding is correct or not - if we have a chain with variable name s in it, and if the VarName s is passed to the following:
for vn in varname_leaves(VarName(:s), [1])
l = AbstractPPL.getlens(vn)
println(VarName(:s) ∘ l)
end
it will generate the VarName s[1] which is obviously not present in the chain. This means for variables that take scalar value, the test will fail. I tested this and it did output s[1] => false; but what we want is s => true I guess.
* suggested changes: 1. fixed some typos in `src/model_utils.jl` (e.g. missing !, inconsistent sequence of `out`, etc), removed unused `x` from argument. 2. wrote tests in `test/model_utils.jl` 3. included the test file in `runtests.jl` * Update src/model_utils.jl Co-authored-by: Hong Ge <[email protected]> --------- Co-authored-by: Hong Ge <[email protected]>
00de8fc to
49b6b93
Compare
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…rom-chain' into torfjelde/extract-model-values-from-chain
yebai
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Let's try to get this PR merged if there is nothing major left. It would pave the way to merge #438
It's currently quite difficult to convert a
MCMCChains.Chainsobject obtained from some DPPLmodelinto something that we can then use in combination withmodelto do post-inference analysis, etc.A very standard scenario is effectively the
predictmethod currently present in Turing.jl, where we do some hacky stuff with strings and setting of values for a givenvarinfo. But now that we haveconditionand friends, a very natural approach would of course be to:Chainsinto an iterator over(Ordered)Dictcorresponding to the parameter realizations.(Orderd)Dictand sample from the predictive model.This PR introduces functionality that allows us to do exactly this.
Specifically, it introduces a method
values_iterator_from_chain(model, chain)which returns such an iterator for a givenmodel.NOTE: This does assume the
modelis static, i.e. we use a singleVarInfoto initialize the process.This is related to #438 as that PR becomes quite trivial with the
values_iterator_from_chainintroduced in this PR.Also related: #478