Skip to content
Closed
6 changes: 6 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,14 @@ function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType)
end

function matchingvalue(context::AbstractContext, vi, value)
return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value)
end
function matchingvalue(::IsLeaf, context::AbstractContext, vi, value)
return matchingvalue(SampleFromPrior(), vi, value)
end
function matchingvalue(::IsParent, context::AbstractContext, vi, value)
return matchingvalue(childcontext(context), vi, value)
end
Comment on lines +504 to +511
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also very nice as it allows us to define something like a CUDAContext in the future which will ensure that all the arguments are moved to the GPU before execution.

function matchingvalue(context::SamplingContext, vi, value)
return matchingvalue(context.sampler, vi, value)
end
Expand Down
107 changes: 51 additions & 56 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,27 @@ function tilde_assume(context::SamplingContext, right, vn, inds, vi)
end

# Leaf contexts
tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi)
function tilde_assume(context::AbstractContext, args...)
return tilde_assume(NodeTrait(tilde_assume, context), context, args...)
end
function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vinds, vi)
return assume(right, vn, vi)
end
function tilde_assume(::IsParent, context::AbstractContext, args...)
return tilde_assume(childcontext(context), args...)
end

function tilde_assume(rng, context::AbstractContext, args...)
return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...)
end
function tilde_assume(
rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi
::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi
)
return assume(rng, sampler, right, vn, vi)
end
function tilde_assume(::IsParent, rng, context::AbstractContext, args...)
return tilde_assume(rng, childcontext(context), args...)
end

function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi)
if haskey(context.vars, getsym(vn))
Expand All @@ -64,12 +79,6 @@ function tilde_assume(
end
return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi)
end
function tilde_assume(::PriorContext, right, vn, inds, vi)
return assume(right, vn, vi)
end
function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi)
return assume(rng, sampler, right, vn, vi)
end

function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi)
if haskey(context.vars, getsym(vn))
Expand Down Expand Up @@ -102,18 +111,9 @@ function tilde_assume(
return assume(rng, sampler, NoDist(right), vn, vi)
end

function tilde_assume(context::MiniBatchContext, right, vn, inds, vi)
return tilde_assume(context.context, right, vn, inds, vi)
end

function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi)
return tilde_assume(rng, context.context, sampler, right, vn, inds, vi)
end

function tilde_assume(context::PrefixContext, right, vn, inds, vi)
return tilde_assume(context.context, right, prefix(context, vn), inds, vi)
end

function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi)
return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi)
end
Expand Down Expand Up @@ -162,16 +162,16 @@ function tilde_observe(context::SamplingContext, right, left, vi)
end

# Leaf contexts
tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi)
function tilde_observe(::DefaultContext, sampler, right, left, vi)
return observe(sampler, right, left, vi)
function tilde_observe(context::AbstractContext, args...)
return tilde_observe(NodeTrait(tilde_observe, context), context, args...)
end
tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...)
function tilde_observe(::IsParent, context::AbstractContext, args...)
return tilde_observe(childcontext(context), args...)
end

tilde_observe(::PriorContext, right, left, vi) = 0
tilde_observe(::PriorContext, sampler, right, left, vi) = 0
tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi)
function tilde_observe(::LikelihoodContext, sampler, right, left, vi)
return observe(sampler, right, left, vi)
end

# `MiniBatchContext`
function tilde_observe(context::MiniBatchContext, right, left, vi)
Expand All @@ -185,9 +185,6 @@ end
function tilde_observe(context::PrefixContext, right, left, vname, vi)
return tilde_observe(context.context, right, left, prefix(context, vname), vi)
end
function tilde_observe(context::PrefixContext, right, left, vi)
return tilde_observe(context.context, right, left, vi)
end

