Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
cd1c46d
added ConditionContext and ContextualModel
torfjelde Jul 15, 2021
b877656
Merge branch 'master' into tor/conditioning
torfjelde Jul 15, 2021
b1106ee
removed redundant definition
torfjelde Jul 15, 2021
0f00771
return condition model by default
torfjelde Jul 15, 2021
c754291
formatting
torfjelde Jul 15, 2021
2d3f94c
forgot to include contextual model in previous commit
torfjelde Jul 15, 2021
3e5a79f
fixed typos
torfjelde Jul 15, 2021
d4e4238
added some niceties to ConditionContext
torfjelde Jul 15, 2021
85a47eb
added support for vectors of VarName
torfjelde Jul 16, 2021
c7dae8d
Update src/contexts.jl
torfjelde Jul 16, 2021
14f5f57
Merge branch 'master' into tor/conditioning
yebai Jul 16, 2021
d3b6485
Merge branch 'master' into tor/conditioning
yebai Jul 18, 2021
75680b3
upper-bound Distributions.jl in tests
torfjelde Jul 19, 2021
1692c03
Merge branch 'master' into tor/conditioning
torfjelde Jul 19, 2021
262c86b
Merge branch 'tor/conditioning' of github.com:TuringLang/DynamicPPL.j…
torfjelde Jul 19, 2021
f0ae744
make the isassumption check using context extensible and nicer
torfjelde Jul 19, 2021
22fcae8
Merge branch 'tor/upper-bound-distributions' into tor/conditioning
torfjelde Jul 19, 2021
b990bb0
renamed type-parameter for ConditionContext
torfjelde Jul 19, 2021
8994cd7
Update src/contexts.jl
torfjelde Jul 19, 2021
8da41c8
Merge branch 'master' into tor/conditioning
torfjelde Jul 21, 2021
94da453
introduced convenient _getvalue method
torfjelde Jul 21, 2021
4e74cf8
overload tilde_assume rather than tilde_assume! and others for Condit…
torfjelde Jul 21, 2021
f9cdfa9
added contextual_isassumption for PrefixContext
torfjelde Jul 21, 2021
835a41e
implemented contextual_isassumption for all contexts
torfjelde Jul 21, 2021
5d110d5
improved the way ConditionContext works signficantly
torfjelde Jul 21, 2021
560ca83
forgot impl of dot_tilde_observe for ConditionContext
torfjelde Jul 21, 2021
5635c3b
Apply suggestions from code review
torfjelde Jul 21, 2021
e78dc65
overload _evaluate rather than the model-call directly
torfjelde Jul 22, 2021
5f0e4a8
Merge branch 'tor/conditioning' of github.com:TuringLang/DynamicPPL.j…
torfjelde Jul 22, 2021
e1a7d38
formatting
torfjelde Jul 24, 2021
4e566f7
remove the drop_missing as it is no longer needed
torfjelde Jul 24, 2021
5c1f18e
Merge branch 'master' into tor/conditioning
torfjelde Jul 24, 2021
6196083
made show a bit nicer for ConditionContext
torfjelde Jul 24, 2021
65048fc
use print instead of println in show
torfjelde Jul 24, 2021
649af29
formatting
torfjelde Jul 24, 2021
5b26300
Merge branch 'master' into tor/conditioning
torfjelde Jul 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.13.0"
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
6 changes: 6 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Distributions
using Bijectors

using AbstractMCMC: AbstractMCMC
using BangBang: BangBang
using ChainRulesCore: ChainRulesCore
using MacroTools: MacroTools
using ZygoteRules: ZygoteRules
Expand Down Expand Up @@ -67,6 +68,7 @@ export AbstractVarInfo,
vectorize,
# Model
Model,
ContextualModel,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe keep this internal until we have a good case of exporting it.

Copy link
Member Author

@torfjelde torfjelde Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think any case that can be made for exporting Model now also can be made for ContextualModel, don't you think?

getmissings,
getargnames,
generated_quantities,
Expand All @@ -81,6 +83,7 @@ export AbstractVarInfo,
PriorContext,
MiniBatchContext,
PrefixContext,
ConditionContext,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, maybe keep this internal until we have a good case of exporting it.

