diff --git a/Project.toml b/Project.toml index 0b92e07d2..b6b66304e 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.13.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" AbstractPPL = "0.2" +BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..ac2734b47 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -9,6 +9,7 @@ using AbstractMCMC: AbstractMCMC using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules +using BangBang: BangBang using Random: Random @@ -81,6 +82,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +101,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel diff --git a/src/compiler.jl b/src/compiler.jl index d344717aa..7880b9f1c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -20,19 +20,66 @@ function isassumption(expr::Union{Symbol,Expr}) return quote let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) - true + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) || + $(DynamicPPL.inmissings)($vn, __model__) + true + else + $(maybe_view(expr)) === missing + end else - # Evaluate the LHS - $(maybe_view(expr)) === missing + false end end end end +""" + contextual_isassumption(context, vn) + +Return `true` if `vn` is considered an assumption by `context`. + +The default implementation for `AbstractContext` always returns `true`. +""" +contextual_isassumption(::IsLeaf, context, vn) = true +function contextual_isassumption(::IsParent, context, vn) + return contextual_isassumption(childcontext(context), vn) +end +function contextual_isassumption(context::AbstractContext, vn) + return contextual_isassumption(NodeTrait(context), context, vn) +end +function contextual_isassumption(context::ConditionContext, vn) + if hasvalue(context, vn) + val = getvalue(context, vn) + # TODO: Do we even need the `>: Missing` to help the compiler? + if eltype(val) >: Missing && val === missing + return true + else + return false + end + end + + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}` + # so we defer to `childcontext` if we haven't concluded that anything yet. + return contextual_isassumption(childcontext(context), vn) +end +function contextual_isassumption(context::PrefixContext, vn) + return contextual_isassumption(childcontext(context), prefix(context, vn)) +end + # failsafe: a literal is never an assumption isassumption(expr) = :(false) @@ -351,6 +398,11 @@ function generate_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) + end + $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -395,6 +447,11 @@ function generate_dot_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) + end + $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d15966fb6..54a198acc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds)) +_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) +function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} + val = getproperty(x, sym) + + # This should work with both cartesian and linear indexing. + return map(vns) do vn + _getindex(val, vn) + end +end # assume """ @@ -162,6 +171,8 @@ function tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts +# TODO: Should we maybe not do `args...` here but instead be explicit? +# Could help avoid stealthy bugs. function tilde_observe(context::AbstractContext, args...) return tilde_observe(NodeTrait(tilde_observe, context), context, args...) end @@ -177,13 +188,14 @@ tilde_observe(::PriorContext, sampler, right, left, vi) = 0 function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(context::MiniBatchContext, right, left, vname, vi) - return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) +function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * + tilde_observe(context.context, sampler, right, left, vname, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index 9b8f7ac07..55b070c6d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -251,3 +251,267 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx + + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) + end +end + +function ConditionContext(values::NamedTuple) + return ConditionContext(values, DefaultContext()) +end +function ConditionContext(values::NamedTuple, context::AbstractContext) + return ConditionContext{typeof(values)}(values, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext( + values::NamedTuple{Names}, context::ConditionContext +) where {Names} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::ConditionContext) + return print(io, "ConditionContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(context::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context +setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) + +""" + hasvalue(context, vn) + +Return `true` if `vn` is found in `context`. +""" +hasvalue(context, vn) = false + +function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} + return sym in vars +end +function hasvalue( + context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} +) where {vars,sym} + return sym in vars +end + +""" + getvalue(context, vn) + +Return value of `vn` in `context`. +""" +function getvalue(context::AbstractContext, vn) + return error("context $(context) does not contain value for $vn") +end +getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) + +""" + hasvalue_nested(context, vn) + +Return `true` if `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`, +not recursively checking if `vn` is in any of its descendants. +""" +function hasvalue_nested(context::AbstractContext, vn) + return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) +end +hasvalue_nested(::IsLeaf, context, vn) = hasvalue(context, vn) +function hasvalue_nested(::IsParent, context, vn) + return hasvalue(context, vn) || hasvalue_nested(childcontext(context), vn) +end +function hasvalue_nested(context::PrefixContext, vn) + return hasvalue_nested(childcontext(context), prefix(context, vn)) +end + +""" + getvalue_nested(context, vn) + +Return the value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getvalue`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getvalue_nested(context::AbstractContext, vn) + return getvalue_nested(NodeTrait(getvalue_nested, context), context, vn) +end +function getvalue_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getvalue_nested(context::PrefixContext, vn) + return getvalue_nested(childcontext(context), prefix(context, vn)) +end +function getvalue_nested(::IsParent, context, vn) + return if hasvalue(context, vn) + getvalue(context, vn) + else + getvalue_nested(childcontext(context), vn) + end +end + +""" + context([context::AbstractContext,] values::NamedTuple) + context([context::AbstractContext]; values...) + +Return `ConditionContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`decondition`](@ref) +""" +condition(; values...) = condition(DefaultContext(), (; values...)) +condition(values::NamedTuple) = condition(DefaultContext(), values) +condition(context::AbstractContext, values::NamedTuple{()}) = context +condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) +condition(context::AbstractContext; values...) = condition(context, (; values...)) +""" + decondition(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer conditioned on. + +Note that this recursively traverses contexts, deconditioning all along the way. + +See also: [`condition`](@ref) + +# Examples +```jldoctest +julia> using DynamicPPL: AbstractContext, leafcontext, setleafcontext, childcontext, setchildcontext + +julia> struct ParentContext{C} <: AbstractContext + context::C + end + +julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() + +julia> DynamicPPL.childcontext(context::ParentContext) = context.context + +julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) + +julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + +julia> ctx = DefaultContext() +DefaultContext() + +julia> decondition(ctx) === ctx # this is a no-op +true + +julia> ctx = condition(x = 1.0) # default "constructor" for `ConditionContext` +ConditionContext((x = 1.0,), DefaultContext()) + +julia> decondition(ctx) # `decondition` without arguments drops all conditioning +DefaultContext() + +julia> # Nested conditioning is supported. + ctx_nested = condition(ParentContext(condition(y=2.0)), x=1.0) +ConditionContext((x = 1.0,), ParentContext(ConditionContext((y = 2.0,), DefaultContext()))) + +julia> # We can also specify which variables to drop. + decondition(ctx_nested, :x) +ParentContext(ConditionContext((y = 2.0,), DefaultContext())) + +julia> # No matter the nested level. + decondition(ctx_nested, :y) +ConditionContext((x = 1.0,), ParentContext(DefaultContext())) + +julia> # Or specify multiple at in one call. + decondition(ctx_nested, :x, :y) +ParentContext(DefaultContext()) + +julia> decondition(ctx_nested) +ParentContext(DefaultContext()) + +julia> # `Val` is also supported. + decondition(ctx_nested, Val(:x)) +ParentContext(ConditionContext((y = 2.0,), DefaultContext())) + +julia> decondition(ctx_nested, Val(:y)) +ConditionContext((x = 1.0,), ParentContext(DefaultContext())) + +julia> decondition(ctx_nested, Val(:x), Val(:y)) +ParentContext(DefaultContext()) +``` +""" +decondition(::IsLeaf, context, args...) = context +function decondition(::IsParent, context, args...) + return setchildcontext(context, decondition(childcontext(context), args...)) +end +decondition(context, args...) = decondition(NodeTrait(context), context, args...) +function decondition(context::ConditionContext) + return decondition(childcontext(context)) +end +function decondition(context::ConditionContext, sym) + return condition( + decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym) + ) +end +function decondition(context::ConditionContext, sym, syms...) + return decondition( + condition( + decondition(childcontext(context), syms...), + BangBang.delete!!(context.values, sym), + ), + syms..., + ) +end + +""" + conditioned(model::Model) + conditioned(context::AbstractContext) + +Return `NamedTuple` of values that are conditioned on under `model`/`context`. + +# Examples +```jldoctest +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + end +demo (generic function with 1 methods) + +julia> m = demo(); + +julia> # Returns all the variables we have conditioned on + their values. + conditioned(condition(m, x=100.0, m=1.0)) +(x = 100.0, m = 1.0) + +julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). + cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, m = 1.0) + +julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, + # `a.m` is treated as a random variable. + keys(VarInfo(cm)) +1-element Vector{VarName{Symbol("a.m"), Tuple{}}}: + a.m + +julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. + cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, a.m = 1.0) + +julia> keys(VarInfo(cm)) # <= no variables are sampled +Any[] +``` +""" +function conditioned(context::AbstractContext) + return conditioned(NodeTrait(conditioned, context), context) +end +conditioned(::IsLeaf, context) = () +conditioned(::IsParent, context) = conditioned(childcontext(context)) +function conditioned(context::ConditionContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `conditioned` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return merge(context.values, conditioned(childcontext(context))) +end diff --git a/src/model.jl b/src/model.jl index 448ee1111..9f163910d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,12 +32,13 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + context::Ctx @doc """ Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) @@ -50,9 +51,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults + context::Ctx=DefaultContext(), + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + name, f, args, defaults, context ) end end @@ -67,12 +69,39 @@ Default arguments `defaults` are used internally when constructing instances of model with different arguments. """ @generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() + name::Symbol, + f::F, + args::NamedTuple{argnames,Targs}, + defaults::NamedTuple=NamedTuple(), + context::AbstractContext=DefaultContext(), ) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) + return :(Model{$missings}(name, f, args, defaults, context)) +end + +function contextualize(model::Model, context::AbstractContext) + return Model(model.name, model.f, model.args, model.defaults, context) +end + +Base.:|(model::Model, values) = condition(model, values) + +condition(model::Model; values...) = condition(model, (; values...)) +function condition(model::Model, values) + return contextualize(model, condition(model.context, values)) +end + +function decondition(model::Model, syms...) + return contextualize(model, decondition(model.context, syms...)) end +""" + observations(model::Model) + +Alias for [`conditioned`](@ref). +""" +observations(model::Model) = conditioned(model) +conditioned(model::Model) = conditioned(model.context) + """ (model::Model)([rng, varinfo, sampler, context]) @@ -156,8 +185,22 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf @generated function _evaluate( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + unwrap_args = [ + :($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames + ] + # We want to give `context` precedence over `model.context` while also + # preserving the leaf context of `context`. We can do this by + # 1. Set the leaf context of `model.context` to `leafcontext(context)`. + # 2. Set leaf context of `context` to the context resulting from (1). + # The result is: + # `context` -> `childcontext(context)` -> ... -> `model.context` + # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` + return quote + context_new = setleafcontext( + context, setleafcontext(model.context, leafcontext(context)) + ) + model.f(model, varinfo, context_new, $(unwrap_args...)) + end end """ diff --git a/test/contexts.jl b/test/contexts.jl index 7793dfc74..fa96a0ec4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,8 +8,15 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext - + PointwiseLikelihoodContext, + contextual_isassumption, + ConditionContext, + hasvalue, + getvalue, + hasvalue_nested, + getvalue_nested + +# Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext context::C end @@ -19,6 +26,50 @@ DynamicPPL.childcontext(context::ParentContext) = context.context DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") +# TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractContext) + if NodeTrait(context) isa IsLeaf + return nothing + end + + return context, context +end +function Base.iterate(_::AbstractContext, context::AbstractContext) + return _iterate(NodeTrait(context), context) +end +_iterate(::IsLeaf, context) = nothing +function _iterate(::IsParent, context) + child = childcontext(context) + return child, child +end + +Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() + +""" + remove_prefix(vn::VarName) + +Return `vn` but now with the prefix removed. +""" +remove_prefix(vn::VarName) = VarName{Symbol(split(string(vn), ".")[end])}(vn.indexing) + +""" + varnames(vn::VarName, val) + +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varnames(vn::VarName, val::Real) = [vn] +function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return (VarName(vn, (vn.indexing..., Tuple(I))) for I in CartesianIndices(val)) +end +function varnames(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varnames(VarName(vn, (vn.indexing..., Tuple(I))), val[I]) for + I in CartesianIndices(val) + ) +end + @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] @@ -28,6 +79,10 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), PointwiseLikelihoodContext(), + ConditionContext((x=1.0,)), + ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), + ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), + ConditionContext((x=[1.0, missing],)), ] contexts = vcat(child_contexts, parent_contexts) @@ -104,6 +159,69 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c end end + @testset "contextual_isassumption" begin + @testset "$context" for context in contexts + # Any `context` should return `true` by default. + @test contextual_isassumption(context, VarName{gensym(:x)}()) + + if any(Base.Fix2(isa, ConditionContext), context) + # We have a `ConditionContext` among us. + # Let's first extract the conditioned variables. + conditioned_values = DynamicPPL.conditioned(context) + + for (sym, val) in pairs(conditioned_values) + vn = VarName{sym}() + + # We need to drop the prefix of `var` since in `contextual_isassumption` + # it will be threaded through the `PrefixContext` before it reaches + # `ConditionContext` with the conditioned variable. + vn_without_prefix = remove_prefix(vn) + + # Let's check elementwise. + for vn_child in varnames(vn_without_prefix, val) + if DynamicPPL._getindex(val, vn_child.indexing) === missing + @test contextual_isassumption(context, vn_child) + else + @test !contextual_isassumption(context, vn_child) + end + end + end + end + end + end + + @testset "getvalue_nested & hasvalue_nested" begin + @testset "$context" for context in contexts + fake_vn = VarName{gensym(:x)}() + @test !hasvalue_nested(context, fake_vn) + @test_throws ErrorException getvalue_nested(context, fake_vn) + + if any(Base.Fix2(isa, ConditionContext), context) + # `ConditionContext` specific. + + # Let's first extract the conditioned variables. + conditioned_values = DynamicPPL.conditioned(context) + + for (sym, val) in pairs(conditioned_values) + vn = VarName{sym}() + + # We need to drop the prefix of `var` since in `contextual_isassumption` + # it will be threaded through the `PrefixContext` before it reaches + # `ConditionContext` with the conditioned variable. + vn_without_prefix = remove_prefix(vn) + + for vn_child in varnames(vn_without_prefix, val) + # `vn_child` should be in `context`. + @test hasvalue_nested(context, vn_child) + # Value should be the same as extracted above. + @test getvalue_nested(context, vn_child) === + DynamicPPL._getindex(val, vn_child.indexing) + end + end + end + end + end + @testset "PrefixContext" begin ctx = @inferred PrefixContext{:f}( PrefixContext{:e}(