Skip to content

Conversation

@torfjelde
Copy link
Member

It's currently quite difficult to convert a MCMCChains.Chains object obtained from some DPPL model into something that we can then use in combination with model to do post-inference analysis, etc.

A very standard scenario is effectively the predict method currently present in Turing.jl, where we do some hacky stuff with strings and setting of values for a given varinfo. But now that we have condition and friends, a very natural approach would of course be to:

  1. Convert a Chains into an iterator over (Ordered)Dict corresponding to the parameter realizations.
  2. Condition on said (Orderd)Dict and sample from the predictive model.

This PR introduces functionality that allows us to do exactly this.

Specifically, it introduces a method values_iterator_from_chain(model, chain) which returns such an iterator for a given model.

NOTE: This does assume the model is static, i.e. we use a single VarInfo to initialize the process.

This is related to #438 as that PR becomes quite trivial with the values_iterator_from_chain introduced in this PR.

Also related: #478

@coveralls
Copy link

coveralls commented Jun 7, 2023

Pull Request Test Coverage Report for Build 5577657491

  • 0 of 56 (0.0%) changed or added relevant lines in 1 file are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage decreased (-1.6%) to 75.102%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/model_utils.jl 0 56 0.0%
Totals Coverage Status
Change from base Build 5564023611: -1.6%
Covered Lines: 2024
Relevant Lines: 2695

💛 - Coveralls

@codecov
Copy link

codecov bot commented Jun 7, 2023

Codecov Report

Patch coverage has no change and project coverage change: -1.60 ⚠️

Comparison is base (3f80199) 76.69% compared to head (86d8fb1) 75.10%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #481      +/-   ##
==========================================
- Coverage   76.69%   75.10%   -1.60%     
==========================================
  Files          22       23       +1     
  Lines        2639     2695      +56     
==========================================
  Hits         2024     2024              
- Misses        615      671      +56     
Impacted Files Coverage Δ
src/DynamicPPL.jl 100.00% <ø> (ø)
src/model_utils.jl 0.00% <0.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@torfjelde
Copy link
Member Author

I believe tests should now be passing 👍 Should be ready for a look-over.

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@torfjelde
Copy link
Member Author

(I need to add tests for this PR)

return varname_in_chain!(vi[vn_parent], vn_parent, chain, out, chain_idx, iteration_idx)
end

function varname_in_chain!(x, vn_parent, chain, out, chain_idx, iteration_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the 'x' as an input argument here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! It's used below for example,, so we need to pass it to every one of these.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the clarification, yes just found it in varname_leaves.

@yebai yebai changed the base branch from master to yongchao/logp.jl June 24, 2023 11:05
@yebai yebai changed the base branch from yongchao/logp.jl to master June 24, 2023 11:06
"""
varnames_in_chain(model::Model, chain) = varnames_in_chain(VarInfo(model), chain)
function varnames_in_chain(varinfo::VarInfo, chain)
return all(vn -> varname_in_chain(varinfo, vn, chain), keys(varinfo))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

varname_in_chain(varinfo, vn, chain) doesn't exist; maybe using varname_in_chain(varinfo, vn, chain,1,1)?

for vn in varname_leaves(VarName{sym}(), x)
# Update `out`, possibly in place, and return.
l = AbstractPPL.getlens(vn)
varname_in_chain!(x, vn_parent l, chain, out, chain_idx, iteration_idx)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the function used here varname_in_chain!(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx) has the same signature as the function we are defining in Line 78 - will it lead to circular dead lop (I found this when testing it on my machine)?
Maybe revise both as follows?

function vn_in_chain(x, vn_parent, chain, out, chain_idx, iteration_idx)
    sym = Symbol(vn_parent)
    out[vn_parent] = sym ∈ names(chain) && !ismissing(chain[iteration_idx, sym, chain_idx])
    return out
end
function varname_in_chain!(
    x::AbstractArray, vn_parent::VarName{sym}, chain, out, chain_idx, iteration_idx
) where {sym}
    # We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
    # This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
    # to extract the value from the `chain`.
    for vn in varname_leaves(VarName{sym}(), x)
        # Update `out`, possibly in place, and return.
        l = AbstractPPL.getlens(vn)
        println(vn_parent ∘ l)
        print(typeof(vn_parent ∘ l))
        vn_in_chain(x, vn_parent ∘ l, chain, out, chain_idx, iteration_idx)
    end
    return out
end

# We use `VarName{sym}()` so that the resulting leaf `vn` only contains the tail of the lens.
# This way we can use `getlens(vn)` to extract the value from `x` and use `vn_parent ∘ getlens(vn)`
# to extract the value from the `chain`.
for vn in varname_leaves(VarName{sym}(), x)
Copy link
Contributor

@YongchaoHuang YongchaoHuang Jul 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't know if my understanding is correct or not - if we have a chain with variable name s in it, and if the VarName s is passed to the following:

for vn in varname_leaves(VarName(:s), [1])
    l = AbstractPPL.getlens(vn)
    println(VarName(:s) ∘ l)
end

it will generate the VarName s[1] which is obviously not present in the chain. This means for variables that take scalar value, the test will fail. I tested this and it did output s[1] => false; but what we want is s => true I guess.

yebai added a commit that referenced this pull request Jul 15, 2023
* suggested changes:
1. fixed some typos in `src/model_utils.jl` (e.g. missing !, inconsistent sequence of `out`, etc), removed unused `x` from argument.
2. wrote tests in `test/model_utils.jl`
3. included the test file in `runtests.jl`

* Update src/model_utils.jl

Co-authored-by: Hong Ge <[email protected]>

---------

Co-authored-by: Hong Ge <[email protected]>
@torfjelde torfjelde force-pushed the torfjelde/extract-model-values-from-chain branch from 00de8fc to 49b6b93 Compare July 17, 2023 11:47
torfjelde and others added 4 commits July 17, 2023 12:53
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
…rom-chain' into torfjelde/extract-model-values-from-chain
Copy link
Member

@yebai yebai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Let's try to get this PR merged if there is nothing major left. It would pave the way to merge #438

@yebai yebai enabled auto-merge July 21, 2023 12:32
@yebai yebai added this pull request to the merge queue Jul 21, 2023
Merged via the queue into master with commit ba206f4 Jul 21, 2023
@yebai yebai deleted the torfjelde/extract-model-values-from-chain branch July 21, 2023 13:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants