-
Notifications
You must be signed in to change notification settings - Fork 37
condition and decondition using traits from #286 without ContextualModel
#294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cd1c46d
b877656
b1106ee
0f00771
c754291
2d3f94c
3e5a79f
d4e4238
85a47eb
c7dae8d
14f5f57
d3b6485
75680b3
1692c03
262c86b
f0ae744
22fcae8
b990bb0
8994cd7
8da41c8
94da453
4e74cf8
f9cdfa9
835a41e
5d110d5
560ca83
5635c3b
e78dc65
5f0e4a8
3a408cf
1d3b11e
e1a7d38
b42c34f
c7c60e6
be67807
0468297
80e3d5f
9419e76
4e566f7
d035c23
5c1f18e
d6cd4ff
b27228a
4935d5c
6196083
65048fc
48dda72
649af29
ffdee05
4010ab8
5b26300
bf39000
798798b
93184cc
21c08e5
f52ccdd
35fad36
9297e61
65f8094
b201ef3
9ebdd0e
274ad23
a02f7b8
cf9d168
589507f
8cc3193
f3698cf
26edb2c
eb31a4a
be61ef1
1a40c9c
bdf3fb4
ffe896a
e2f6fc5
987e4e4
a361a5e
ce160d6
41cc734
b492041
7ae9e3e
3b31d35
cefb443
6086734
e579020
2af625c
1d3fa2b
da34dff
5e15b25
386e985
52c9f04
65e7f71
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,19 +20,66 @@ function isassumption(expr::Union{Symbol,Expr}) | |
|
|
||
| return quote | ||
| let $vn = $(varname(expr)) | ||
| # This branch should compile nicely in all cases except for partial missing data | ||
| # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` | ||
| if !$(DynamicPPL.inargnames)($vn, __model__) || | ||
| $(DynamicPPL.inmissings)($vn, __model__) | ||
| true | ||
| if $(DynamicPPL.contextual_isassumption)(__context__, $vn) | ||
| # Considered an assumption by `__context__` which means either: | ||
| # 1. We hit the default implementation, e.g. using `DefaultContext`, | ||
| # which in turn means that we haven't considered if it's one of | ||
| # the model arguments, hence we need to check this. | ||
| # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, | ||
| # i.e. we're trying to condition one of the latent variables. | ||
| # In this case, the below will return `true` since the first branch | ||
| # will be hit. | ||
| # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, | ||
| # i.e. we're trying to override the value. This is currently NOT supported. | ||
| # TODO: Support by adding context to model, and use `model.args` | ||
| # as the default conditioning. Then we no longer need to check `inargnames` | ||
| # since it will all be handled by `contextual_isassumption`. | ||
| if !($(DynamicPPL.inargnames)($vn, __model__)) || | ||
| $(DynamicPPL.inmissings)($vn, __model__) | ||
| true | ||
| else | ||
| $(maybe_view(expr)) === missing | ||
| end | ||
| else | ||
| # Evaluate the LHS | ||
| $(maybe_view(expr)) === missing | ||
| false | ||
| end | ||
| end | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| contextual_isassumption(context, vn) | ||
|
|
||
| Return `true` if `vn` is considered an assumption by `context`. | ||
|
|
||
| The default implementation for `AbstractContext` always returns `true`. | ||
| """ | ||
| contextual_isassumption(::IsLeaf, context, vn) = true | ||
| function contextual_isassumption(::IsParent, context, vn) | ||
| return contextual_isassumption(childcontext(context), vn) | ||
| end | ||
| function contextual_isassumption(context::AbstractContext, vn) | ||
| return contextual_isassumption(NodeTrait(context), context, vn) | ||
| end | ||
| function contextual_isassumption(context::ConditionContext, vn) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current implementation of
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm a bit uncertain what you mean here 😕 This pattern we're using all over the place. Traits represents a default implementation, but we we still allow direct overloads, e.g. here Notice how even the trait-impls above call the impl without the trait argument, thus allowing these specific overloads.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not quite sure - mixing traits and types kind of makes the code hard to read. But I'm happy to leave this as a future effort, given that the functionality is correct.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure 👍 |
||
| if hasvalue(context, vn) | ||
| val = getvalue(context, vn) | ||
| # TODO: Do we even need the `>: Missing` to help the compiler? | ||
| if eltype(val) >: Missing && val === missing | ||
| return true | ||
| else | ||
| return false | ||
| end | ||
| end | ||
|
|
||
| # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}` | ||
| # so we defer to `childcontext` if we haven't concluded that anything yet. | ||
| return contextual_isassumption(childcontext(context), vn) | ||
| end | ||
| function contextual_isassumption(context::PrefixContext, vn) | ||
| return contextual_isassumption(childcontext(context), prefix(context, vn)) | ||
| end | ||
|
|
||
| # failsafe: a literal is never an assumption | ||
| isassumption(expr) = :(false) | ||
|
|
||
|
|
@@ -351,6 +398,11 @@ function generate_tilde(left, right) | |
| __varinfo__, | ||
| ) | ||
| else | ||
| # If `vn` is not in `argnames`, we need to make sure that the variable is defined. | ||
| if !$(DynamicPPL.inargnames)($vn, __model__) | ||
| $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) | ||
| end | ||
|
|
||
| $(DynamicPPL.tilde_observe!)( | ||
| __context__, | ||
| $(DynamicPPL.check_tilde_rhs)($right), | ||
|
|
@@ -395,6 +447,11 @@ function generate_dot_tilde(left, right) | |
| __varinfo__, | ||
| ) | ||
| else | ||
| # If `vn` is not in `argnames`, we need to make sure that the variable is defined. | ||
| if !$(DynamicPPL.inargnames)($vn, __model__) | ||
| $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) | ||
| end | ||
|
|
||
| $(DynamicPPL.dot_tilde_observe!)( | ||
| __context__, | ||
| $(DynamicPPL.check_tilde_rhs)($right), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) | |
| require_gradient(spl::Sampler) = false | ||
| require_particles(spl::Sampler) = false | ||
|
|
||
| _getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds)) | ||
| _getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add some doc on what are the possible
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All types are allowed with |
||
| _getindex(x, inds::Tuple{}) = x | ||
| _getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) | ||
| function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} | ||
| val = getproperty(x, sym) | ||
|
|
||
| # This should work with both cartesian and linear indexing. | ||
| return map(vns) do vn | ||
| _getindex(val, vn) | ||
| end | ||
| end | ||
|
|
||
| # assume | ||
| """ | ||
|
|
@@ -162,6 +171,8 @@ function tilde_observe(context::SamplingContext, right, left, vi) | |
| end | ||
|
|
||
| # Leaf contexts | ||
| # TODO: Should we maybe not do `args...` here but instead be explicit? | ||
| # Could help avoid stealthy bugs. | ||
| function tilde_observe(context::AbstractContext, args...) | ||
| return tilde_observe(NodeTrait(tilde_observe, context), context, args...) | ||
| end | ||
|
|
@@ -177,13 +188,14 @@ tilde_observe(::PriorContext, sampler, right, left, vi) = 0 | |
| function tilde_observe(context::MiniBatchContext, right, left, vi) | ||
| return context.loglike_scalar * tilde_observe(context.context, right, left, vi) | ||
| end | ||
| function tilde_observe(context::MiniBatchContext, right, left, vname, vi) | ||
| return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) | ||
| function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems that we are missing the argument
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this is actually because we never use This is also why this passed through the tests; we have absolutely no tests for the vname-included versions (and some of contexts are missing impls for this even on master). |
||
| return context.loglike_scalar * | ||
| tilde_observe(context.context, sampler, right, left, vname, vi) | ||
| end | ||
|
|
||
| # `PrefixContext` | ||
| function tilde_observe(context::PrefixContext, right, left, vname, vi) | ||
| return tilde_observe(context.context, right, left, prefix(context, vname), vi) | ||
| function tilde_observe(context::PrefixContext, right, left, vi) | ||
| return tilde_observe(context.context, right, left, vi) | ||
| end | ||
|
|
||
| """ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm confused by what will happen in the 3rd situation: will the current implementation error/warn, or it will fail/override silently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, no this just means that in this scenario
ConditionContextwon't have any effect, i.e.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, let us add this to the docstring since it makes things clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do 👍
I'll also add tests now that we've landed on taking this approach.