diff --git a/Project.toml b/Project.toml index e2c1dd29e..b7d5ab39e 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.3" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..2f933e4d3 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC +using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -67,6 +68,7 @@ export AbstractVarInfo, vectorize, # Model Model, + ContextualModel, getmissings, getargnames, generated_quantities, @@ -81,6 +83,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +102,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel @@ -116,11 +121,11 @@ abstract type AbstractContext end include("utils.jl") include("selector.jl") +include("contexts.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") include("varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") diff --git a/src/compiler.jl b/src/compiler.jl index 91fe78e2b..95e5f832d 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -22,8 +22,10 @@ function isassumption(expr::Union{Symbol,Expr}) let $vn = $(varname(expr)) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) + if !$(DynamicPPL.inargnames)($vn, __model__) || ( + __context__ isa $(DynamicPPL.ConditionContext) && + !$(Base.haskey)(__context__, $vn) + ) true else # Evaluate the LHS diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..052b4ab9c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi) return value end +function tilde_assume!(context::ConditionContext, right, vn, inds, vi) + if haskey(context, vn) + # Extract value. + value = if inds isa Tuple{} + getfield(context.values, getsym(vn)) + else + _getindex(getfield(context.values, getsym(vn)), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + tilde_observe!(context.context, right, value, vn, inds, vi) + else + value = tilde_assume!(context.context, right, vn, inds, vi) + end + + return value +end + # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi) return left end +function tilde_observe!(context::ConditionContext, right, left, vi) + return tilde_observe!(context.context, right, left, vi) +end + function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end +function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) + if vn in context + # Extract value. + value = if inds isa Tuple{} + getfield(context.values, sym) + else + _getindex(getfield(context.values, sym), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + dot_tilde_observe!(context.context, right, value, vn, inds, vi) + else + value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) + end + + return value +end + # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi @@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end +function dot_tilde_observe!(context::ConditionContext, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) +end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..e06d4c6eb 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -106,3 +106,66 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx + + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) + end +end + +@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} + names_expr = Expr(:tuple) + values_expr = Expr(:tuple) + + for (n, v) in zip(names, values.parameters) + if !(v <: Missing) + push!(names_expr.args, QuoteNode(n)) + push!(values_expr.args, :(nt.$n)) + end + end + + return :(NamedTuple{$names_expr}($values_expr)) +end + +function ConditionContext(context::ConditionContext, child_context::AbstractContext) + return ConditionContext(context.values, child_context) +end +function ConditionContext(values::NamedTuple) + return ConditionContext(values, DefaultContext()) +end + +function ConditionContext(values::NamedTuple, context::AbstractContext) + values_wo_missing = drop_missings(values) + return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), context.context) +end + +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + +# TODO: Can we maybe do this in a better way? +# When no second argument is given, we remove _all_ conditioned variables. +# TODO: Should we remove this and just return `context.context`? +# That will work better if `Model` becomes like `ContextualModel`. +decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext, sym) + return ConditionContext(BangBang.delete!!(context.values, sym), context.context) +end +function decondition(context::ConditionContext, sym, syms...) + return decondition( + ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... + ) +end diff --git a/src/model.jl b/src/model.jl index 448ee1111..3dfe13bbf 100644 --- a/src/model.jl +++ b/src/model.jl @@ -32,45 +32,36 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram +struct Model{ + F, + argnames, + defaultnames, + Targs, + Tdefaults, + conditionnames, + Ctx<:ConditionContext{conditionnames}, +} <: AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + context::Ctx - @doc """ - Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) - - Create a model of name `name` with evaluation function `f` and missing arguments - overwritten by `missings`. - """ - function Model{missings}( + function Model( name::Symbol, f::F, args::NamedTuple{argnames,Targs}, - defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults + defaults::NamedTuple{defaultnames,Tdefaults}=NamedTuple(), + context::ConditionContext{conditionnames}=ConditionContext(args, DefaultContext()), + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,conditionnames} + return new{F,argnames,defaultnames,Targs,Tdefaults,conditionnames,typeof(context)}( + name, f, args, defaults, context ) end end -""" - Model(name::Symbol, f, args::NamedTuple[, defaults::NamedTuple = ()]) - -Create a model of name `name` with evaluation function `f` and missing arguments deduced -from `args`. - -Default arguments `defaults` are used internally when constructing instances of the same -model with different arguments. -""" -@generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() -) where {F,argnames,Targs} - missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) +function Model(m::Model, context::ConditionContext) + return Model(m.name, m.f, m.args, m.defaults, context) end """ @@ -93,10 +84,12 @@ end (model::Model)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) + condition_context = ConditionContext(model.context.values, context) + if Threads.nthreads() == 1 - return evaluate_threadunsafe(model, varinfo, context) + return evaluate_threadunsafe(model, varinfo, condition_context) else - return evaluate_threadsafe(model, varinfo, context) + return evaluate_threadsafe(model, varinfo, condition_context) end end @@ -154,12 +147,42 @@ end Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( - model::Model{_F,argnames}, varinfo, context -) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + model::Model{_F,argnames,<:Any,<:Any,<:Any,conditionnames}, varinfo, context +) where {_F,argnames,conditionnames} + unwrap_args = [] + for var in argnames + # If `var` is not to be found in the `ConditionContext`, we fall back + # to the original args. + expr = if var in conditionnames + :($matchingvalue(context, varinfo, model.context.values.$var)) + else + :($matchingvalue(context, varinfo, model.args.$var)) + end + push!(unwrap_args, expr) + end + + return :(model.f( + model, varinfo, ConditionContext(model.context, context), $(unwrap_args...) + )) end +""" + condition(model::Model, values::NamedTuple) + condition(model::Model; values...) + +Condition `model` on the specifide `values`, i.e. make `model` treat `values` as observations. +""" +condition(model::Model, values) = Model(model, ConditionContext(values, model.context)) +condition(model::Model; values...) = condition(model, (; values...)) + +""" + decondition(model::Model) + decondition(model::Model, symbols...) + +Decondition `symbols` in `model`, i.e. make `model` treat them as random variables. +""" +decondition(model::Model, symbols...) = Model(model, decondition(model.context, symbols...)) + """ getargnames(model::Model)