Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
cd1c46d
added ConditionContext and ContextualModel
torfjelde Jul 15, 2021
b877656
Merge branch 'master' into tor/conditioning
torfjelde Jul 15, 2021
b1106ee
removed redundant definition
torfjelde Jul 15, 2021
0f00771
return condition model by default
torfjelde Jul 15, 2021
c754291
formatting
torfjelde Jul 15, 2021
2d3f94c
forgot to include contextual model in previous commit
torfjelde Jul 15, 2021
3e5a79f
fixed typos
torfjelde Jul 15, 2021
d4e4238
added some niceties to ConditionContext
torfjelde Jul 15, 2021
85a47eb
added support for vectors of VarName
torfjelde Jul 16, 2021
c7dae8d
Update src/contexts.jl
torfjelde Jul 16, 2021
14f5f57
Merge branch 'master' into tor/conditioning
yebai Jul 16, 2021
d3b6485
Merge branch 'master' into tor/conditioning
yebai Jul 18, 2021
75680b3
upper-bound Distributions.jl in tests
torfjelde Jul 19, 2021
1692c03
Merge branch 'master' into tor/conditioning
torfjelde Jul 19, 2021
262c86b
Merge branch 'tor/conditioning' of github.com:TuringLang/DynamicPPL.j…
torfjelde Jul 19, 2021
f0ae744
make the isassumption check using context extensible and nicer
torfjelde Jul 19, 2021
22fcae8
Merge branch 'tor/upper-bound-distributions' into tor/conditioning
torfjelde Jul 19, 2021
b990bb0
renamed type-parameter for ConditionContext
torfjelde Jul 19, 2021
8994cd7
Update src/contexts.jl
torfjelde Jul 19, 2021
8da41c8
Merge branch 'master' into tor/conditioning
torfjelde Jul 21, 2021
94da453
introduced convenient _getvalue method
torfjelde Jul 21, 2021
4e74cf8
overload tilde_assume rather than tilde_assume! and others for Condit…
torfjelde Jul 21, 2021
f9cdfa9
added contextual_isassumption for PrefixContext
torfjelde Jul 21, 2021
835a41e
implemented contextual_isassumption for all contexts
torfjelde Jul 21, 2021
5d110d5
improved the way ConditionContext works signficantly
torfjelde Jul 21, 2021
560ca83
forgot impl of dot_tilde_observe for ConditionContext
torfjelde Jul 21, 2021
5635c3b
Apply suggestions from code review
torfjelde Jul 21, 2021
e78dc65
overload _evaluate rather than the model-call directly
torfjelde Jul 22, 2021
5f0e4a8
Merge branch 'tor/conditioning' of github.com:TuringLang/DynamicPPL.j…
torfjelde Jul 22, 2021
3a408cf
Merge branch 'tor/context-traits' into tor/conditioning-with-traits
torfjelde Jul 24, 2021
1d3b11e
drop now unneceesary impls for tilds for ConditionContext
torfjelde Jul 24, 2021
e1a7d38
formatting
torfjelde Jul 24, 2021
b42c34f
address issues using traits
torfjelde Jul 24, 2021
c7c60e6
added rewrap for contexts
torfjelde Jul 24, 2021
be67807
do decondition properly
torfjelde Jul 24, 2021
0468297
added some examples and decondition now removes ConditionContext
torfjelde Jul 24, 2021
80e3d5f
improved condition and decondition a bit further
torfjelde Jul 24, 2021
9419e76
use rewrap in _evaluate for ContextualModel
torfjelde Jul 24, 2021
4e566f7
remove the drop_missing as it is no longer needed
torfjelde Jul 24, 2021
d035c23
Merge branch 'tor/conditioning' into tor/conditioning-with-traits
torfjelde Jul 24, 2021
5c1f18e
Merge branch 'master' into tor/conditioning
torfjelde Jul 24, 2021
d6cd4ff
Merge branch 'tor/conditioning' into tor/conditioning-with-traits
torfjelde Jul 24, 2021
b27228a
rename rewrap to setchildcontet
torfjelde Jul 24, 2021
4935d5c
formatting
torfjelde Jul 24, 2021
6196083
made show a bit nicer for ConditionContext
torfjelde Jul 24, 2021
65048fc
use print instead of println in show
torfjelde Jul 24, 2021
48dda72
Merge branch 'tor/conditioning' into tor/conditioning-with-traits
torfjelde Jul 24, 2021
649af29
formatting
torfjelde Jul 24, 2021
ffdee05
Merge branch 'tor/conditioning' into tor/conditioning-with-traits
torfjelde Jul 25, 2021
5b26300
Merge branch 'master' into tor/conditioning
torfjelde Jul 28, 2021
bf39000
Merge branch 'tor/conditioning' into tor/conditioning-with-traits
torfjelde Jul 28, 2021
93184cc
Merge branch 'tor/context-traits' into tor/conditioning-with-traits
torfjelde Jul 28, 2021
21c08e5
dont overload haskey and improved contextual_isassumption check
torfjelde Jul 28, 2021
b201ef3
removed leftovers
torfjelde Jul 28, 2021
cf9d168
Merge branch 'tor/context-traits' into tor/conditioning-with-traits
torfjelde Jul 29, 2021
8cc3193
improved condition and fixed isassumption
torfjelde Jul 29, 2021
26edb2c
added BangBang as dep
torfjelde Jul 29, 2021
1a40c9c
Merge branch 'tor/context-traits' into tor/conditioning-with-traits
torfjelde Jul 29, 2021
ffe896a
Merge branch 'tor/context-traits' into tor/conditioning-with-traits
torfjelde Jul 29, 2021
e2f6fc5
formatting
torfjelde Jul 29, 2021
ce160d6
fixed condition without arguments
torfjelde Jul 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.13.0"
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand All @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
AbstractMCMC = "2, 3.0"
AbstractPPL = "0.2"
BangBang = "0.3"
Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9"
ChainRulesCore = "0.9.7, 0.10"
Distributions = "0.23.8, 0.24, 0.25"
Expand Down
5 changes: 5 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using AbstractMCMC: AbstractMCMC
using ChainRulesCore: ChainRulesCore
using MacroTools: MacroTools
using ZygoteRules: ZygoteRules
using BangBang: BangBang

using Random: Random

Expand Down Expand Up @@ -67,6 +68,7 @@ export AbstractVarInfo,
vectorize,
# Model
Model,
ContextualModel,
getmissings,
getargnames,
generated_quantities,
Expand All @@ -81,6 +83,7 @@ export AbstractVarInfo,
PriorContext,
MiniBatchContext,
PrefixContext,
ConditionContext,
assume,
dot_assume,
observe,
Expand All @@ -99,6 +102,8 @@ export AbstractVarInfo,
logprior,
logjoint,
pointwise_loglikelihoods,
condition,
decondition,
# Convenience macros
@addlogprob!,
@submodel
Expand Down
69 changes: 62 additions & 7 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,64 @@ function isassumption(expr::Union{Symbol,Expr})

return quote
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, __model__) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
if $(DynamicPPL.contextual_isassumption)(__context__, $vn)
# Considered an assumption by `__context__` which means either:
# 1. We hit the default implementation, e.g. using `DefaultContext`,
# which in turn means that we haven't considered if it's one of
# the model arguments, hence we need to check this.
# 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments,
# i.e. we're trying to condition one of the latent variables.
# In this case, the below will return `true` since the first branch
# will be hit.
# 3. We are working with a `ConditionContext` _and_ it's in the model arguments,
# i.e. we're trying to override the value. This is currently NOT supported.
# TODO: Support by adding context to model, and use `model.args`
# as the default conditioning. Then we no longer need to check `inargnames`
# since it will all be handled by `contextual_isassumption`.
if !($(DynamicPPL.inargnames)($vn, __model__)) ||
$(DynamicPPL.inmissings)($vn, __model__)
true
else
$expr === missing
end
else
# Evaluate the LHS
$(maybe_view(expr)) === missing
false
end
end
end
end

