diff --git a/Project.toml b/Project.toml index 758e95243..bb8faf6c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.13.2" +version = "0.14.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" AbstractPPL = "0.2" +BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..ac2734b47 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -9,6 +9,7 @@ using AbstractMCMC: AbstractMCMC using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules +using BangBang: BangBang using Random: Random @@ -81,6 +82,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +101,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel diff --git a/src/compiler.jl b/src/compiler.jl index d344717aa..4594808af 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -20,19 +20,66 @@ function isassumption(expr::Union{Symbol,Expr}) return quote let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) - true + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) || + $(DynamicPPL.inmissings)($vn, __model__) + true + else + $(maybe_view(expr)) === missing + end else - # Evaluate the LHS - $(maybe_view(expr)) === missing + false end end end end +""" + contextual_isassumption(context, vn) + +Return `true` if `vn` is considered an assumption by `context`. + +The default implementation for `AbstractContext` always returns `true`. +""" +contextual_isassumption(::IsLeaf, context, vn) = true +function contextual_isassumption(::IsParent, context, vn) + return contextual_isassumption(childcontext(context), vn) +end +function contextual_isassumption(context::AbstractContext, vn) + return contextual_isassumption(NodeTrait(context), context, vn) +end +function contextual_isassumption(context::ConditionContext, vn) + if hasvalue(context, vn) + val = getvalue(context, vn) + # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? + if eltype(val) >: Missing && val === missing + return true + else + return false + end + end + + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}` + # so we defer to `childcontext` if we haven't concluded that anything yet. + return contextual_isassumption(childcontext(context), vn) +end +function contextual_isassumption(context::PrefixContext, vn) + return contextual_isassumption(childcontext(context), prefix(context, vn)) +end + # failsafe: a literal is never an assumption isassumption(expr) = :(false) @@ -93,7 +140,7 @@ variables. # Example ```jldoctest; setup=:(using Distributions) -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); string(vns[end]) +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal([1.0, 1.0], [1.0 0.0; 0.0 1.0]), randn(2, 2), @varname(x)); string(vns[end]) "x[:,2]" julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end]) @@ -351,6 +398,11 @@ function generate_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) + end + $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -395,6 +447,11 @@ function generate_dot_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) + end + $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d15966fb6..834680d39 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds)) +_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) +function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} + val = getproperty(x, sym) + + # This should work with both cartesian and linear indexing. + return map(vns) do vn + _getindex(val, vn) + end +end # assume """ @@ -177,13 +186,14 @@ tilde_observe(::PriorContext, sampler, right, left, vi) = 0 function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(context::MiniBatchContext, right, left, vname, vi) - return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) +function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * + tilde_observe(context.context, sampler, right, left, vname, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index 9b8f7ac07..98eb4b85d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -251,3 +251,176 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx + + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) + end +end + +function ConditionContext(values::NamedTuple) + return ConditionContext(values, DefaultContext()) +end +function ConditionContext(values::NamedTuple, context::AbstractContext) + return ConditionContext{typeof(values)}(values, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext( + values::NamedTuple{Names}, context::ConditionContext +) where {Names} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), childcontext(context)) +end + +function Base.show(io::IO, context::ConditionContext) + return print(io, "ConditionContext($(context.values), $(childcontext(context)))") +end + +NodeTrait(context::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context +setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) + +""" + hasvalue(context, vn) + +Return `true` if `vn` is found in `context`. +""" +hasvalue(context, vn) = false + +function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} + return sym in vars +end +function hasvalue( + context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} +) where {vars,sym} + return sym in vars +end + +""" + getvalue(context, vn) + +Return value of `vn` in `context`. +""" +function getvalue(context::AbstractContext, vn) + return error("context $(context) does not contain value for $vn") +end +getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) + +""" + hasvalue_nested(context, vn) + +Return `true` if `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`, +not recursively checking if `vn` is in any of its descendants. +""" +function hasvalue_nested(context::AbstractContext, vn) + return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) +end +hasvalue_nested(::IsLeaf, context, vn) = hasvalue(context, vn) +function hasvalue_nested(::IsParent, context, vn) + return hasvalue(context, vn) || hasvalue_nested(childcontext(context), vn) +end +function hasvalue_nested(context::PrefixContext, vn) + return hasvalue_nested(childcontext(context), prefix(context, vn)) +end + +""" + getvalue_nested(context, vn) + +Return the value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getvalue`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. +""" +function getvalue_nested(context::AbstractContext, vn) + return getvalue_nested(NodeTrait(getvalue_nested, context), context, vn) +end +function getvalue_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getvalue_nested(context::PrefixContext, vn) + return getvalue_nested(childcontext(context), prefix(context, vn)) +end +function getvalue_nested(::IsParent, context, vn) + return if hasvalue(context, vn) + getvalue(context, vn) + else + getvalue_nested(childcontext(context), vn) + end +end + +""" + condition([context::AbstractContext,] values::NamedTuple) + condition([context::AbstractContext]; values...) + +Return `ConditionContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`decondition`](@ref) +""" +condition(; values...) = condition(DefaultContext(), NamedTuple(values)) +condition(values::NamedTuple) = condition(DefaultContext(), values) +condition(context::AbstractContext, values::NamedTuple{()}) = context +condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) +condition(context::AbstractContext; values...) = condition(context, NamedTuple(values)) + +""" + decondition(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer conditioned on. + +Note that this recursively traverses contexts, deconditioning all along the way. + +See also: [`condition`](@ref) +""" +decondition(::IsLeaf, context, args...) = context +function decondition(::IsParent, context, args...) + return setchildcontext(context, decondition(childcontext(context), args...)) +end +decondition(context, args...) = decondition(NodeTrait(context), context, args...) +function decondition(context::ConditionContext) + return decondition(childcontext(context)) +end +function decondition(context::ConditionContext, sym) + return condition( + decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym) + ) +end +function decondition(context::ConditionContext, sym, syms...) + return decondition( + condition( + decondition(childcontext(context), syms...), + BangBang.delete!!(context.values, sym), + ), + syms..., + ) +end + +""" + conditioned(context::AbstractContext) + +Return `NamedTuple` of values that are conditioned on under context`. + +Note that this will recursively traverse the context stack and return +a merged version of the condition values. +""" +function conditioned(context::AbstractContext) + return conditioned(NodeTrait(conditioned, context), context) +end +conditioned(::IsLeaf, context) = () +conditioned(::IsParent, context) = conditioned(childcontext(context)) +function conditioned(context::ConditionContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `conditioned` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return merge(context.values, conditioned(childcontext(context))) +end diff --git a/src/model.jl b/src/model.jl index 448ee1111..a081bf831 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,12 +32,13 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + context::Ctx @doc """ Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) @@ -50,9 +51,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults + context::Ctx=DefaultContext(), + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + name, f, args, defaults, context ) end end @@ -67,12 +69,302 @@ Default arguments `defaults` are used internally when constructing instances of model with different arguments. """ @generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() + name::Symbol, + f::F, + args::NamedTuple{argnames,Targs}, + defaults::NamedTuple=NamedTuple(), + context::AbstractContext=DefaultContext(), ) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) + return :(Model{$missings}(name, f, args, defaults, context)) +end + +function contextualize(model::Model, context::AbstractContext) + return Model(model.name, model.f, model.args, model.defaults, context) end +""" + model | (x = 1.0, ...) + +Return a `Model` which now treats variables on the right-hand side as observations. + +See [`condition`](@ref) for more information and examples. +""" +Base.:|(model::Model, values) = condition(model, values) + +""" + condition(model::Model; values...) + condition(model::Model, values::NamedTuple) + +Return a `Model` which now treats the variables in `values` as observations. + +See also: [`decondition`](@ref), [`conditioned`](@ref) + +# Limitations + +This does currently _not_ work with variables that are +provided to the model as arguments, e.g. `@model function demo(x) ... end` +means that `condition` will not affect the variable `x`. + +Therefore if one wants to make use of `condition` and [`decondition`](@ref) +one should not be specifying any random variables as arguments. + +This is done for the sake of backwards compatibility. + +# Examples +## Simple univariate model +```jldoctest condition +julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reproducibility. + +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + return (; m=m, x=x) + end +demo (generic function with 1 method) + +julia> model = demo(); + +julia> model(rng) +(m = -0.6702516921145671, x = -0.22312984965118443) + +julia> # Create a new instance which treats `x` as observed + # with value `100.0`, and similarly for `m=1.0`. + conditioned_model = condition(model, x=100.0, m=1.0); + +julia> conditioned_model(rng) +(m = 1.0, x = 100.0) + +julia> # Let's only condition on `x = 100.0`. + conditioned_model = condition(model, x = 100.0); + +julia> conditioned_model(rng) +(m = 1.3736306979834252, x = 100.0) + +julia> # We can also use the nicer `|` syntax. + conditioned_model = model | (x = 100.0, ); + +julia> conditioned_model(rng) +(m = 1.3095394956381083, x = 100.0) +``` + +## Condition only a part of a multivariate variable + +Not only can be condition on multivariate random variables, but +we can also use the standard mechanism of setting something to `missing` +in the call to `condition` to only condition on a part of the variable. + +```jldoctest condition +julia> @model function demo_mv(::Type{TV}=Float64) where {TV} + m = Vector{TV}(undef, 2) + m[1] ~ Normal() + m[2] ~ Normal() + return m + end +demo_mv (generic function with 2 methods) + +julia> model = demo_mv(); + +julia> conditioned_model = condition(model, m = [missing, 1.0]); + +julia> conditioned_model(rng) # (✓) `m[1]` sampled, `m[2]` is fixed +2-element Vector{Float64}: + 0.12607002180931043 + 1.0 +``` + +Intuitively one might also expect to be able to write `model | (x[1] = 1.0, )`. +Unfortunately this is not supported due to performance. + +```jldoctest condition +julia> condition(model, var"x[2]" = 1.0)(rng) # (×) `x[2]` is not set to 1.0. +2-element Vector{Float64}: + 0.683947930996541 + -1.019202452456547 +``` + +We will likely provide some syntactic sugar for this in the future. + +## Nested models + +`condition` of course also supports the use of nested models through +the use of [`@submodel`](@ref). + +```jldoctest condition +julia> @model demo_inner() = m ~ Normal() +demo_inner (generic function with 1 method) + +julia> @model function demo_outer() + m = @submodel demo_inner() + return m + end +demo_outer (generic function with 1 method) + +julia> model = demo_outer(); + +julia> model(rng) +-0.7935128416361353 + +julia> conditioned_model = model | (m = 1.0, ); + +julia> conditioned_model(rng) +1.0 +``` + +But one needs to be careful when prefixing variables in the nested models: + +```jldoctest condition +julia> @model function demo_outer_prefix() + m = @submodel inner demo_inner() + return m + end +demo_outer_prefix (generic function with 1 method) + +julia> # This doesn't work now! + conditioned_model = demo_outer_prefix() | (m = 1.0, ); + +julia> conditioned_model(rng) +1.7747246334368165 + +julia> # `m` in `demo_inner` is referred to as `inner.m` internally, so we do: + conditioned_model = demo_outer_prefix() | (var"inner.m" = 1.0, ); + +julia> conditioned_model(rng) +1.0 + +julia> # Note that the above `var"..."` is just standard Julia syntax: + keys((var"inner.m" = 1.0, )) +(Symbol("inner.m"),) +``` + +The difference is maybe more obvious once we look at how these different +in their trace/`VarInfo`: + +```jldoctest condition +julia> keys(VarInfo(demo_outer())) +1-element Vector{VarName{:m, Tuple{}}}: + m + +julia> keys(VarInfo(demo_outer_prefix())) +1-element Vector{VarName{Symbol("inner.m"), Tuple{}}}: + inner.m +``` + +From this we can tell what the correct way to condition `m` within `demo_inner` +is in the two different models. + +""" +condition(model::Model; values...) = condition(model, NamedTuple(values)) +function condition(model::Model, values) + return contextualize(model, condition(model.context, values)) +end + +""" + decondition(model::Model) + decondition(model::Model, syms...) + +Return a `Model` for which `syms...` are _not_ considered observations. +If no `syms` are provided, then all variables currently considered observations +will no longer be. + +This is essentially the inverse of [`condition`](@ref). This also means that +it suffers from the same limitiations. + +# Examples +```jldoctest +julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reproducibility. + +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + return (; m=m, x=x) + end +demo (generic function with 1 method) + +julia> conditioned_model = condition(demo(), m = 1.0, x = 10.0); + +julia> conditioned_model(rng) +(m = 1.0, x = 10.0) + +julia> model = decondition(conditioned_model, :m); + +julia> model(rng) +(m = -0.6702516921145671, x = 10.0) + +julia> # `decondition` multiple at once: + decondition(model, :m, :x)(rng) +(m = 0.4471218424633827, x = 1.820752540446808) + +julia> # `decondition` without any symbols will `decondition` all variables. + decondition(model)(rng) +(m = 1.3095394956381083, x = 1.4356095174474188) + +julia> # Usage of `Val` to perform `decondition` at compile-time if possible + # is also supported. + model = decondition(conditioned_model, Val{:m}()); + +julia> model(rng) +(m = 0.683947930996541, x = 10.0) +``` +""" +function decondition(model::Model, syms...) + return contextualize(model, decondition(model.context, syms...)) +end + +""" + observations(model::Model) + +Alias for [`conditioned`](@ref). +""" +observations(model::Model) = conditioned(model) + +""" + conditioned(model::Model) + +Return `NamedTuple` of values that are conditioned on under `model`. + +# Examples +```jldoctest +julia> using Distributions + +julia> using DynamicPPL: conditioned, contextualize + +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + end +demo (generic function with 1 method) + +julia> m = demo(); + +julia> # Returns all the variables we have conditioned on + their values. + conditioned(condition(m, x=100.0, m=1.0)) +(x = 100.0, m = 1.0) + +julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). + cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, m = 1.0) + +julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, + # `a.m` is treated as a random variable. + keys(VarInfo(cm)) +1-element Vector{VarName{Symbol("a.m"), Tuple{}}}: + a.m + +julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. + cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, a.m = 1.0) + +julia> keys(VarInfo(cm)) # <= no variables are sampled +Any[] +``` +""" +conditioned(model::Model) = conditioned(model.context) + """ (model::Model)([rng, varinfo, sampler, context]) @@ -156,8 +448,22 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf @generated function _evaluate( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + unwrap_args = [ + :($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames + ] + # We want to give `context` precedence over `model.context` while also + # preserving the leaf context of `context`. We can do this by + # 1. Set the leaf context of `model.context` to `leafcontext(context)`. + # 2. Set leaf context of `context` to the context resulting from (1). + # The result is: + # `context` -> `childcontext(context)` -> ... -> `model.context` + # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` + return quote + context_new = setleafcontext( + context, setleafcontext(model.context, leafcontext(context)) + ) + model.f(model, varinfo, context_new, $(unwrap_args...)) + end end """ diff --git a/test/contexts.jl b/test/contexts.jl index 7793dfc74..80615cbd9 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,8 +8,15 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext - + PointwiseLikelihoodContext, + contextual_isassumption, + ConditionContext, + hasvalue, + getvalue, + hasvalue_nested, + getvalue_nested + +# Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext context::C end @@ -19,6 +26,54 @@ DynamicPPL.childcontext(context::ParentContext) = context.context DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") +# TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractContext) + if NodeTrait(context) isa IsLeaf + return nothing + end + + return context, context +end +function Base.iterate(_::AbstractContext, context::AbstractContext) + return _iterate(NodeTrait(context), context) +end +_iterate(::IsLeaf, context) = nothing +function _iterate(::IsParent, context) + child = childcontext(context) + return child, child +end + +Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() + +""" + remove_prefix(vn::VarName) + +Return `vn` but now with the prefix removed. +""" +function remove_prefix(vn::VarName) + return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( + vn.indexing + ) +end + +""" + varnames(vn::VarName, val) + +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varnames(vn::VarName, val::Real) = [vn] +function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return (VarName(vn, (vn.indexing..., Tuple(I))) for I in CartesianIndices(val)) +end +function varnames(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varnames(VarName(vn, (vn.indexing..., Tuple(I))), val[I]) for + I in CartesianIndices(val) + ) +end + @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] @@ -28,6 +83,10 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), PointwiseLikelihoodContext(), + ConditionContext((x=1.0,)), + ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), + ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), + ConditionContext((x=[1.0, missing],)), ] contexts = vcat(child_contexts, parent_contexts) @@ -104,6 +163,69 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c end end + @testset "contextual_isassumption" begin + @testset "$context" for context in contexts + # Any `context` should return `true` by default. + @test contextual_isassumption(context, VarName{gensym(:x)}()) + + if any(Base.Fix2(isa, ConditionContext), context) + # We have a `ConditionContext` among us. + # Let's first extract the conditioned variables. + conditioned_values = DynamicPPL.conditioned(context) + + for (sym, val) in pairs(conditioned_values) + vn = VarName{sym}() + + # We need to drop the prefix of `var` since in `contextual_isassumption` + # it will be threaded through the `PrefixContext` before it reaches + # `ConditionContext` with the conditioned variable. + vn_without_prefix = remove_prefix(vn) + + # Let's check elementwise. + for vn_child in varnames(vn_without_prefix, val) + if DynamicPPL._getindex(val, vn_child.indexing) === missing + @test contextual_isassumption(context, vn_child) + else + @test !contextual_isassumption(context, vn_child) + end + end + end + end + end + end + + @testset "getvalue_nested & hasvalue_nested" begin + @testset "$context" for context in contexts + fake_vn = VarName{gensym(:x)}() + @test !hasvalue_nested(context, fake_vn) + @test_throws ErrorException getvalue_nested(context, fake_vn) + + if any(Base.Fix2(isa, ConditionContext), context) + # `ConditionContext` specific. + + # Let's first extract the conditioned variables. + conditioned_values = DynamicPPL.conditioned(context) + + for (sym, val) in pairs(conditioned_values) + vn = VarName{sym}() + + # We need to drop the prefix of `var` since in `contextual_isassumption` + # it will be threaded through the `PrefixContext` before it reaches + # `ConditionContext` with the conditioned variable. + vn_without_prefix = remove_prefix(vn) + + for vn_child in varnames(vn_without_prefix, val) + # `vn_child` should be in `context`. + @test hasvalue_nested(context, vn_child) + # Value should be the same as extracted above. + @test getvalue_nested(context, vn_child) === + DynamicPPL._getindex(val, vn_child.indexing) + end + end + end + end + end + @testset "PrefixContext" begin ctx = @inferred PrefixContext{:f}( PrefixContext{:e}( diff --git a/test/runtests.jl b/test/runtests.jl index d83be0eea..bb2ae579c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,7 +57,15 @@ include("test_util.jl") DocMeta.setdocmeta!( DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true ) - doctest(DynamicPPL; manual=false) + doctestfilters = [ + # Older versions will show "0 element Array" instead of "Type[]". + r"(Any\[\]|0-element Array{.+,[0-9]+})", + # Older versions will show "Array{...,1}" instead of "Vector{...}". + r"(Array{.+,\s?1}|Vector{.+})", + # Older versions will show "Array{...,2}" instead of "Matrix{...}". + r"(Array{.+,\s?2}|Matrix{.+})", + ] + doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) end end diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 8edbb9389..95212c061 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.13" +DynamicPPL = "0.14" Turing = "0.17" julia = "1.3"