diff --git a/Project.toml b/Project.toml index 5678c050a..ecf47b50f 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..bfcc49e6b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC +using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -67,6 +68,7 @@ export AbstractVarInfo, vectorize, # Model Model, + ContextualModel, getmissings, getargnames, generated_quantities, @@ -81,6 +83,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +102,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel @@ -129,5 +134,6 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("contextual_model.jl") end # module diff --git a/src/compiler.jl b/src/compiler.jl index c70bbff1e..aa54ff9ba 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -20,19 +20,60 @@ function isassumption(expr::Union{Symbol,Expr}) return quote let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) - true + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) + true + else + $expr === missing + end else - # Evaluate the LHS - $(maybe_view(expr)) === missing + false end end end end +""" + contextual_isassumption(context, vn) + +Return `true` if `vn` is considered an assumption by `context`. + +The default implementation for `AbstractContext` always returns `true`. +""" +contextual_isassumption(context::AbstractContext, vn) = true +function contextual_isassumption(context::ConditionContext, vn) + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. + + # We have either of the following cases: + # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, + # which means we have a value to replace with and we don't need to recurse. + # 2. One of the decendant contexts consider it as an observation, i.e. + # `contextual_isassumption` evaluates to `false`. + # The below then evaluates to `!(false || true) === false`. + # 3. Neither `context` nor any of it's decendants considers it an observation, + # in which case the below evaluates to `!(false || false) === true`. + return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) +end +function contextual_isassumption(context::PrefixContext, vn) + return contextual_isassumption(context.context, prefix(context, vn)) +end +function contextual_isassumption(context::MiniBatchContext, vn) + return contextual_isassumption(context.context, vn) +end + # failsafe: a literal is never an assumption isassumption(expr) = :(false) @@ -336,6 +377,11 @@ function generate_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left = $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -380,6 +426,11 @@ function generate_dot_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left .= $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), diff --git a/src/context_implementations.jl b/src/context_implementations.jl index cd7a92535..6336cfd06 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(view(x, first(inds)...), Base.tail(inds)) +_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) +function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} + val = getproperty(x, sym) + + # This should work with both cartesian and linear indexing. + return map(vns) do vn + _getindex(val, vn) + end +end # assume """ @@ -118,6 +127,15 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end +# `ConditionContext` +function tilde_assume(context::ConditionContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) +end + +function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) +end + """ tilde_assume!(context, right, vn, inds, vi) @@ -189,6 +207,14 @@ function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function tilde_observe(context::ConditionContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, vname, vi) +end +function tilde_observe(context::ConditionContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) +end + """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -402,6 +428,17 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, ) end +# `ConditionContext` +function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function dot_tilde_assume( + rng, context::ConditionContext, sampler, right, left, vn, inds, vi +) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) +end + """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -609,6 +646,11 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function dot_tilde_observe(context::ConditionContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) +end + """ dot_tilde_observe!(context, right, left, vname, vinds, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..926dd0fe0 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -106,3 +106,73 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx + + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) + end +end + +function ConditionContext(context::ConditionContext, child_context::AbstractContext) + return ConditionContext(context.values, child_context) +end +function ConditionContext(values::NamedTuple) + return ConditionContext(values, DefaultContext()) +end +function ConditionContext(values::NamedTuple, context::AbstractContext) + return ConditionContext{typeof(values)}(values, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext( + values::NamedTuple{Names}, context::ConditionContext +) where {Names} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), context.context) +end + +function Base.show(io::IO, context::ConditionContext) + return print(io, "ConditionContext($(context.values), $(context.context))") +end + +function getvalue(context::ConditionContext, vn) + return if haskey(context, vn) + _getvalue(context.values, vn) + else + getvalue(context.context, vn) + end +end +getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) +getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) + +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + +function Base.haskey( + context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} +) where {vars,sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + +# TODO: Can we maybe do this in a better way? +# When no second argument is given, we remove _all_ conditioned variables. +# TODO: Should we remove this and just return `context.context`? +# That will work better if `Model` becomes like `ContextualModel`. +decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext, sym) + return ConditionContext(BangBang.delete!!(context.values, sym), context.context) +end +function decondition(context::ConditionContext, sym, syms...) + return decondition( + ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... + ) +end diff --git a/src/contextual_model.jl b/src/contextual_model.jl new file mode 100644 index 000000000..793491720 --- /dev/null +++ b/src/contextual_model.jl @@ -0,0 +1,29 @@ +struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel + context::Ctx + model::M +end + +function contextualize(model::AbstractModel, context::AbstractContext) + return ContextualModel(context, model) +end + +# TODO: What do we do for other contexts? Could handle this in general if we had a +# notion of wrapper-, primitive-context, etc. +function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) + # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as + # `ConditionContext` child. + return _evaluate( + cmodel.model, varinfo, ConditionContext(cmodel.context.values, context) + ) +end + +condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) +condition(model::AbstractModel; values...) = condition(model, (; values...)) +function condition(cmodel::ContextualModel{<:ConditionContext}, values) + return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) +end + +decondition(model::AbstractModel, args...) = model +function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) + return contextualize(cmodel.model, decondition(cmodel.context, syms...)) +end diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 2901432d1..807dd0b8e 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -13,6 +13,14 @@ function PointwiseLikelihoodContext( ) end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end + +function contextual_isobservation(context::PointwiseLikelihoodContext, vn) + return contextual_isobservation(context.context, vn) +end + function Base.push!( context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, diff --git a/src/model.jl b/src/model.jl index 448ee1111..7b5d8918a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,5 @@ +abstract type AbstractModel <: AbstractProbabilisticProgram end + """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -32,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -82,7 +83,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( +function (model::AbstractModel)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -91,8 +92,8 @@ function (model::Model)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::Model)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) +(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) +function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else @@ -100,17 +101,17 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) +function (model::AbstractModel)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end diff --git a/src/varinfo.jl b/src/varinfo.jl index 64c122dc2..ba6e157e9 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::Model, + model::AbstractModel, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end