-
Notifications
You must be signed in to change notification settings - Fork 37
[Merged by Bors] - Method to extract loglikelihoods #166
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Nice! Some quick high-level comments:
|
|
It's also a bit unfortunate that the approach will be broken with the new interface (which we can switch to soon hopefully) |
Home come? |
|
Actually, I think it might be possible to simplify this a bit since you don't accumulate log probabilities (and hence there is no benefit from using ThreadSafeVarInfo) and don't need to keep track of Lines 87 to 98 in 405546f
_evaluate (as in Line 115 in 405546f
|
Agreed. Suggestions?
I 100% agree with this, as I don't like the "implicit" dependencies on MCMCChains.jl due to how we use
|
|
If we wouldn't enforce |
The sampler setup will change and immutable by design with a separate state, so it will require changes to the current implementation (but will still be possible). Making it a context would avoid these issues I guess. |
Yep, that's the important point IMO. Maybe we could add a way to build a "dummy" VarInfo object from a set of samples? |
I'm not using a sampler state here though, right? |
Not in the "Turing" sense but the array of likelihoods is like an internal state that is mutated when running the model, isn't it? |
I'm confused. In the new interface, isn't that just a change to |
Actually, I'm not certain I quite get how I'd use a |
|
Probably most parts will be fine, just some of the additional methods might have to be redefined. But what's your opinion about just using a special context here? I guess it could even be interesting in combination with other samplers than |
Sorrry, I didn't properly respond to this; I think it's a great idea! Do you know on the top of your head what I'll need to overload here? I'll have a look myself too, ofc. |
|
Probably the context itself would be defined similar to Lines 31 to 34 in 405546f
function tilde_assume(rng, ctx::TrackedLikelihoodsContext, sampler, right, vn, inds, vi)
return tilde_assume(rng, ctx.context, sampler, right, vn, inds, vi)
endand for the observations function tilde_observe(ctx::TrackedLikelihoodsContexts, sampler, right, left, vname, vinds, vi)
# this is slightly unfortunate since it is not completely generic...
# ideally we would call `tilde_observe` recursively but then we don't get the loglikelihood value
logp = tilde(ctx.context, sampler, right, left, vi)
acclogp!(vi, logp)
# track loglikelihood value
...
return left
end
function tilde_observe(ctx, sampler, right, left, vi)
# same here
...
endand similarly for |
|
That looks great; I'll change it at some point today (gotta get some dinner now though!). Thanks man:) |
|
Implemented
I'm very for this. But I'd prefer to merge this PR without it, if that's okay. Then we raise an issue in DynamicPPL/MCMCChains (whichever is most suitable), and figure out where to go from there? |
Another comment on this: could do something similar to what we do in |
|
Btw, is this ok to merge? |
Hmm that's a bit unfortunate but I'm wondering if we should just consider these statements as a single observation for now, since there are ambiguities and design choices to make anyway. On the other hand, it's a bit unfortunate that it leads to a discrepancy between for-loop and dotted implementations - but the same problem arises if one uses, e.g., MvNormal for a set of observations.
Yes, better to keep this PR focused on the loglikelihoods. |
Maybe: function varinfo_from_chain(
model::Turing.Model,
chain::MCMCChains.Chains;
sampler = DynamicPPL.SampleFromPrior()
)
vi = Turing.VarInfo(model)
vis = map(1:length(chain)) do i
c = chain[i]
md = vi.metadata
for v in keys(md)
for vn in md[v].vns
vn_sym = Symbol(vn)
# Cannot use `vn_sym` to index in the chain
# so we have to extract the corresponding "linear"
# indices and use those.
# `ks` is empty if `vn_sym` not in `c`.
ks = MCMCChains.namesingroup(c, vn_sym)
if !isempty(ks)
# 1st dimension is of size 1 since `c`
# only contains a single sample, and the
# last dimension is of size 1 since
# we're assuming we're working with a single chain.
val = copy(vec(c[ks].value))
DynamicPPL.setval!(vi, val, vn)
DynamicPPL.settrans!(vi, false, vn)
else
DynamicPPL.set_flag!(vi, vn, "del")
end
end
end
new_vi = VarInfo(vi, sampler, vi[sampler])
setlogp!(new_vi, first(chain[i][:lp])) # Is there a better way?
return new_vi
end
return vis
endand reconstruction (at least in the context of vis = varinfo_from_chain(model, chain)
chain_new = AbstractMCMC.bundle_samples(
rng, model, SampleFromPrior(), length(vis), vis, MCMCChains.Chains
)Doesn't require executing the model or anything. |
src/loglikelihoods.jl
Outdated
| struct TrackedLikelihoodContext{A, Ctx, Tvars} <: AbstractContext | ||
| loglikelihoods::A | ||
| ctx::Ctx | ||
| vars::Tvars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems we don't use them currently - so either check for them when tracking the loglikelihoods (preferably IMO) or just remove this field?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know what the field is supposed to do?
Co-authored-by: David Widmann <[email protected]>
src/loglikelihoods.jl
Outdated
| # Do the usual thing | ||
| logp = tilde(ctx.ctx, sampler, right, left, vi) | ||
| acclogp!(vi, logp) | ||
|
|
||
| # track loglikelihood value | ||
| lookup = ctx.loglikelihoods | ||
| ℓ = get!(lookup, string(vname), Float64[]) | ||
| push!(ℓ, logp) | ||
|
|
||
| return left |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to merge this with the implementation above?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just removed it because it doesn't have vname, i.e. the method wasn't even runnable. When is this actually called? Seems like it never should be?
| # Update the values | ||
| setval!(vi, chain, sample_idx, chain_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So here empty! is not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, empty! will ruin it! I did that initially, but empty! also made it so that values would be resampled. So it ended up sampling from the prior instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm but why would it resample values in this case? Shouldn't setval! fix them? There's something going on with this empty!/setval! thing that I don't understand 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, okay then there's something weird. I thought I just had misunderstood something, but if you also don't know why that's the case then there's something going on 😅
Can it be the fact that empty! clears the "del" flag + setval! does NOT set it to false? So then we you run the model again, it will resample?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, something is wrong:
julia> using DynamicPPL, Turing
julia> import DynamicPPL: setval!
julia> @model function demo(xs)
m ~ MvNormal(2, 1.)
for i in eachindex(xs)
xs[i] ~ Normal(m[1], 1.)
end
end
demo (generic function with 1 method)
julia> model = demo(randn(3));
julia> chain = sample(model, MH(), 10);
julia> var_info = VarInfo(model);
julia> θ_old = var_info[SampleFromPrior()]
2-element Array{Float64,1}:
0.41438744434831887
0.373145757716783
julia> θ_chain = vec(MCMCChains.group(chain, :m)[1, :, 1].value)
2-element reshape(::AxisArrays.AxisArray{Float64,3,Array{Float64,3},Tuple{AxisArrays.Axis{:iter,StepRange{Int64,Int64}},AxisArrays.Axis{:var,Array{Symbol,1}},AxisArrays.Axis{:chain,UnitRange{Int64}}}}, 2) with eltype Float64:
0.3939341450590006
-1.1020030439893758
julia> empty!(var_info)
VarInfo{NamedTuple{(:m,),Tuple{DynamicPPL.Metadata{Dict{VarName{:m,Tuple{}},Int64},Array{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},1},Array{VarName{:m,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}((m = DynamicPPL.Metadata{Dict{VarName{:m,Tuple{}},Int64},Array{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},1},Array{VarName{:m,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}(Dict{VarName{:m,Tuple{}},Int64}(), VarName{:m,Tuple{}}[], UnitRange{Int64}[], Float64[], MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}}[], Set{DynamicPPL.Selector}[], Int64[], Dict{String,BitArray{1}}("del" => [],"trans" => [])),), Base.RefValue{Float64}(0.0), Base.RefValue{Int64}(0))
julia> setval!(var_info, chain, 1, 1)
VarInfo{NamedTuple{(:m,),Tuple{DynamicPPL.Metadata{Dict{VarName{:m,Tuple{}},Int64},Array{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},1},Array{VarName{:m,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}((m = DynamicPPL.Metadata{Dict{VarName{:m,Tuple{}},Int64},Array{MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}},1},Array{VarName{:m,Tuple{}},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}(Dict{VarName{:m,Tuple{}},Int64}(), VarName{:m,Tuple{}}[], UnitRange{Int64}[], Float64[], MvNormal{Float64,PDMats.ScalMat{Float64},FillArrays.Zeros{Float64,1,Tuple{Base.OneTo{Int64}}}}[], Set{DynamicPPL.Selector}[], Int64[], Dict{String,BitArray{1}}("del" => [],"trans" => [])),), Base.RefValue{Float64}(0.0), Base.RefValue{Int64}(0))
julia> θ_new = var_info[SampleFromPrior()]
Float64[]
julia> model(var_info)
julia> var_info[SampleFromPrior()]
2-element Array{Float64,1}:
-0.7971105060749638
0.9046609763240063There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Figured it out: empty! clears vi.metadata.$n.vns for every $n in names, so the following never touches _setval_kernel!:
Lines 1167 to 1171 in 334cb98
| quote | |
| for vn in metadata.$n.vns | |
| _setval_kernel!(vi, vn, values, keys) | |
| end | |
| end |
|
Ready for merge? |
devmotion
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy with it, just some final comments:
- If you increase the version number, we can make a new release immediately (given that bors doesn't complain)
- Models with
.~are not supported since onlydot_tilde_assumeis implemented. I added a comment of how we could implementdot_tilde_observefor now, but it's fine with me if we postpone this to a separate PR and do not support such models for the time being. - We should improve the
setval!stuff, maybe copy your debugging steps to a separate issue? (we should also get thegenerated_quantitiesdraft in and probably just useempty!there as long as it is not fixed...)
Will do!
Ah, sorry forgot to respond to that! So you said:
At the moment, I make it clear in the docstring that stuff like
Agree! I actually went back to check the |
I just remembered that we were quite confused about the behaviour of the different suggestions. Basically, I meant that even though it is still unclear to me why |
Just to make sure we're on the same page: I'm also worried that it shouldn't be used in |
|
I think we need to fix #167 before merging this; this should also fail similarily on models using |
|
Waiting with merge until #168 has been merged |
|
bors r+ |
For several reasons, it would be very nice to have a way of extracting the log-likelihoods from a chain. This PR implements the method `loglikelihoods` to do exactly this.
# Up for discussion
1. **Return-value.** Right now it returns a `Dict{String, Vector{Float64})` with the keys being `string(varname)` and the values being an array with the i-th index corresponding to the log-likelihood for `string(varname)` in `chain[i]`. Alternatives:
- Dict of the form `Dict(y => Dict(y[1] => ..., y[2] => ...), ...)`, i.e. "hierarhical"
- Dict of the form `Dict(y[1] => ..., y[2] => ..., ...)`, i.e. "flattened"
- ????
2. **Project structure.** I'm a bit uncertain where to actually put the implementation. As I now experienced, what you actually need to implement for to make a `AbstractSampler` is a bit unclear, e.g. are some methods in `varinfo.jl` which also requires implementation (e.g. `getindex`). So, should I make it it's own file, like I have now, or should I follow suit with `SampleFromPrior` and `SampleFromUniform`?
# Example
```julia
julia> using DynamicPPL, Turing
julia> @model function demo(xs, y)
s ~ InverseGamma(2, 3)
m ~ Normal(0, √s)
for i in eachindex(xs)
xs[i] ~ Normal(m, √s)
end
y ~ Normal(m, √s)
end
demo (generic function with 1 method)
julia> model = demo(randn(3), randn());
julia> chain = sample(model, MH(), 10);
julia> DynamicPPL.loglikelihoods(model, chain)
Dict{String,Array{Float64,1}} with 4 entries:
"xs[3]" => [-1.02616, -1.26931, -1.05003, -5.05458, -1.33825, -1.02904, -1.23761, -1.30128, -1.04872, -2.03716]
"xs[1]" => [-2.08205, -2.51387, -3.03175, -2.5981, -2.31322, -2.62284, -2.70874, -1.18617, -1.36281, -4.39839]
"xs[2]" => [-2.20604, -2.63495, -3.22802, -2.48785, -2.40941, -2.78791, -2.85013, -1.24081, -1.46019, -4.59025]
"y" => [-1.36627, -1.21964, -1.03342, -7.46617, -1.3234, -1.14536, -1.14781, -2.48912, -2.23705, -1.26267]
```
|
Pull request successfully merged into master. Build succeeded: |
For several reasons, it would be very nice to have a way of extracting the log-likelihoods from a chain. This PR implements the method
loglikelihoodsto do exactly this.Up for discussion
Dict{String, Vector{Float64})with the keys beingstring(varname)and the values being an array with the i-th index corresponding to the log-likelihood forstring(varname)inchain[i]. Alternatives:Dict(y => Dict(y[1] => ..., y[2] => ...), ...), i.e. "hierarhical"Dict(y[1] => ..., y[2] => ..., ...), i.e. "flattened"AbstractSampleris a bit unclear, e.g. are some methods invarinfo.jlwhich also requires implementation (e.g.getindex). So, should I make it it's own file, like I have now, or should I follow suit withSampleFromPriorandSampleFromUniform?Example