"""
tilde_observe!(context, right, left, vname, vinds, vi)
Expand Down Expand Up @@ -288,9 +285,28 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi)
end

# `DefaultContext`
function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi)
function dot_tilde_assume(context::AbstractContext, args...)
return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...)
end
function dot_tilde_assume(rng, context::AbstractContext, args...)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO we should make context always be the first argument, even after unwrapping in SamplingContext. This will further reduce redundancy by a factor of 2.

We only did it this way to attempt to be non-breaking (it wasn't 😅 ), so might was well make the change

return dot_tilde_assume(rng, NodeTrait(dot_tilde_assume, context), context, args...)
end

function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi)
return dot_assume(right, left, vns, vi)
end
function dot_tilde_assume(
::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi
)
return dot_assume(rng, sampler, right, vns, left, vi)
end

function dot_tilde_assume(::IsParent, context::AbstractContext, args...)
return dot_tilde_assume(childcontext(context), args...)
end
function dot_tilde_assume(rng, ::IsParent, context::AbstractContext, args...)
return dot_tilde_assume(rng, childcontext(context), args...)
end

function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi)
return dot_assume(rng, sampler, right, vns, left, vi)
Expand Down Expand Up @@ -371,25 +387,6 @@ function dot_tilde_assume(
dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi)
end
end
function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi)
return dot_assume(right, left, vn, vi)
end
function dot_tilde_assume(
rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi
)
return dot_assume(rng, sampler, right, vn, left, vi)
end

# `MiniBatchContext`
function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi)
return dot_tilde_assume(context.context, right, left, vn, inds, vi)
end

function dot_tilde_assume(
rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi
)
return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi)
end

# `PrefixContext`
function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi)
Expand Down Expand Up @@ -586,18 +583,16 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi)
end

# Leaf contexts
dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi)
function dot_tilde_observe(::DefaultContext, sampler, right, left, vi)
return dot_observe(sampler, right, left, vi)
function dot_tilde_observe(context::AbstractContext, args...)
return dot_tilde_observe(NodeTrait(tilde_observe, context), context, args...)
end
dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...)
function dot_tilde_observe(::IsParent, context::AbstractContext, args...)
return dot_tilde_observe(childcontext(context), args...)
end

dot_tilde_observe(::PriorContext, right, left, vi) = 0
dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0
function dot_tilde_observe(context::LikelihoodContext, right, left, vi)
return dot_observe(right, left, vi)
end
function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi)
return dot_observe(sampler, right, left, vi)
end

# `MiniBatchContext`
function dot_tilde_observe(context::MiniBatchContext, right, left, vi)
Expand Down
145 changes: 145 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,124 @@
# Fallback traits
# TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`?

"""
NodeTrait(context)
NodeTrait(f, context)

Specifies the role of `context` in the context-tree.

The officially supported traits are:
- `IsLeaf`: `context` does not have any decendants.
- `IsParent`: `context` has a child context to which we often defer.
Expects the following methods to be implemented:
- [`childcontext`](@ref)
- [`setchildcontext`](@ref)
"""
abstract type NodeTrait end
NodeTrait(_, context) = NodeTrait(context)

"""
IsLeaf

Specifies that the context is a leaf in the context-tree.
"""
struct IsLeaf <: NodeTrait end
"""
IsParent

Specifies that the context is a parent in the context-tree.
"""
struct IsParent <: NodeTrait end

"""
childcontext(context)

Return the descendant context of `context`.
"""
childcontext

"""
setchildcontext(parent::AbstractContext, child::AbstractContext)

Reconstruct `parent` but now using `child` is its [`childcontext`](@ref),
effectively updating the child context.

# Examples
```jldoctest
julia> ctx = SamplingContext();

julia> DynamicPPL.childcontext(ctx)
DefaultContext()

julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior

julia> DynamicPPL.childcontext(ctx_prior)
PriorContext{Nothing}(nothing)
```
"""
setchildcontext

"""
leafcontext(context)

Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`.
"""
leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context)
leafcontext(::IsLeaf, context) = context
leafcontext(::IsParent, context) = leafcontext(childcontext(context))

"""
setleafcontext(left, right)

Return `left` but now with its leaf context replaced by `right`.

Note that this also works even if `right` is not a leaf context,
in which case effectively append `right` to `left`, dropping the
original leaf context of `left`.

# Examples
```jldoctest
julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext

julia> struct ParentContext{C} <: AbstractContext
context::C
end

julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent()

julia> DynamicPPL.childcontext(context::ParentContext) = context.context

julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child)

julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")")

julia> ctx = ParentContext(ParentContext(DefaultContext()))
ParentContext(ParentContext(DefaultContext()))

julia> # Replace the leaf context with another leaf.
leafcontext(setleafcontext(ctx, PriorContext()))
PriorContext{Nothing}(nothing)

julia> # Append another parent context.
setleafcontext(ctx, ParentContext(DefaultContext()))
ParentContext(ParentContext(ParentContext(DefaultContext())))
```
"""
function setleafcontext(left, right)
return setleafcontext(
NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right
)
end
function setleafcontext(::IsParent, ::IsParent, left, right)
return setchildcontext(left, setleafcontext(childcontext(left), right))
end
function setleafcontext(::IsParent, ::IsLeaf, left, right)
return setchildcontext(left, setleafcontext(childcontext(left), right))
end
setleafcontext(::IsLeaf, ::IsParent, left, right) = right
setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right

# Contexts
"""
SamplingContext(rng, sampler, context)

Expand All @@ -11,6 +132,16 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte
sampler::S
context::C
end
SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context)
SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context)
SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext())
SamplingContext() = SamplingContext(SampleFromPrior())

NodeTrait(context::SamplingContext) = IsParent()
childcontext(context::SamplingContext) = context.context
function setchildcontext(parent::SamplingContext, child)
return SamplingContext(parent.rng, parent.sampler, child)
end

"""
struct DefaultContext <: AbstractContext end
Expand All @@ -19,6 +150,7 @@ The `DefaultContext` is used by default to compute log the joint probability of
and parameters when running the model.
"""
struct DefaultContext <: AbstractContext end
NodeTrait(context::DefaultContext) = IsLeaf()

"""
struct PriorContext{Tvars} <: AbstractContext
Expand All @@ -32,6 +164,7 @@ struct PriorContext{Tvars} <: AbstractContext
vars::Tvars
end
PriorContext() = PriorContext(nothing)
NodeTrait(context::PriorContext) = IsLeaf()

"""
struct LikelihoodContext{Tvars} <: AbstractContext
Expand All @@ -46,6 +179,7 @@ struct LikelihoodContext{Tvars} <: AbstractContext
vars::Tvars
end
LikelihoodContext() = LikelihoodContext(nothing)
NodeTrait(context::LikelihoodContext) = IsLeaf()

"""
struct MiniBatchContext{Tctx, T} <: AbstractContext
Expand All @@ -66,6 +200,11 @@ end
function MiniBatchContext(context=DefaultContext(); batch_size, npoints)
return MiniBatchContext(context, npoints / batch_size)
end
NodeTrait(context::MiniBatchContext) = IsParent()
childcontext(context::MiniBatchContext) = context.context
function setchildcontext(parent::MiniBatchContext, child)
return MiniBatchContext(child, parent.loglike_scalar)
end

"""
PrefixContext{Prefix}(context)
Expand All @@ -85,6 +224,12 @@ function PrefixContext{Prefix}(context::AbstractContext) where {Prefix}
return PrefixContext{Prefix,typeof(context)}(context)
end

NodeTrait(context::PrefixContext) = IsParent()
childcontext(context::PrefixContext) = context.context
function setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix}
return PrefixContext{Prefix}(child)
end

const PREFIX_SEPARATOR = Symbol(".")

function PrefixContext{PrefixInner}(
Expand Down
Loading