"""
contextual_isassumption(context, vn)

Return `true` if `vn` is considered an assumption by `context`.

The default implementation for `AbstractContext` always returns `true`.
"""
contextual_isassumption(::IsLeaf, context, vn) = true
function contextual_isassumption(::IsParent, context, vn)
return contextual_isassumption(childcontext(context), vn)
end
function contextual_isassumption(context::AbstractContext, vn)
return contextual_isassumption(NodeTrait(context), context, vn)
end
function contextual_isassumption(context::ConditionContext, vn)
if hasvalue(context, vn)
val = getvalue(context, vn)
# TODO: Do we even need the `>: Missing` to help the compiler?
if eltype(val) >: Missing && val === missing
return true
end
end

# We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`
# so we defer to `childcontext` if we haven't concluded that anything yet.
return contextual_isassumption(childcontext(context), vn)
end
function contextual_isassumption(context::PrefixContext, vn)
return contextual_isassumption(childcontext(context), prefix(context, vn))
end

# failsafe: a literal is never an assumption
isassumption(expr) = :(false)

Expand Down Expand Up @@ -336,6 +381,11 @@ function generate_tilde(left, right)
__varinfo__,
)
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left = $(DynamicPPL.getvalue)(__context__, $vn)
end

$(DynamicPPL.tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
Expand Down Expand Up @@ -380,6 +430,11 @@ function generate_dot_tilde(left, right)
__varinfo__,
)
else
# If `vn` is not in `argnames`, we need to make sure that the variable is defined.
if !$(DynamicPPL.inargnames)($vn, __model__)
$left .= $(DynamicPPL.getvalue)(__context__, $vn)
end

$(DynamicPPL.dot_tilde_observe!)(
__context__,
$(DynamicPPL.check_tilde_rhs)($right),
Expand Down
11 changes: 10 additions & 1 deletion src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds))
_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds))
_getindex(x, inds::Tuple{}) = x
_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing)
function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym}
val = getproperty(x, sym)

# This should work with both cartesian and linear indexing.
return map(vns) do vn
_getindex(val, vn)
end
end

# assume
"""
Expand Down
177 changes: 177 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,180 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing)
end
end

struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx

function ConditionContext{Values}(
values::Values, context::AbstractContext
) where {names,Values<:NamedTuple{names}}
return new{names,typeof(values),typeof(context)}(values, context)
end
end

function ConditionContext(values::NamedTuple)
return ConditionContext(values, DefaultContext())
end
function ConditionContext(values::NamedTuple, context::AbstractContext)
return ConditionContext{typeof(values)}(values, context)
end

# Try to avoid nested `ConditionContext`.
function ConditionContext(
values::NamedTuple{Names}, context::ConditionContext
) where {Names}
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `ConditionContext`.
return ConditionContext(merge(context.values, values), childcontext(context))
end

function Base.show(io::IO, context::ConditionContext)
return print(io, "ConditionContext($(context.values), $(childcontext(context)))")
end

NodeTrait(context::ConditionContext) = IsParent()
childcontext(context::ConditionContext) = context.context
setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child)

"""
getvalue(context, vn)

Return the value of the parameter corresponding to `vn` from `context`.
If `context` does not contain the value for `vn`, then `nothing` is returned,
e.g. [`DefaultContext`](@ref) will always return `nothing`.
"""
getvalue(::IsLeaf, context, vn) = nothing
getvalue(::IsParent, context, vn) = getvalue(childcontext(context), vn)
getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), context, vn)
getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn))

function getvalue(context::ConditionContext, vn)
return if hasvalue(context, vn)
_getvalue(context.values, vn)
else
getvalue(childcontext(context), vn)
end
end

# General implementations of `haskey`.
hasvalue(::IsLeaf, context, vn) = false
hasvalue(::IsParent, context, vn) = hasvalue(childcontext(context), vn)
hasvalue(context::AbstractContext, vn) = hasvalue(NodeTrait(context), context, vn)

# Specific to `ConditionContext`.
function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym}
# TODO: Add possibility of indexed variables, e.g. `x[1]`, etc.
return sym in vars
end

function hasvalue(
context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}
) where {vars,sym}
# TODO: Add possibility of indexed variables, e.g. `x[1]`, etc.
return sym in vars
end

"""
context([context::AbstractContext,] values::NamedTuple)
context([context::AbstractContext]; values...)