Copy link
Member Author

@torfjelde torfjelde Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I exported it is because essentially all other contexts are being exported. It also makes show nicer, tbh. But I'm also fine with not exporting, up to you 👍

assume,
dot_assume,
observe,
Expand All @@ -99,6 +102,8 @@ export AbstractVarInfo,
logprior,
logjoint,
pointwise_loglikelihoods,
condition,
decondition,
# Convenience macros
@addlogprob!,
@submodel
Expand Down Expand Up @@ -129,5 +134,6 @@ include("prob_macro.jl")
include("compat/ad.jl")
include("loglikelihoods.jl")
include("submodel_macro.jl")
include("contextual_model.jl")

end # module
65 changes: 58 additions & 7 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,60 @@ function isassumption(expr::Union{Symbol,Expr})

return quote
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, __model__) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
Comment on lines +34 to +36
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this should be deferred for now. It will require a bit of thinking I believe.

if !($(DynamicPPL.inargnames)($vn, __model__))
true
else
$expr === missing
end
else
# Evaluate the LHS
$(maybe_view(expr)) === missing
false
end
end
end
end

"""
contextual_isassumption(context, vn)

Return `true` if `vn` is considered an assumption by `context`.

The default implementation for `AbstractContext` always returns `true`.
"""
contextual_isassumption(context::AbstractContext, vn) = true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is too general, e.g. we want this for DefaultContext and other "leaf" contexts but not for "parent" contexts, e.g. MiniBatchContext.

function contextual_isassumption(context::ConditionContext, vn)
# We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`.

# We have either of the following cases:
# 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key,
# which means we have a value to replace with and we don't need to recurse.
# 2. One of the decendant contexts consider it as an observation, i.e.
# `contextual_isassumption` evaluates to `false`.
# The below then evaluates to `!(false || true) === false`.
# 3. Neither `context` nor any of it's decendants considers it an observation,
# in which case the below evaluates to `!(false || false) === true`.
return !(haskey(context, vn) || !contextual_isassumption(context.context, vn))
end
function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(context.context, prefix(context, vn))
end
function contextual_isassumption(context::MiniBatchContext, vn)
return contextual_isassumption(context.context, vn)
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

Expand Down Expand Up @@ -336,6 +377,11 @@ function generate_tilde(left, right)
__varinfo__,
)
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getvalue)(__context__, $vn)
end
Comment on lines +380 to +383
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will compile away when possible, just as isassumption, and will preserve existing functionality/meaning of model arguments hence will be non-breaking.


$(DynamicPPL.tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
Expand Down Expand Up @@ -380,6 +426,11 @@ function generate_dot_tilde(left, right)
__varinfo__,
)
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getvalue)(__context__, $vn)
end

$(DynamicPPL.dot_tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
Expand Down
44 changes: 43 additions & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds))
_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds))
_getindex(x, inds::Tuple{}) = x
_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing)
function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym}
val = getproperty(x, sym)

# This should work with both cartesian and linear indexing.
return map(vns) do vn
_getindex(val, vn)
end
end

# assume
"""
Expand Down Expand Up @@ -118,6 +127,15 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi)
return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi)
end

# `ConditionContext`
function tilde_assume(context::ConditionContext, right, vn, inds, vi)
return tilde_assume(context.context, right, vn, inds, vi)
end

function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi)
return tilde_assume(rng, context.context, sampler, right, vn, inds, vi)
end

"""
tilde_assume!(context, right, vn, inds, vi)

Expand Down Expand Up @@ -189,6 +207,14 @@ function tilde_observe(context::PrefixContext, right, left, vi)
return tilde_observe(context.context, right, left, vi)
end

# `ConditionContext`
function tilde_observe(context::ConditionContext, right, left, vname, vi)
return tilde_observe(context.context, right, left, vname, vi)
end
function tilde_observe(context::ConditionContext, right, left, vi)
return tilde_observe(context.context, right, left, vi)
end

