-
Notifications
You must be signed in to change notification settings - Fork 37
condition and decondition
#278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cd1c46d
b877656
b1106ee
0f00771
c754291
2d3f94c
3e5a79f
d4e4238
85a47eb
c7dae8d
14f5f57
d3b6485
75680b3
1692c03
262c86b
f0ae744
22fcae8
b990bb0
8994cd7
8da41c8
94da453
4e74cf8
f9cdfa9
835a41e
5d110d5
560ca83
5635c3b
e78dc65
5f0e4a8
e1a7d38
4e566f7
5c1f18e
6196083
65048fc
649af29
5b26300
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ using Distributions | |
| using Bijectors | ||
|
|
||
| using AbstractMCMC: AbstractMCMC | ||
| using BangBang: BangBang | ||
| using ChainRulesCore: ChainRulesCore | ||
| using MacroTools: MacroTools | ||
| using ZygoteRules: ZygoteRules | ||
|
|
@@ -67,6 +68,7 @@ export AbstractVarInfo, | |
| vectorize, | ||
| # Model | ||
| Model, | ||
| ContextualModel, | ||
| getmissings, | ||
| getargnames, | ||
| generated_quantities, | ||
|
|
@@ -81,6 +83,7 @@ export AbstractVarInfo, | |
| PriorContext, | ||
| MiniBatchContext, | ||
| PrefixContext, | ||
| ConditionContext, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similarly, maybe keep this
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason why I exported it is because essentially all other contexts are being exported. It also makes |
||
| assume, | ||
| dot_assume, | ||
| observe, | ||
|
|
@@ -99,6 +102,8 @@ export AbstractVarInfo, | |
| logprior, | ||
| logjoint, | ||
| pointwise_loglikelihoods, | ||
| condition, | ||
| decondition, | ||
| # Convenience macros | ||
| @addlogprob!, | ||
| @submodel | ||
|
|
@@ -129,5 +134,6 @@ include("prob_macro.jl") | |
| include("compat/ad.jl") | ||
| include("loglikelihoods.jl") | ||
| include("submodel_macro.jl") | ||
| include("contextual_model.jl") | ||
|
|
||
| end # module | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,19 +20,60 @@ function isassumption(expr::Union{Symbol,Expr}) | |
|
|
||
| return quote | ||
| let $vn = $(varname(expr)) | ||
| # This branch should compile nicely in all cases except for partial missing data | ||
| # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` | ||
| if !$(DynamicPPL.inargnames)($vn, __model__) || | ||
| $(DynamicPPL.inmissings)($vn, __model__) | ||
| true | ||
| if $(DynamicPPL.contextual_isassumption)(__context__, $vn) | ||
| # Considered an assumption by `__context__` which means either: | ||
| # 1. We hit the default implementation, e.g. using `DefaultContext`, | ||
| # which in turn means that we haven't considered if it's one of | ||
| # the model arguments, hence we need to check this. | ||
| # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, | ||
| # i.e. we're trying to condition one of the latent variables. | ||
| # In this case, the below will return `true` since the first branch | ||
| # will be hit. | ||
| # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, | ||
| # i.e. we're trying to override the value. This is currently NOT supported. | ||
| # TODO: Support by adding context to model, and use `model.args` | ||
| # as the default conditioning. Then we no longer need to check `inargnames` | ||
| # since it will all be handled by `contextual_isassumption`. | ||
|
Comment on lines
+34
to
+36
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO this should be deferred for now. It will require a bit of thinking I believe. |
||
| if !($(DynamicPPL.inargnames)($vn, __model__)) | ||
| true | ||
| else | ||
| $expr === missing | ||
| end | ||
| else | ||
| # Evaluate the LHS | ||
| $(maybe_view(expr)) === missing | ||
| false | ||
| end | ||
| end | ||
| end | ||
| end | ||
|
|
||
| """ | ||
| contextual_isassumption(context, vn) | ||
|
|
||
| Return `true` if `vn` is considered an assumption by `context`. | ||
|
|
||
| The default implementation for `AbstractContext` always returns `true`. | ||
| """ | ||
| contextual_isassumption(context::AbstractContext, vn) = true | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is too general, e.g. we want this for |
||
| function contextual_isassumption(context::ConditionContext, vn) | ||
| # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. | ||
|
|
||
| # We have either of the following cases: | ||
| # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, | ||
| # which means we have a value to replace with and we don't need to recurse. | ||
| # 2. One of the decendant contexts consider it as an observation, i.e. | ||
| # `contextual_isassumption` evaluates to `false`. | ||
| # The below then evaluates to `!(false || true) === false`. | ||
| # 3. Neither `context` nor any of it's decendants considers it an observation, | ||
| # in which case the below evaluates to `!(false || false) === true`. | ||
| return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) | ||
| end | ||
| function contextual_isassumption(context::PrefixContext, vn) | ||
| return contextual_isassumption(context.context, prefix(context, vn)) | ||
| end | ||
| function contextual_isassumption(context::MiniBatchContext, vn) | ||
| return contextual_isassumption(context.context, vn) | ||
| end | ||
|
|
||
| # failsafe: a literal is never an assumption | ||
| isassumption(expr) = :(false) | ||
|
|
||
|
|
@@ -336,6 +377,11 @@ function generate_tilde(left, right) | |
| __varinfo__, | ||
| ) | ||
| else | ||
| # If `vn` is not in `argnames`, we need to make sure that the variable is defined. | ||
| if !$(DynamicPPL.inargnames)($vn, __model__) | ||
| $left = $(DynamicPPL.getvalue)(__context__, $vn) | ||
| end | ||
|
Comment on lines
+380
to
+383
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will compile away when possible, just as isassumption, and will preserve existing functionality/meaning of model arguments hence will be non-breaking. |
||
|
|
||
| $(DynamicPPL.tilde_observe!)( | ||
| __context__, | ||
| $(DynamicPPL.check_tilde_rhs)($right), | ||
|
|
@@ -380,6 +426,11 @@ function generate_dot_tilde(left, right) | |
| __varinfo__, | ||
| ) | ||
| else | ||
| # If `vn` is not in `argnames`, we need to make sure that the variable is defined. | ||
| if !$(DynamicPPL.inargnames)($vn, __model__) | ||
| $left .= $(DynamicPPL.getvalue)(__context__, $vn) | ||
| end | ||
|
|
||
| $(DynamicPPL.dot_tilde_observe!)( | ||
| __context__, | ||
| $(DynamicPPL.check_tilde_rhs)($right), | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -106,3 +106,73 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} | |
| VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) | ||
| end | ||
| end | ||
|
|
||
| struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext | ||
| values::Values | ||
| context::Ctx | ||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| function ConditionContext{Values}( | ||
| values::Values, context::AbstractContext | ||
| ) where {names,Values<:NamedTuple{names}} | ||
| return new{names,typeof(values),typeof(context)}(values, context) | ||
| end | ||
| end | ||
|
|
||
| function ConditionContext(context::ConditionContext, child_context::AbstractContext) | ||
phipsgabler marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return ConditionContext(context.values, child_context) | ||
| end | ||
| function ConditionContext(values::NamedTuple) | ||
| return ConditionContext(values, DefaultContext()) | ||
| end | ||
| function ConditionContext(values::NamedTuple, context::AbstractContext) | ||
| return ConditionContext{typeof(values)}(values, context) | ||
| end | ||
|
|
||
| # Try to avoid nested `ConditionContext`. | ||
| function ConditionContext( | ||
| values::NamedTuple{Names}, context::ConditionContext | ||
| ) where {Names} | ||
| # Note that this potentially overrides values from `context`, thus giving | ||
| # precedence to the outmost `ConditionContext`. | ||
| return ConditionContext(merge(context.values, values), context.context) | ||
| end | ||
|
|
||
| function Base.show(io::IO, context::ConditionContext) | ||
| return print(io, "ConditionContext($(context.values), $(context.context))") | ||
| end | ||
|
|
||
| function getvalue(context::ConditionContext, vn) | ||
| return if haskey(context, vn) | ||
| _getvalue(context.values, vn) | ||
| else | ||
| getvalue(context.context, vn) | ||
| end | ||
| end | ||
|
Comment on lines
+144
to
+150
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This currently behaves a bit strangely in certain cases: julia> @model function outer_with_inner_condition(x)
y = Vector(undef, length(x))
for i in eachindex(x)
# Here we'll end up with `PrefixContext{..., <:ConditionContext}`
y[i] = @submodel $(Symbol("y[$i]")) condition(inner(x[i]), m=10.0)
end
return y
end
outer_with_inner_condition (generic function with 1 method)
julia> m = outer_with_inner_condition([1.0, 2.0])
Model{var"#49#50", (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}}(:outer_with_inner_condition, var"#49#50"(), (x = [1.0, 2.0],), NamedTuple())
julia> m() # (✓) both `m` are set as specified within `outer_with_inner_condition`
2-element Vector{Any}:
(m = 10.0, x = 1.0)
(m = 10.0, x = 2.0)
julia> condition(m, var"y[1].m"=5.0)() # (×) try to override one of the inner `m`
2-element Vector{Any}:
(m = 10.0, x = 1.0)
(m = 10.0, x = 2.0)But we can change function getvalue(context::ConditionContext, vn)
# Return early if we've already found our value, thus giving precedence
# to the inner-most `ConditionContext`.
maybeval = getvalue(context.context, vn)
maybeval === nothing || return maybeval
return haskey(context, vn) ? _getvalue(context.values, vn) : nothing
endfor which we then get the "desired" behavior in the above example: julia> condition(m, var"y[1].m"=5.0)() # (✓) try to override one of the inner `m`
2-element Vector{Any}:
(m = 5.0, x = 1.0)
(m = 10.0, x = 2.0)but this does mean that we give precedence to the inner-most We could potentially introduce functionality that will recurse into the context-tree to remove all other mentionings of the variables, i.e. at most one |
||
| getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) | ||
| getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) | ||
|
|
||
| function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} | ||
| # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. | ||
| return sym in vars | ||
| end | ||
|
|
||
| function Base.haskey( | ||
| context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} | ||
| ) where {vars,sym} | ||
| # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. | ||
| return sym in vars | ||
| end | ||
|
|
||
| # TODO: Can we maybe do this in a better way? | ||
| # When no second argument is given, we remove _all_ conditioned variables. | ||
| # TODO: Should we remove this and just return `context.context`? | ||
| # That will work better if `Model` becomes like `ContextualModel`. | ||
|
Comment on lines
+166
to
+169
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be addressed nicely with #286 |
||
| decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe rename this to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO the different types for the arguments already disambiguiates the two 😕 |
||
| function decondition(context::ConditionContext, sym) | ||
| return ConditionContext(BangBang.delete!!(context.values, sym), context.context) | ||
| end | ||
| function decondition(context::ConditionContext, sym, syms...) | ||
| return decondition( | ||
| ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... | ||
| ) | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel | ||
| context::Ctx | ||
| model::M | ||
| end | ||
|
|
||
| function contextualize(model::AbstractModel, context::AbstractContext) | ||
| return ContextualModel(context, model) | ||
| end | ||
|
|
||
| # TODO: What do we do for other contexts? Could handle this in general if we had a | ||
| # notion of wrapper-, primitive-context, etc. | ||
| function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) | ||
| # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as | ||
| # `ConditionContext` child. | ||
| return _evaluate( | ||
| cmodel.model, varinfo, ConditionContext(cmodel.context.values, context) | ||
| ) | ||
| end | ||
|
|
||
| condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) | ||
| condition(model::AbstractModel; values...) = condition(model, (; values...)) | ||
| function condition(cmodel::ContextualModel{<:ConditionContext}, values) | ||
| return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) | ||
| end | ||
|
|
||
| decondition(model::AbstractModel, args...) = model | ||
| function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) | ||
| return contextualize(cmodel.model, decondition(cmodel.context, syms...)) | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| abstract type AbstractModel <: AbstractProbabilisticProgram end | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we set up a type alias instead of creating a new level of the hierarchy?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately not because then methods such as function (model::AbstractModel)(args...)
return model(Random.GLOBAL_RNG, args...)
endbecome type-piracy. |
||
|
|
||
| """ | ||
| struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} | ||
| name::Symbol | ||
|
|
@@ -32,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition | |
| Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) | ||
| ``` | ||
| """ | ||
| struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: | ||
| AbstractProbabilisticProgram | ||
| struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel | ||
| name::Symbol | ||
| f::F | ||
| args::NamedTuple{argnames,Targs} | ||
|
|
@@ -82,7 +83,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a | |
| The method resets the log joint probability of `varinfo` and increases the evaluation | ||
| number of `sampler`. | ||
| """ | ||
| function (model::Model)( | ||
| function (model::AbstractModel)( | ||
| rng::Random.AbstractRNG, | ||
| varinfo::AbstractVarInfo=VarInfo(), | ||
| sampler::AbstractSampler=SampleFromPrior(), | ||
|
|
@@ -91,26 +92,26 @@ function (model::Model)( | |
| return model(varinfo, SamplingContext(rng, sampler, context)) | ||
| end | ||
|
|
||
| (model::Model)(context::AbstractContext) = model(VarInfo(), context) | ||
| function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) | ||
| (model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) | ||
| function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext) | ||
| if Threads.nthreads() == 1 | ||
| return evaluate_threadunsafe(model, varinfo, context) | ||
| else | ||
| return evaluate_threadsafe(model, varinfo, context) | ||
| end | ||
| end | ||
|
|
||
| function (model::Model)(args...) | ||
| function (model::AbstractModel)(args...) | ||
| return model(Random.GLOBAL_RNG, args...) | ||
| end | ||
|
|
||
| # without VarInfo | ||
| function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) | ||
| function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) | ||
| return model(rng, VarInfo(), sampler, args...) | ||
| end | ||
|
|
||
| # without VarInfo and without AbstractSampler | ||
| function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) | ||
| function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) | ||
| return model(rng, VarInfo(), SampleFromPrior(), context) | ||
| end | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe keep this
internaluntil we have a good case of exporting it.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think any case that can be made for exporting
Modelnow also can be made forContextualModel, don't you think?