diff --git a/src/compiler.jl b/src/compiler.jl index 6de2f0945..d344717aa 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -501,8 +501,14 @@ function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) end function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value) +end +function matchingvalue(::IsLeaf, context::AbstractContext, vi, value) return matchingvalue(SampleFromPrior(), vi, value) end +function matchingvalue(::IsParent, context::AbstractContext, vi, value) + return matchingvalue(childcontext(context), vi, value) +end function matchingvalue(context::SamplingContext, vi, value) return matchingvalue(context.sampler, vi, value) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index cd7a92535..d15966fb6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -35,12 +35,27 @@ function tilde_assume(context::SamplingContext, right, vn, inds, vi) end # Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) +function tilde_assume(context::AbstractContext, args...) + return tilde_assume(NodeTrait(tilde_assume, context), context, args...) +end +function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vinds, vi) + return assume(right, vn, vi) +end +function tilde_assume(::IsParent, context::AbstractContext, args...) + return tilde_assume(childcontext(context), args...) +end + +function tilde_assume(rng, context::AbstractContext, args...) + return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) +end function tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi + ::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi ) return assume(rng, sampler, right, vn, vi) end +function tilde_assume(::IsParent, rng, context::AbstractContext, args...) + return tilde_assume(rng, childcontext(context), args...) +end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) @@ -64,12 +79,6 @@ function tilde_assume( end return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end -function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, vi) -end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) @@ -102,18 +111,9 @@ function tilde_assume( return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) -end - function tilde_assume(context::PrefixContext, right, vn, inds, vi) return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end - function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end @@ -162,16 +162,16 @@ function tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) -function tilde_observe(::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::AbstractContext, args...) + return tilde_observe(NodeTrait(tilde_observe, context), context, args...) +end +tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...) +function tilde_observe(::IsParent, context::AbstractContext, args...) + return tilde_observe(childcontext(context), args...) end + tilde_observe(::PriorContext, right, left, vi) = 0 tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) -function tilde_observe(::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) -end # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, right, left, vi) @@ -185,9 +185,6 @@ end function tilde_observe(context::PrefixContext, right, left, vname, vi) return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -288,9 +285,28 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) end # `DefaultContext` -function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) +function dot_tilde_assume(context::AbstractContext, args...) + return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) +end +function dot_tilde_assume(rng, context::AbstractContext, args...) + return dot_tilde_assume(rng, NodeTrait(dot_tilde_assume, context), context, args...) +end + +function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi) return dot_assume(right, left, vns, vi) end +function dot_tilde_assume( + ::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi +) + return dot_assume(rng, sampler, right, vns, left, vi) +end + +function dot_tilde_assume(::IsParent, context::AbstractContext, args...) + return dot_tilde_assume(childcontext(context), args...) +end +function dot_tilde_assume(rng, ::IsParent, context::AbstractContext, args...) + return dot_tilde_assume(rng, childcontext(context), args...) +end function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) @@ -371,25 +387,6 @@ function dot_tilde_assume( dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end end -function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - return dot_assume(right, left, vn, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi -) - return dot_assume(rng, sampler, right, vn, left, vi) -end - -# `MiniBatchContext` -function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - -function dot_tilde_assume( - rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi -) - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) -end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) @@ -586,18 +583,16 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) -function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +function dot_tilde_observe(context::AbstractContext, args...) + return dot_tilde_observe(NodeTrait(tilde_observe, context), context, args...) +end +dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) +function dot_tilde_observe(::IsParent, context::AbstractContext, args...) + return dot_tilde_observe(childcontext(context), args...) end + dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(context::LikelihoodContext, right, left, vi) - return dot_observe(right, left, vi) -end -function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) -end # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, right, left, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..9b8f7ac07 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,124 @@ +# Fallback traits +# TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`? + +""" + NodeTrait(context) + NodeTrait(f, context) + +Specifies the role of `context` in the context-tree. + +The officially supported traits are: +- `IsLeaf`: `context` does not have any decendants. +- `IsParent`: `context` has a child context to which we often defer. + Expects the following methods to be implemented: + - [`childcontext`](@ref) + - [`setchildcontext`](@ref) +""" +abstract type NodeTrait end +NodeTrait(_, context) = NodeTrait(context) + +""" + IsLeaf + +Specifies that the context is a leaf in the context-tree. +""" +struct IsLeaf <: NodeTrait end +""" + IsParent + +Specifies that the context is a parent in the context-tree. +""" +struct IsParent <: NodeTrait end + +""" + childcontext(context) + +Return the descendant context of `context`. +""" +childcontext + +""" + setchildcontext(parent::AbstractContext, child::AbstractContext) + +Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), +effectively updating the child context. + +# Examples +```jldoctest +julia> ctx = SamplingContext(); + +julia> DynamicPPL.childcontext(ctx) +DefaultContext() + +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior + +julia> DynamicPPL.childcontext(ctx_prior) +PriorContext{Nothing}(nothing) +``` +""" +setchildcontext + +""" + leafcontext(context) + +Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +""" +leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context) +leafcontext(::IsLeaf, context) = context +leafcontext(::IsParent, context) = leafcontext(childcontext(context)) + +""" + setleafcontext(left, right) + +Return `left` but now with its leaf context replaced by `right`. + +Note that this also works even if `right` is not a leaf context, +in which case effectively append `right` to `left`, dropping the +original leaf context of `left`. + +# Examples +```jldoctest +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext + +julia> struct ParentContext{C} <: AbstractContext + context::C + end + +julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() + +julia> DynamicPPL.childcontext(context::ParentContext) = context.context + +julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) + +julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + +julia> ctx = ParentContext(ParentContext(DefaultContext())) +ParentContext(ParentContext(DefaultContext())) + +julia> # Replace the leaf context with another leaf. + leafcontext(setleafcontext(ctx, PriorContext())) +PriorContext{Nothing}(nothing) + +julia> # Append another parent context. + setleafcontext(ctx, ParentContext(DefaultContext())) +ParentContext(ParentContext(ParentContext(DefaultContext()))) +``` +""" +function setleafcontext(left, right) + return setleafcontext( + NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right + ) +end +function setleafcontext(::IsParent, ::IsParent, left, right) + return setchildcontext(left, setleafcontext(childcontext(left), right)) +end +function setleafcontext(::IsParent, ::IsLeaf, left, right) + return setchildcontext(left, setleafcontext(childcontext(left), right)) +end +setleafcontext(::IsLeaf, ::IsParent, left, right) = right +setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right + +# Contexts """ SamplingContext(rng, sampler, context) @@ -11,6 +132,16 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end +SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context) +SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context) +SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext()) +SamplingContext() = SamplingContext(SampleFromPrior()) + +NodeTrait(context::SamplingContext) = IsParent() +childcontext(context::SamplingContext) = context.context +function setchildcontext(parent::SamplingContext, child) + return SamplingContext(parent.rng, parent.sampler, child) +end """ struct DefaultContext <: AbstractContext end @@ -19,6 +150,7 @@ The `DefaultContext` is used by default to compute log the joint probability of and parameters when running the model. """ struct DefaultContext <: AbstractContext end +NodeTrait(context::DefaultContext) = IsLeaf() """ struct PriorContext{Tvars} <: AbstractContext @@ -32,6 +164,7 @@ struct PriorContext{Tvars} <: AbstractContext vars::Tvars end PriorContext() = PriorContext(nothing) +NodeTrait(context::PriorContext) = IsLeaf() """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -46,6 +179,7 @@ struct LikelihoodContext{Tvars} <: AbstractContext vars::Tvars end LikelihoodContext() = LikelihoodContext(nothing) +NodeTrait(context::LikelihoodContext) = IsLeaf() """ struct MiniBatchContext{Tctx, T} <: AbstractContext @@ -66,6 +200,11 @@ end function MiniBatchContext(context=DefaultContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end +NodeTrait(context::MiniBatchContext) = IsParent() +childcontext(context::MiniBatchContext) = context.context +function setchildcontext(parent::MiniBatchContext, child) + return MiniBatchContext(child, parent.loglike_scalar) +end """ PrefixContext{Prefix}(context) @@ -85,6 +224,12 @@ function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} return PrefixContext{Prefix,typeof(context)}(context) end +NodeTrait(context::PrefixContext) = IsParent() +childcontext(context::PrefixContext) = context.context +function setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} + return PrefixContext{Prefix}(child) +end + const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 2901432d1..0cac29219 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -13,6 +13,12 @@ function PointwiseLikelihoodContext( ) end +NodeTrait(::PointwiseLikelihoodContext) = IsParent() +childcontext(context::PointwiseLikelihoodContext) = context.context +function setchildcontext(context::PointwiseLikelihoodContext, child) + return PointwiseLikelihoodContext(context.loglikelihoods, child) +end + function Base.push!( context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, @@ -61,14 +67,6 @@ function Base.push!( return context.loglikelihoods[vn] = logp end -function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!(context.context, right, left, vi) diff --git a/test/contexts.jl b/test/contexts.jl index d9bcd2ef9..7793dfc74 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,109 @@ +using Test, DynamicPPL +using DynamicPPL: + leafcontext, + setleafcontext, + childcontext, + setchildcontext, + AbstractContext, + NodeTrait, + IsLeaf, + IsParent, + PointwiseLikelihoodContext + +struct ParentContext{C<:AbstractContext} <: AbstractContext + context::C +end +ParentContext() = ParentContext(DefaultContext()) +DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::ParentContext) = context.context +DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) +Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + @testset "contexts.jl" begin + child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] + + parent_contexts = [ + ParentContext(DefaultContext()), + SamplingContext(), + MiniBatchContext(DefaultContext(), 0.0), + PrefixContext{:x}(DefaultContext()), + PointwiseLikelihoodContext(), + ] + + contexts = vcat(child_contexts, parent_contexts) + + @testset "NodeTrait" begin + @testset "$context" for context in contexts + # Every `context` should have a `NodeTrait`. + @test NodeTrait(context) isa NodeTrait + end + end + + @testset "leafcontext" begin + @testset "$context" for context in child_contexts + @test leafcontext(context) === context + end + + @testset "$context" for context in parent_contexts + @test NodeTrait(leafcontext(context)) isa IsLeaf + end + end + + @testset "setleafcontext" begin + @testset "$context" for context in child_contexts + # Setting to itself should return itself. + @test setleafcontext(context, context) === context + + # Setting to a different context should return that context. + new_leaf = context isa DefaultContext ? PriorContext() : DefaultContext() + @test setleafcontext(context, new_leaf) === new_leaf + + # Also works for parent contexts. + new_leaf = ParentContext(context) + @test setleafcontext(context, new_leaf) === new_leaf + end + + @testset "$context" for context in parent_contexts + # Leaf contexts. + new_leaf = + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + @test leafcontext(setleafcontext(context, new_leaf)) === new_leaf + + # Setting parent contexts as "leaf" means that the new leaf should be + # the leaf of the parent context we just set as the leaf. + new_leaf = ParentContext(( + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + )) + @test leafcontext(setleafcontext(context, new_leaf)) === leafcontext(new_leaf) + end + end + + # `IsParent` interface. + @testset "childcontext" begin + @testset "$context" for context in parent_contexts + @test childcontext(context) isa AbstractContext + end + end + + @testset "setchildcontext" begin + @testset "nested contexts" begin + # Both of the following should result in the same context. + context1 = ParentContext(ParentContext(ParentContext())) + context2 = setchildcontext( + ParentContext(), setchildcontext(ParentContext(), ParentContext()) + ) + @test context1 === context2 + end + + @testset "$context" for context in parent_contexts + # Setting the child context to a leaf should now change the `leafcontext` accordingly. + new_leaf = + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + new_context = setchildcontext(context, new_leaf) + @test childcontext(new_context) === leafcontext(new_context) === new_leaf + end + end + @testset "PrefixContext" begin ctx = @inferred PrefixContext{:f}( PrefixContext{:e}( diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 05e6fb55e..8edbb9389 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -6,5 +6,5 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DynamicPPL = "0.13" -Turing = "0.15, 0.16" +Turing = "0.17" julia = "1.3"