-
Notifications
You must be signed in to change notification settings - Fork 230
ESS produces the wrong result certain models #1633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
48b8463
ca81eb0
df8bb42
9898865
7351562
20267ce
4a931c1
48030eb
92cabdf
966b724
2cc253b
2b5c5e1
cab751a
20daa3e
169a014
64a816a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,7 +112,9 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) | |
| sampler = p.sampler | ||
| varinfo = p.varinfo | ||
| vns = _getvns(varinfo, sampler) | ||
| set_flag!(varinfo, vns[1][1], "del") | ||
| for vn in Iterators.flatten(values(vns)) | ||
| set_flag!(varinfo, vn, "del") | ||
| end | ||
| p.model(rng, varinfo, sampler) | ||
| return varinfo[sampler] | ||
| end | ||
|
|
@@ -155,6 +157,6 @@ function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, | |
| end | ||
| end | ||
|
|
||
| function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) | ||
| return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vi) | ||
| function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) | ||
| return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi) | ||
|
Comment on lines
+160
to
+161
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this a bug or is needed because of recent changes in DynamicPPL?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was a bug. This is never actually hit and given the arguments I'm assuming it was intended to be a |
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,14 +34,15 @@ function Bijectors.bijector( | |
| end | ||
|
|
||
| bs = Bijectors.bijector.(tuple(dists...)) | ||
| rs = tuple(ranges...) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is not related to ESS?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correct; this is just to ensure that we have the same behavior as we had before the most recent release of Bijectors.jl. I just noticed it when trying to figure out why the tests weren't passing. |
||
|
|
||
| if sym2ranges | ||
| return ( | ||
| Bijectors.Stacked(bs, ranges), | ||
| Bijectors.Stacked(bs, rs), | ||
| (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), | ||
| ) | ||
| else | ||
| return Bijectors.Stacked(bs, ranges) | ||
| return Bijectors.Stacked(bs, rs) | ||
| end | ||
| end | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -51,3 +51,94 @@ MoGtest_default = MoGtest([1.0 1.0 4.0 4.0]) | |
|
|
||
| # Declare empty model to make the Sampler constructor work. | ||
| @model empty_model() = begin x = 1; end | ||
|
|
||
| # A collection of models for which the mean-of-means for the posterior should | ||
| # be same. | ||
|
Comment on lines
+55
to
+56
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, what exactly do you mean with
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to have a collection of models which tries out all the combinations of Buuuuut now that we're comparing to the true mean rather than pitting the different models against each other, I guess we don't need to do that 😅 |
||
| @model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} | ||
| # `dot_assume` and `observe` | ||
| m = TV(undef, length(x)) | ||
| m .~ Normal() | ||
| x ~ MvNormal(m, 0.5 * ones(length(x))) | ||
| end | ||
|
|
||
| @model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} | ||
| # `assume` with indexing and `observe` | ||
| m = TV(undef, length(x)) | ||
| for i in eachindex(m) | ||
| m[i] ~ Normal() | ||
| end | ||
| x ~ MvNormal(m, 0.5 * ones(length(x))) | ||
| end | ||
|
|
||
| @model function gdemo3(x = 10 * ones(2)) | ||
| # Multivariate `assume` and `observe` | ||
| m ~ MvNormal(length(x), 1.0) | ||
| x ~ MvNormal(m, 0.5 * ones(length(x))) | ||
| end | ||
|
|
||
| @model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} | ||
| # `dot_assume` and `observe` with indexing | ||
| m = TV(undef, length(x)) | ||
| m .~ Normal() | ||
| for i in eachindex(x) | ||
| x[i] ~ Normal(m[i], 0.5) | ||
| end | ||
| end | ||
|
|
||
| # Using vector of `length` 1 here so the posterior of `m` is the same | ||
| # as the others. | ||
| @model function gdemo5(x = 10 * ones(1)) | ||
| # `assume` and `dot_observe` | ||
| m ~ Normal() | ||
| x .~ Normal(m, 0.5) | ||
| end | ||
|
|
||
| @model function gdemo6() | ||
| # `assume` and literal `observe` | ||
| m ~ MvNormal(2, 1.0) | ||
| [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) | ||
| end | ||
|
|
||
| @model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} | ||
| # `dot_assume` and literal `observe` with indexing | ||
| m = TV(undef, 2) | ||
| m .~ Normal() | ||
| for i in eachindex(m) | ||
| 10.0 ~ Normal(m[i], 0.5) | ||
| end | ||
| end | ||
|
|
||
| @model function gdemo8() | ||
| # `assume` and literal `dot_observe` | ||
| m ~ Normal() | ||
| [10.0, ] .~ Normal(m, 0.5) | ||
| end | ||
|
|
||
| @model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} | ||
| m = TV(undef, 2) | ||
| m .~ Normal() | ||
|
|
||
| return m | ||
| end | ||
|
|
||
| @model function gdemo9() | ||
| # Submodel prior | ||
| m = @submodel _prior_dot_assume() | ||
| for i in eachindex(m) | ||
| 10.0 ~ Normal(m[i], 0.5) | ||
| end | ||
| end | ||
|
|
||
| @model function _likelihood_dot_observe(m, x) | ||
| x ~ MvNormal(m, 0.5 * ones(length(m))) | ||
| end | ||
|
|
||
| @model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} | ||
| m = TV(undef, length(x)) | ||
| m .~ Normal() | ||
|
|
||
| # Submodel likelihood | ||
| @submodel _likelihood_dot_observe(m, x) | ||
| end | ||
|
|
||
| const gdemo_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo6(), gdemo7(), gdemo8(), gdemo9(), gdemo10()) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add a context that does this 😛 At least it would be more convenient than dealing with DynamicPPL internals here. I remember that I was very confused and uncertain if I did it correctly when I implemented this. It seemed to work 😬
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works sometimes because in
assumewe only check if the flag is set forvns[1], but that is only for "sub-symbols". This is the line I'm referring to: https://github.com/TuringLang/DynamicPPL.jl/blob/9083299db3f623136895cae80ef5f10d7fcf8d2c/src/context_implementations.jl#L268. But this won't work if we have more than one key invi.metadataor iftilde_assumeand others are called with a varname subsumed e.g.m[2].And this won't be an issue once we have a clear separation between sampling and evaluation. These sorts of bugs show up soooo often (and I agree, it's super-confusing), so looking forward to not having those:)