Return `ConditionContext` with `values` and `context` if `values` is non-empty,
otherwise return `context` which is [`DefaultContext`](@ref) by default.

See also: [`decondition`](@ref)
"""
condition(; values...) = condition(DefaultContext(), (; values...))
condition(values::NamedTuple) = condition(DefaultContext(), values)
condition(context::AbstractContext, values::NamedTuple{()}) = context
condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context)
condition(context::AbstractContext; values...) = condition(context, (; values...))
"""
decondition(context::AbstractContext, syms...)

Return `context` but with `syms` no longer conditioned on.

Note that this recursively traverses contexts, deconditioning all along the way.

See also: [`condition`](@ref)

# Examples
```jldoctest
julia> using DynamicPPL: AbstractContext, leafcontext, setleafcontext, childcontext, setchildcontext

julia> struct ParentContext{C} <: AbstractContext
context::C
end

julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent()

julia> DynamicPPL.childcontext(context::ParentContext) = context.context

julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child)

julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")")

julia> ctx = DefaultContext()
DefaultContext()

julia> decondition(ctx) === ctx # this is a no-op
true

julia> ctx = condition(x = 1.0) # default "constructor" for `ConditionContext`
ConditionContext((x = 1.0,), DefaultContext())

julia> decondition(ctx) # `decondition` without arguments drops all conditioning
DefaultContext()

julia> # Nested conditioning is supported.
ctx_nested = condition(ParentContext(condition(y=2.0)), x=1.0)
ConditionContext((x = 1.0,), ParentContext(ConditionContext((y = 2.0,), DefaultContext())))

julia> # We can also specify which variables to drop.
decondition(ctx_nested, :x)
ParentContext(ConditionContext((y = 2.0,), DefaultContext()))

julia> # No matter the nested level.
decondition(ctx_nested, :y)
ConditionContext((x = 1.0,), ParentContext(DefaultContext()))

julia> # Or specify multiple at in one call.
decondition(ctx_nested, :x, :y)
ParentContext(DefaultContext())

julia> decondition(ctx_nested)
ParentContext(DefaultContext())

julia> # `Val` is also supported.
decondition(ctx_nested, Val(:x))
ParentContext(ConditionContext((y = 2.0,), DefaultContext()))

julia> decondition(ctx_nested, Val(:y))
ConditionContext((x = 1.0,), ParentContext(DefaultContext()))

julia> decondition(ctx_nested, Val(:x), Val(:y))
ParentContext(DefaultContext())
```
"""
decondition(::IsLeaf, context, args...) = context
function decondition(::IsParent, context, args...)
return setchildcontext(context, decondition(childcontext(context), args...))
end
decondition(context, args...) = decondition(NodeTrait(context), context, args...)
function decondition(context::ConditionContext)
return decondition(childcontext(context))
end
function decondition(context::ConditionContext, sym)
return condition(
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
)
end
function decondition(context::ConditionContext, sym, syms...)
return decondition(
condition(
decondition(childcontext(context), syms...),
BangBang.delete!!(context.values, sym),
),
syms...,
)
end
27 changes: 27 additions & 0 deletions src/contextual_model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel
context::Ctx
model::M
end

function contextualize(model::AbstractModel, context::AbstractContext)
return ContextualModel(context, model)
end

# TODO: What do we do for other contexts? Could handle this in general if we had a
# notion of wrapper-, primitive-context, etc.
function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context)
# Wrap `context` in the model-associated `ConditionContext`, but now using `context` as
# `ConditionContext` child.
return _evaluate(cmodel.model, varinfo, setchildcontext(cmodel.context, context))
end

condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values))
condition(model::AbstractModel; values...) = condition(model, (; values...))
function condition(cmodel::ContextualModel{<:ConditionContext}, values)
return contextualize(cmodel.model, ConditionContext(values, cmodel.context))
end

decondition(model::AbstractModel, args...) = model
function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...)
return contextualize(cmodel.model, decondition(cmodel.context, syms...))
end
Loading