"""
tilde_observe!(context, right, left, vname, vinds, vi)

Expand Down Expand Up @@ -402,6 +428,17 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn,
)
end

# `ConditionContext`
function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi)
return dot_tilde_assume(context.context, right, left, vn, inds, vi)
end

function dot_tilde_assume(
rng, context::ConditionContext, sampler, right, left, vn, inds, vi
)
return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi)
end

"""
dot_tilde_assume!(context, right, left, vn, inds, vi)

Expand Down Expand Up @@ -609,6 +646,11 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi)
return dot_tilde_observe(context.context, right, left, vi)
end

# `ConditionContext`
function dot_tilde_observe(context::ConditionContext, right, left, vi)
return dot_tilde_observe(context.context, right, left, vi)
end

"""
dot_tilde_observe!(context, right, left, vname, vinds, vi)

Expand Down
70 changes: 70 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,73 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing)
end
end

struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx

function ConditionContext{Values}(
values::Values, context::AbstractContext
) where {names,Values<:NamedTuple{names}}
return new{names,typeof(values),typeof(context)}(values, context)
end
end

function ConditionContext(context::ConditionContext, child_context::AbstractContext)
return ConditionContext(context.values, child_context)
end
function ConditionContext(values::NamedTuple)
return ConditionContext(values, DefaultContext())
end
function ConditionContext(values::NamedTuple, context::AbstractContext)
return ConditionContext{typeof(values)}(values, context)
end

# Try to avoid nested `ConditionContext`.
function ConditionContext(
values::NamedTuple{Names}, context::ConditionContext
) where {Names}
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `ConditionContext`.
return ConditionContext(merge(context.values, values), context.context)
end

function Base.show(io::IO, context::ConditionContext)
return print(io, "ConditionContext($(context.values), $(context.context))")
end

function getvalue(context::ConditionContext, vn)
return if haskey(context, vn)
_getvalue(context.values, vn)
else
getvalue(context.context, vn)
end
end
Comment on lines +144 to +150
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This currently behaves a bit strangely in certain cases:

julia> @model function outer_with_inner_condition(x)
           y = Vector(undef, length(x))
           for i in eachindex(x)
               # Here we'll end up with `PrefixContext{..., <:ConditionContext}`
               y[i] = @submodel $(Symbol("y[$i]")) condition(inner(x[i]), m=10.0)
           end
           return y
       end
outer_with_inner_condition (generic function with 1 method)

julia> m = outer_with_inner_condition([1.0, 2.0])
Model{var"#49#50", (:x,), (), (), Tuple{Vector{Float64}}, Tuple{}}(:outer_with_inner_condition, var"#49#50"(), (x = [1.0, 2.0],), NamedTuple())

julia> m() # (✓) both `m` are set as specified within `outer_with_inner_condition`
2-element Vector{Any}:
 (m = 10.0, x = 1.0)
 (m = 10.0, x = 2.0)

julia> condition(m, var"y[1].m"=5.0)() # (×) try to override one of the inner `m`
2-element Vector{Any}:
 (m = 10.0, x = 1.0)
 (m = 10.0, x = 2.0)

But we can change getvalue to:

function getvalue(context::ConditionContext, vn)
    # Return early if we've already found our value, thus giving precedence
    # to the inner-most `ConditionContext`.
    maybeval = getvalue(context.context, vn)
    maybeval === nothing || return maybeval

    return haskey(context, vn) ? _getvalue(context.values, vn) : nothing
end

for which we then get the "desired" behavior in the above example:

julia> condition(m, var"y[1].m"=5.0)() # (✓) try to override one of the inner `m`
2-element Vector{Any}:
 (m = 5.0, x = 1.0)
 (m = 10.0, x = 2.0)

but this does mean that we give precedence to the inner-most ConditionContext rather than the outermost, i.e. the one that was applied most recently, which is counter-intuitive and can be difficult to work with.

We could potentially introduce functionality that will recurse into the context-tree to remove all other mentionings of the variables, i.e. at most one CouplingContext in the context-tree contains a mentioning of any specific variable. cthis also makes sense when considering decondition(ctx, :x) since we'd expect this to traverse the entire context-tree and remove every mentioning of x rather than only doing so in the outer-most ConditionContext. Buuuut this brings us back to #254 again, since it would require functionality for such a tree-traversal and other things.

getvalue(context::AbstractContext, vn) = getvalue(context.context, vn)
getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn))

function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym}
# TODO: Add possibility of indexed variables, e.g. `x[1]`, etc.
return sym in vars
end

function Base.haskey(
context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}
) where {vars,sym}
# TODO: Add possibility of indexed variables, e.g. `x[1]`, etc.
return sym in vars
end

# TODO: Can we maybe do this in a better way?
# When no second argument is given, we remove _all_ conditioned variables.
# TODO: Should we remove this and just return `context.context`?
# That will work better if `Model` becomes like `ContextualModel`.
Comment on lines +166 to +169
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be addressed nicely with #286

decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this to _decondition to avoid confusion with decondition(m, ...)?

@phipsgabler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO the different types for the arguments already disambiguiates the two 😕

function decondition(context::ConditionContext, sym)
return ConditionContext(BangBang.delete!!(context.values, sym), context.context)
end
function decondition(context::ConditionContext, sym, syms...)
return decondition(
ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms...
)
end
29 changes: 29 additions & 0 deletions src/contextual_model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel
context::Ctx
model::M
end

function contextualize(model::AbstractModel, context::AbstractContext)
return ContextualModel(context, model)
end

# TODO: What do we do for other contexts? Could handle this in general if we had a
# notion of wrapper-, primitive-context, etc.
function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context)
# Wrap `context` in the model-associated `ConditionContext`, but now using `context` as
# `ConditionContext` child.
return _evaluate(
cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)
)
end

condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values))
condition(model::AbstractModel; values...) = condition(model, (; values...))
function condition(cmodel::ContextualModel{<:ConditionContext}, values)
return contextualize(cmodel.model, ConditionContext(values, cmodel.context))
end

decondition(model::AbstractModel, args...) = model
function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...)
return contextualize(cmodel.model, decondition(cmodel.context, syms...))
end
8 changes: 8 additions & 0 deletions src/loglikelihoods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ function PointwiseLikelihoodContext(
)
end

function contextual_isassumption(context::PointwiseLikelihoodContext, vn)
return contextual_isassumption(context.context, vn)
end

function contextual_isobservation(context::PointwiseLikelihoodContext, vn)
return contextual_isobservation(context.context, vn)
end

function Base.push!(
context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}},
vn::VarName,
Expand Down
17 changes: 9 additions & 8 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
abstract type AbstractModel <: AbstractProbabilisticProgram end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we set up a type alias instead of creating a new level of the hierarchy?

Copy link
Member Author

@torfjelde torfjelde Jul 19, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately not because then methods such as

function (model::AbstractModel)(args...)
    return model(Random.GLOBAL_RNG, args...)
end

become type-piracy.


"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults}
name::Symbol
Expand Down Expand Up @@ -32,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
```
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <:
AbstractProbabilisticProgram
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel
name::Symbol
f::F
args::NamedTuple{argnames,Targs}
Expand Down Expand Up @@ -82,7 +83,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a
The method resets the log joint probability of `varinfo` and increases the evaluation
number of `sampler`.
"""
function (model::Model)(
function (model::AbstractModel)(
rng::Random.AbstractRNG,
varinfo::AbstractVarInfo=VarInfo(),
sampler::AbstractSampler=SampleFromPrior(),
Expand All @@ -91,26 +92,26 @@ function (model::Model)(
return model(varinfo, SamplingContext(rng, sampler, context))
end

(model::Model)(context::AbstractContext) = model(VarInfo(), context)
function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext)
(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context)
function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext)
if Threads.nthreads() == 1
return evaluate_threadunsafe(model, varinfo, context)
else
return evaluate_threadsafe(model, varinfo, context)
end
end

function (model::Model)(args...)
function (model::AbstractModel)(args...)
return model(Random.GLOBAL_RNG, args...)
end

# without VarInfo
function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...)
function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...)
return model(rng, VarInfo(), sampler, args...)
end

# without VarInfo and without AbstractSampler
function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext)
function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext)
return model(rng, VarInfo(), SampleFromPrior(), context)
end

Expand Down
Loading