From f9e753ac2ede7d511b46c4bfb833b9dbb19a1481 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 23 Jul 2021 17:59:38 +0100 Subject: [PATCH 01/12] initial work --- src/context_implementations.jl | 104 +++++++++++++++------------------ src/contexts.jl | 13 +++++ 2 files changed, 59 insertions(+), 58 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..94f5f3014 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -35,12 +35,25 @@ function tilde_assume(context::SamplingContext, right, vn, inds, vi) end # Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) -function tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi -) +function tilde_assume(context::AbstractContext, args...) + return tilde_assume(NodeTrait(tilde_assume, context), context, args...) +end +function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vinds, vi) + return assume(right, vn, vi) +end +function tilde_assume(::IsParent, context::AbstractContext, args...) + return tilde_assume(childcontext(context), args...) +end + +function tilde_assume(rng, context::AbstractContext, args...) + return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) +end +function tilde_assume(::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi) return assume(rng, sampler, right, vn, vi) end +function tilde_assume(::IsParent, rng, context::AbstractContext, args...) + return tilde_assume(rng, childcontext(context), args...) +end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) @@ -64,12 +77,6 @@ function tilde_assume( end return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end -function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, vi) -end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) @@ -102,18 +109,9 @@ function tilde_assume( return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) -end - function tilde_assume(context::PrefixContext, right, vn, inds, vi) return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end - function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end @@ -162,16 +160,16 @@ function tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) -function tilde_observe(::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::AbstractContext, args...) + return tilde_observe(NodeTrait(tilde_observe, context), context, args...) +end +tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...) +function tilde_observe(::IsParent, context::AbstractContext, args...) + return tilde_observe(childcontext(context), args...) end + tilde_observe(::PriorContext, right, left, vi) = 0 tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) -function tilde_observe(::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) -end # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, right, left, vi) @@ -185,9 +183,6 @@ end function tilde_observe(context::PrefixContext, right, left, vname, vi) return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -291,9 +286,26 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) end # `DefaultContext` -function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) +function dot_tilde_assume(context::AbstractContext, args...) + return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) +end +function dot_tilde_assume(rng, context::AbstractContext, args...) + return dot_tilde_assume(rng, NodeTrait(dot_tilde_assume, context), context, args...) +end + +function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi) return dot_assume(right, left, vns, vi) end +function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi) + return dot_assume(rng, sampler, right, vns, left, vi) +end + +function dot_tilde_assume(::IsParent, context::AbstractContext, args...) + return dot_tilde_assume(childcontext(context), args...) +end +function dot_tilde_assume(rng, ::IsParent, context::AbstractContext, args...) + return dot_tilde_assume(rng, childcontext(context), args...) +end function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) @@ -374,25 +386,6 @@ function dot_tilde_assume( dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end end -function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - return dot_assume(right, left, vn, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi -) - return dot_assume(rng, sampler, right, vn, left, vi) -end - -# `MiniBatchContext` -function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - -function dot_tilde_assume( - rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi -) - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) -end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) @@ -588,18 +581,13 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) -function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) +function dot_tilde_observe(::IsParent, context::AbstractContext, args...) + return dot_tilde_observe(childcontext(context), args...) end + dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(context::LikelihoodContext, right, left, vi) - return dot_observe(right, left, vi) -end -function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) -end # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, right, left, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..b27258932 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,11 @@ +# Fallback traits +# TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`? +struct IsLeaf end +struct IsParent end + +NodeTrait(::Any, context) = NodeTrait(context) + +# Contexts """ SamplingContext(rng, sampler, context) @@ -11,6 +19,7 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end +NodeTrait(context::SamplingContext) = IsParent() """ struct DefaultContext <: AbstractContext end @@ -19,6 +28,7 @@ The `DefaultContext` is used by default to compute log the joint probability of and parameters when running the model. """ struct DefaultContext <: AbstractContext end +NodeTrait(context::DefaultContext) = IsLeaf() """ struct PriorContext{Tvars} <: AbstractContext @@ -32,6 +42,7 @@ struct PriorContext{Tvars} <: AbstractContext vars::Tvars end PriorContext() = PriorContext(nothing) +NodeTrait(context::PriorContext) = IsLeaf() """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -46,6 +57,7 @@ struct LikelihoodContext{Tvars} <: AbstractContext vars::Tvars end LikelihoodContext() = LikelihoodContext(nothing) +NodeTrait(context::LikelihoodContext) = IsLeaf() """ struct MiniBatchContext{Tctx, T} <: AbstractContext @@ -66,6 +78,7 @@ end function MiniBatchContext(context=DefaultContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end +NodeTrait(context::MiniBatchContext) = IsParent() """ PrefixContext{Prefix}(context) From f090ff50ec04d9b8eaa0b84efe7eb4479e9ae142 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 05:15:18 +0100 Subject: [PATCH 02/12] added some missing implementations --- src/contexts.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index b27258932..78759de61 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -3,7 +3,13 @@ struct IsLeaf end struct IsParent end -NodeTrait(::Any, context) = NodeTrait(context) +""" + NodeTrait(context) + NodeTrait(f, context) + +Specifies the role of `context` in the context-tree. +""" +NodeTrait(_, context) = NodeTrait(context) # Contexts """ @@ -20,6 +26,7 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte context::C end NodeTrait(context::SamplingContext) = IsParent() +childcontext(context::SamplingContext) = context.context """ struct DefaultContext <: AbstractContext end @@ -79,6 +86,7 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end NodeTrait(context::MiniBatchContext) = IsParent() +childcontext(context::MiniBatchContext) = context.context """ PrefixContext{Prefix}(context) @@ -98,6 +106,9 @@ function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} return PrefixContext{Prefix,typeof(context)}(context) end +NodeTrait(context::PrefixContext) = IsParent() +childcontext(context::PrefixContext) = context.context + const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( From 4e36c55fb7ae028177763dca51cf8dff68d98601 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 05:23:22 +0100 Subject: [PATCH 03/12] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 94f5f3014..0d35aba45 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -48,7 +48,9 @@ end function tilde_assume(rng, context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) end -function tilde_assume(::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi) +function tilde_assume( + ::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi +) return assume(rng, sampler, right, vn, vi) end function tilde_assume(::IsParent, rng, context::AbstractContext, args...) @@ -296,7 +298,9 @@ end function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi) return dot_assume(right, left, vns, vi) end -function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume( + ::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi +) return dot_assume(rng, sampler, right, vns, left, vi) end From ee035cb388d43e38a7130182302ef43f0775dcb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 18:59:48 +0100 Subject: [PATCH 04/12] added some more functionality for context traits --- src/contexts.jl | 112 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 78759de61..3653c68d3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -8,9 +8,107 @@ struct IsParent end NodeTrait(f, context) Specifies the role of `context` in the context-tree. + +The officially supported traits are: +- `IsLeaf`: `context` does not have any decendants. +- `IsParent`: `context` has a child context to which we often defer. + Expects the following methods to be implemented: + - [`childcontext`](@ref) + - [`setchildcontext`](@ref) """ NodeTrait(_, context) = NodeTrait(context) +""" + childcontext(context) + +Return the descendant context of `context`. +""" +childcontext + +""" + setchildcontext(parent::AbstractContext, child::AbstractContext) + +Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), +effectively updating the child context. + +# Examples +```jldoctest +julia> ctx = SamplingContext(); + +julia> DynamicPPL.childcontext(ctx) +DefaultContext() + +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior + +julia> DynamicPPL.childcontext(ctx_prior) +PriorContext{Nothing}(nothing) +``` +""" +setchildcontext + +""" + leafcontext(context) + +Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +""" +leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context) +leafcontext(::IsLeaf, context) = context +leafcontext(::IsParent, context) = leafcontext(childcontext(context)) + +""" + setleafcontext(left, right) + +Return `left` but now with its leaf context replaced by `right`. + +Note that this also works even if `right` is not a leaf context, +in which case effectively append `right` to `left`, dropping the +original leaf context of `left`. + +# Examples +```jldoctest +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext + +julia> struct ParentContext{C} + context::C + end + +julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() + +julia> DynamicPPL.childcontext(context::ParentContext) = context.context + +julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) + +julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + +julia> ctx = ParentContext(ParentContext(DefaultContext())) +ParentContext(ParentContext(DefaultContext())) + +julia> # Replace the leaf context with another leaf. + leafcontext(setleafcontext(ctx, PriorContext())) +PriorContext{Nothing}(nothing) + +julia> # Append another parent context. + setleafcontext(ctx, ParentContext(DefaultContext())) +ParentContext(ParentContext(ParentContext(DefaultContext()))) +``` +""" +function setleafcontext(left, right) + return setleafcontext( + NodeTrait(setleafcontext, left), + NodeTrait(setleafcontext, right), + left, + right + ) +end +function setleafcontext(::IsParent, ::IsParent, left, right) + return setchildcontext(left, setleafcontext(childcontext(left), right)) +end +function setleafcontext(::IsParent, ::IsLeaf, left, right) + return setchildcontext(left, setleafcontext(childcontext(left), right)) +end +setleafcontext(::IsLeaf, ::IsParent, left, right) = right +setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right + # Contexts """ SamplingContext(rng, sampler, context) @@ -25,8 +123,16 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end +SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context) +SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context) +SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext()) +SamplingContext() = SamplingContext(SampleFromPrior()) + NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context +function setchildcontext(parent::SamplingContext, child) + return SamplingContext(parent.rng, parent.sampler, child) +end """ struct DefaultContext <: AbstractContext end @@ -87,6 +193,9 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context +function setchildcontext(parent::MiniBatchContext, child) + return MiniBatchContext(child, parent.loglike_scalar) +end """ PrefixContext{Prefix}(context) @@ -108,6 +217,9 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context +function setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} + return PrefixContext{Prefix}(child) +end const PREFIX_SEPARATOR = Symbol(".") From b7998bc84ddab1f99d84b7fc9e79806082976574 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:33:15 +0100 Subject: [PATCH 05/12] fixed PointwiseLikelihood --- src/context_implementations.jl | 3 +++ src/loglikelihoods.jl | 14 ++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 08ec7144d..d15966fb6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -583,6 +583,9 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts +function dot_tilde_observe(context::AbstractContext, args...) + return dot_tilde_observe(NodeTrait(tilde_observe, context), context, args...) +end dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) function dot_tilde_observe(::IsParent, context::AbstractContext, args...) return dot_tilde_observe(childcontext(context), args...) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 2901432d1..0cac29219 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -13,6 +13,12 @@ function PointwiseLikelihoodContext( ) end +NodeTrait(::PointwiseLikelihoodContext) = IsParent() +childcontext(context::PointwiseLikelihoodContext) = context.context +function setchildcontext(context::PointwiseLikelihoodContext, child) + return PointwiseLikelihoodContext(context.loglikelihoods, child) +end + function Base.push!( context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, @@ -61,14 +67,6 @@ function Base.push!( return context.loglikelihoods[vn] = logp end -function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!(context.context, right, left, vi) From 4e3e08fae8c9a73db5ed777535234a6d02e021e4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:44:30 +0100 Subject: [PATCH 06/12] fixed a doctest --- src/contexts.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 3653c68d3..1345e1c93 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -66,9 +66,9 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext -julia> struct ParentContext{C} +julia> struct ParentContext{C} <: AbstractContext context::C end From 9e23d4d3ecf405646d06076031611244813ee756 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:45:52 +0100 Subject: [PATCH 07/12] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 1345e1c93..4bda07009 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -94,10 +94,7 @@ ParentContext(ParentContext(ParentContext(DefaultContext()))) """ function setleafcontext(left, right) return setleafcontext( - NodeTrait(setleafcontext, left), - NodeTrait(setleafcontext, right), - left, - right + NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right ) end function setleafcontext(::IsParent, ::IsParent, left, right) From 7c9dc5e9b7dd4ff16bfdf42b08d4cd5c35345dae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:20:15 +0100 Subject: [PATCH 08/12] make NodeTrait an abstract type --- src/contexts.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 4bda07009..9b8f7ac07 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,7 +1,5 @@ # Fallback traits # TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`? -struct IsLeaf end -struct IsParent end """ NodeTrait(context) @@ -16,8 +14,22 @@ The officially supported traits are: - [`childcontext`](@ref) - [`setchildcontext`](@ref) """ +abstract type NodeTrait end NodeTrait(_, context) = NodeTrait(context) +""" + IsLeaf + +Specifies that the context is a leaf in the context-tree. +""" +struct IsLeaf <: NodeTrait end +""" + IsParent + +Specifies that the context is a parent in the context-tree. +""" +struct IsParent <: NodeTrait end + """ childcontext(context) From e67089f4ca262ca34d948947020eaf2b50a872e3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:21:28 +0100 Subject: [PATCH 09/12] make matchingvalue work nicely with contexts --- src/compiler.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 6de2f0945..d344717aa 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -501,8 +501,14 @@ function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) end function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value) +end +function matchingvalue(::IsLeaf, context::AbstractContext, vi, value) return matchingvalue(SampleFromPrior(), vi, value) end +function matchingvalue(::IsParent, context::AbstractContext, vi, value) + return matchingvalue(childcontext(context), vi, value) +end function matchingvalue(context::SamplingContext, vi, value) return matchingvalue(context.sampler, vi, value) end From 9479ce474afe5e52ef3e1909bf8a974c58db4faf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:21:40 +0100 Subject: [PATCH 10/12] added a bunch of tests for the new trait system for contexts --- test/contexts.jl | 94 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/test/contexts.jl b/test/contexts.jl index d9bcd2ef9..ae1ef100b 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,98 @@ +using Test, DynamicPPL +using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, NodeTrait, IsLeaf, IsParent, PointwiseLikelihoodContext + +struct ParentContext{C<:AbstractContext} <: AbstractContext + context::C +end +ParentContext() = ParentContext(DefaultContext()) +DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::ParentContext) = context.context +DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) +Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + @testset "contexts.jl" begin + child_contexts = [ + DefaultContext(), + PriorContext(), + LikelihoodContext(), + ] + + parent_contexts = [ + ParentContext(DefaultContext()), + SamplingContext(), + MiniBatchContext(DefaultContext(), 0.0), + PrefixContext{:x}(DefaultContext()), + PointwiseLikelihoodContext() + ] + + contexts = vcat(child_contexts, parent_contexts) + + @testset "NodeTrait" begin + @testset "$context" for context in contexts + # Every `context` should have a `NodeTrait`. + @test NodeTrait(context) isa NodeTrait + end + end + + @testset "leafcontext" begin + @testset "$context" for context in child_contexts + @test leafcontext(context) === context + end + + @testset "$context" for context in parent_contexts + @test NodeTrait(leafcontext(context)) isa IsLeaf + end + end + + @testset "setleafcontext" begin + @testset "$context" for context in child_contexts + # Setting to itself should return itself. + @test setleafcontext(context, context) === context + + # Setting to a different context should return that context. + new_leaf = context isa DefaultContext ? PriorContext() : DefaultContext() + @test setleafcontext(context, new_leaf) === new_leaf + + # Also works for parent contexts. + new_leaf = ParentContext(context) + @test setleafcontext(context, new_leaf) === new_leaf + end + + @testset "$context" for context in parent_contexts + # Leaf contexts. + new_leaf = leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + @test leafcontext(setleafcontext(context, new_leaf)) === new_leaf + + # Setting parent contexts as "leaf" means that the new leaf should be + # the leaf of the parent context we just set as the leaf. + new_leaf = ParentContext((leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext())) + @test leafcontext(setleafcontext(context, new_leaf)) === leafcontext(new_leaf) + end + end + + # `IsParent` interface. + @testset "childcontext" begin + @testset "$context" for context in parent_contexts + @test childcontext(context) isa AbstractContext + end + end + + @testset "setchildcontext" begin + @testset "nested contexts" begin + # Both of the following should result in the same context. + context1 = ParentContext(ParentContext(ParentContext())) + context2 = setchildcontext(ParentContext(), setchildcontext(ParentContext(), ParentContext())) + @test context1 === context2 + end + + @testset "$context" for context in parent_contexts + # Setting the child context to a leaf should now change the `leafcontext` accordingly. + new_leaf = leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + new_context = setchildcontext(context, new_leaf) + @test childcontext(new_context) === leafcontext(new_context) === new_leaf + end + end + @testset "PrefixContext" begin ctx = @inferred PrefixContext{:f}( PrefixContext{:e}( From 1315d886d56328f52c852410803d5da859d2e8ac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:22:39 +0100 Subject: [PATCH 11/12] formatting --- test/contexts.jl | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index ae1ef100b..7793dfc74 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,14 @@ using Test, DynamicPPL -using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, NodeTrait, IsLeaf, IsParent, PointwiseLikelihoodContext +using DynamicPPL: + leafcontext, + setleafcontext, + childcontext, + setchildcontext, + AbstractContext, + NodeTrait, + IsLeaf, + IsParent, + PointwiseLikelihoodContext struct ParentContext{C<:AbstractContext} <: AbstractContext context::C @@ -11,18 +20,14 @@ DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") @testset "contexts.jl" begin - child_contexts = [ - DefaultContext(), - PriorContext(), - LikelihoodContext(), - ] + child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] parent_contexts = [ ParentContext(DefaultContext()), SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), - PointwiseLikelihoodContext() + PointwiseLikelihoodContext(), ] contexts = vcat(child_contexts, parent_contexts) @@ -60,18 +65,21 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c @testset "$context" for context in parent_contexts # Leaf contexts. - new_leaf = leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + new_leaf = + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() @test leafcontext(setleafcontext(context, new_leaf)) === new_leaf # Setting parent contexts as "leaf" means that the new leaf should be # the leaf of the parent context we just set as the leaf. - new_leaf = ParentContext((leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext())) + new_leaf = ParentContext(( + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + )) @test leafcontext(setleafcontext(context, new_leaf)) === leafcontext(new_leaf) end end # `IsParent` interface. - @testset "childcontext" begin + @testset "childcontext" begin @testset "$context" for context in parent_contexts @test childcontext(context) isa AbstractContext end @@ -81,13 +89,16 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c @testset "nested contexts" begin # Both of the following should result in the same context. context1 = ParentContext(ParentContext(ParentContext())) - context2 = setchildcontext(ParentContext(), setchildcontext(ParentContext(), ParentContext())) + context2 = setchildcontext( + ParentContext(), setchildcontext(ParentContext(), ParentContext()) + ) @test context1 === context2 end @testset "$context" for context in parent_contexts # Setting the child context to a leaf should now change the `leafcontext` accordingly. - new_leaf = leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + new_leaf = + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() new_context = setchildcontext(context, new_leaf) @test childcontext(new_context) === leafcontext(new_context) === new_leaf end From 59b39f1852f217739633922799246ebe949c4712 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 20:34:28 +0100 Subject: [PATCH 12/12] bump major version for Turing --- test/turing/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 05e6fb55e..8edbb9389 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -6,5 +6,5 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] DynamicPPL = "0.13" -Turing = "0.15, 0.16" +Turing = "0.17" julia = "1.3"