Skip to content
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.12.3"
[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down
7 changes: 6 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ using Distributions
using Bijectors

using AbstractMCMC: AbstractMCMC
using BangBang: BangBang
using ChainRulesCore: ChainRulesCore
using MacroTools: MacroTools
using ZygoteRules: ZygoteRules
Expand Down Expand Up @@ -67,6 +68,7 @@ export AbstractVarInfo,
vectorize,
# Model
Model,
ContextualModel,
getmissings,
getargnames,
generated_quantities,
Expand All @@ -81,6 +83,7 @@ export AbstractVarInfo,
PriorContext,
MiniBatchContext,
PrefixContext,
ConditionContext,
assume,
dot_assume,
observe,
Expand All @@ -99,6 +102,8 @@ export AbstractVarInfo,
logprior,
logjoint,
pointwise_loglikelihoods,
condition,
decondition,
# Convenience macros
@addlogprob!,
@submodel
Expand All @@ -116,11 +121,11 @@ abstract type AbstractContext end

include("utils.jl")
include("selector.jl")
include("contexts.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("threadsafe.jl")
include("context_implementations.jl")
Expand Down
6 changes: 4 additions & 2 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ function isassumption(expr::Union{Symbol,Expr})
let $vn = $(varname(expr))
# This branch should compile nicely in all cases except for partial missing data
# For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}`
if !$(DynamicPPL.inargnames)($vn, __model__) ||
$(DynamicPPL.inmissings)($vn, __model__)
if !$(DynamicPPL.inargnames)($vn, __model__) || (
__context__ isa $(DynamicPPL.ConditionContext) &&
!$(Base.haskey)(__context__, $vn)
)
true
else
# Evaluate the LHS
Expand Down
51 changes: 51 additions & 0 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi)
return value
end

function tilde_assume!(context::ConditionContext, right, vn, inds, vi)
if haskey(context, vn)
# Extract value.
value = if inds isa Tuple{}
getfield(context.values, getsym(vn))
else
_getindex(getfield(context.values, getsym(vn)), inds)
end

# Should we even do this?
if haskey(vi, vn)
vi[vn] = value
end

tilde_observe!(context.context, right, value, vn, inds, vi)
else
value = tilde_assume!(context.context, right, vn, inds, vi)
end

return value
end

# observe
"""
tilde_observe(context::SamplingContext, right, left, vname, vinds, vi)
Expand Down Expand Up @@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi)
return left
end

function tilde_observe!(context::ConditionContext, right, left, vi)
return tilde_observe!(context.context, right, left, vi)
end

function assume(rng, spl::Sampler, dist)
return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))")
end
Expand Down Expand Up @@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi)
return value
end

function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi)
if vn in context
# Extract value.
value = if inds isa Tuple{}
getfield(context.values, sym)
else
_getindex(getfield(context.values, sym), inds)
end

# Should we even do this?
if haskey(vi, vn)
vi[vn] = value
end

dot_tilde_observe!(context.context, right, value, vn, inds, vi)
else
value = dot_tilde_assume!(context.context, right, left, vn, inds, vi)
end

return value
end

# `dot_assume`
function dot_assume(
dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi
Expand Down Expand Up @@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi)
acclogp!(vi, logp)
return left
end
function dot_tilde_observe!(context::ConditionContext, right, left, vi)
return dot_tilde_observe!(context.context, right, left, vi)
end

# Falls back to non-sampler definition.
function dot_observe(::AbstractSampler, dist, value, vi)
Expand Down
63 changes: 63 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,66 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing)
end
end

struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx

function ConditionContext{Values}(
values::Values, context::AbstractContext
) where {names,Values<:NamedTuple{names}}
return new{names,typeof(values),typeof(context)}(values, context)
end
end

@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values}
names_expr = Expr(:tuple)
values_expr = Expr(:tuple)

for (n, v) in zip(names, values.parameters)
if !(v <: Missing)
push!(names_expr.args, QuoteNode(n))
push!(values_expr.args, :(nt.$n))
end
end

return :(NamedTuple{$names_expr}($values_expr))
end

function ConditionContext(context::ConditionContext, child_context::AbstractContext)
return ConditionContext(context.values, child_context)
end
function ConditionContext(values::NamedTuple)
return ConditionContext(values, DefaultContext())
end

function ConditionContext(values::NamedTuple, context::AbstractContext)
values_wo_missing = drop_missings(values)
return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context)
end

# Try to avoid nested `ConditionContext`.
function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars}
# Note that this potentially overrides values from `context`, thus giving
# precedence to the outmost `ConditionContext`.
return ConditionContext(merge(context.values, values), context.context)
end

function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym}
# TODO: Add possibility of indexed variables, e.g. `x[1]`, etc.
return sym in vars
end

# TODO: Can we maybe do this in a better way?
# When no second argument is given, we remove _all_ conditioned variables.
# TODO: Should we remove this and just return `context.context`?
# That will work better if `Model` becomes like `ContextualModel`.
decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context)
function decondition(context::ConditionContext, sym)
return ConditionContext(BangBang.delete!!(context.values, sym), context.context)
end
function decondition(context::ConditionContext, sym, syms...)
return decondition(
ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms...
)
end
89 changes: 56 additions & 33 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,45 +32,36 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition
Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,))
```
"""
struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <:
AbstractProbabilisticProgram
struct Model{
F,
argnames,
defaultnames,
Targs,
Tdefaults,
conditionnames,
Ctx<:ConditionContext{conditionnames},
} <: AbstractProbabilisticProgram
name::Symbol
f::F
args::NamedTuple{argnames,Targs}
defaults::NamedTuple{defaultnames,Tdefaults}
context::Ctx

@doc """
Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple)

Create a model of name `name` with evaluation function `f` and missing arguments
overwritten by `missings`.
"""
function Model{missings}(
function Model(
name::Symbol,
f::F,
args::NamedTuple{argnames,Targs},
defaults::NamedTuple{defaultnames,Tdefaults},
) where {missings,F,argnames,Targs,defaultnames,Tdefaults}
return new{F,argnames,defaultnames,missings,Targs,Tdefaults}(
name, f, args, defaults
defaults::NamedTuple{defaultnames,Tdefaults}=NamedTuple(),
context::ConditionContext{conditionnames}=ConditionContext(args, DefaultContext()),
) where {missings,F,argnames,Targs,defaultnames,Tdefaults,conditionnames}
return new{F,argnames,defaultnames,Targs,Tdefaults,conditionnames,typeof(context)}(
name, f, args, defaults, context
)
end
end

"""
Model(name::Symbol, f, args::NamedTuple[, defaults::NamedTuple = ()])

Create a model of name `name` with evaluation function `f` and missing arguments deduced
from `args`.

Default arguments `defaults` are used internally when constructing instances of the same
model with different arguments.
"""
@generated function Model(
name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple()
) where {F,argnames,Targs}
missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing)
return :(Model{$missings}(name, f, args, defaults))
function Model(m::Model, context::ConditionContext)
return Model(m.name, m.f, m.args, m.defaults, context)
end

"""
Expand All @@ -93,10 +84,12 @@ end

(model::Model)(context::AbstractContext) = model(VarInfo(), context)
function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext)
condition_context = ConditionContext(model.context.values, context)

if Threads.nthreads() == 1
return evaluate_threadunsafe(model, varinfo, context)
return evaluate_threadunsafe(model, varinfo, condition_context)
else
return evaluate_threadsafe(model, varinfo, context)
return evaluate_threadsafe(model, varinfo, condition_context)
end
end

Expand Down Expand Up @@ -154,12 +147,42 @@ end
Evaluate the `model` with the arguments matching the given `context` and `varinfo` object.
"""
@generated function _evaluate(
model::Model{_F,argnames}, varinfo, context
) where {_F,argnames}
unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames]
return :(model.f(model, varinfo, context, $(unwrap_args...)))
model::Model{_F,argnames,<:Any,<:Any,<:Any,conditionnames}, varinfo, context
) where {_F,argnames,conditionnames}
unwrap_args = []
for var in argnames
# If `var` is not to be found in the `ConditionContext`, we fall back
# to the original args.
expr = if var in conditionnames
:($matchingvalue(context, varinfo, model.context.values.$var))
else
:($matchingvalue(context, varinfo, model.args.$var))
end
push!(unwrap_args, expr)
end

return :(model.f(
model, varinfo, ConditionContext(model.context, context), $(unwrap_args...)
))
end

"""
condition(model::Model, values::NamedTuple)
condition(model::Model; values...)

Condition `model` on the specifide `values`, i.e. make `model` treat `values` as observations.
"""
condition(model::Model, values) = Model(model, ConditionContext(values, model.context))
condition(model::Model; values...) = condition(model, (; values...))

"""
decondition(model::Model)
decondition(model::Model, symbols...)

Decondition `symbols` in `model`, i.e. make `model` treat them as random variables.
"""
decondition(model::Model, symbols...) = Model(model, decondition(model.context, symbols...))

"""
getargnames(model::Model)

Expand Down