From cd1c46de8bef1242f22ea9303ff1faca0bb972c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:08:01 +0100 Subject: [PATCH 1/8] added ConditionContext and ContextualModel --- Project.toml | 1 + src/DynamicPPL.jl | 6 ++++ src/context_implementations.jl | 51 ++++++++++++++++++++++++++++++++++ src/contexts.jl | 42 ++++++++++++++++++++++++++++ src/model.jl | 14 ++++++---- src/varinfo.jl | 6 ++-- 6 files changed, 111 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 921dc054d..7cc3db31d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..bfcc49e6b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC +using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -67,6 +68,7 @@ export AbstractVarInfo, vectorize, # Model Model, + ContextualModel, getmissings, getargnames, generated_quantities, @@ -81,6 +83,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +102,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel @@ -129,5 +134,6 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("contextual_model.jl") end # module diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..68165a64f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi) return value end +function tilde_assume!(context::ConditionContext, right, vn, inds, vi) + value = if haskey(context, vn) + # Extract value. + if inds isa Tuple{} + getfield(context.values, getsym(vn)) + else + _getindex(getfield(context.values, getsym(vn)), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + tilde_observe!(context.context, right, value, vn, inds, vi) + else + tilde_assume!(context.context, right, vn, inds, vi) + end + + return value +end + # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi) return left end +function tilde_observe!(context::ConditionContext, right, left, vi) + return tilde_observe!(context.context, right, left, vi) +end + function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end +function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) + value = if vn in context + # Extract value. + if inds isa Tuple{} + getfield(context.values, sym) + else + _getindex(getfield(context.values, sym), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + dot_tilde_observe!(context.context, right, left, vn, inds, vi) + else + dot_tilde_assume!(context.context, right, left, vn, inds, vi) + end + + return value +end + # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi @@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end +function dot_tilde_observe!(context::ConditionContext, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) +end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..946c22092 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -106,3 +106,45 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + + +struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +function ConditionContext(values::NamedTuple{Vars}) where {Vars} + return ConditionContext(values, DefaultContext()) +end + +function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} + return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), context.context) +end + +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + +# TODO: Can we maybe do this in a better way? +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, sym, syms...) + return decondition( + ConditionContext(BangBang.delete!!(context.values, sym), context.context), + syms... + ) +end +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} + return decondition( + ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), + syms... + ) +end diff --git a/src/model.jl b/src/model.jl index 9ec047a44..3ce707e2e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,5 @@ +abstract type AbstractModel <: AbstractProbabilisticProgram end + """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -33,7 +35,7 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x ``` """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram + AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -82,7 +84,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( +function (model::AbstractModel)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -91,7 +93,7 @@ function (model::Model)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::Model)(context::AbstractContext) = model(VarInfo(), context) +(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) @@ -100,17 +102,17 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) +function (model::AbstractModel)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..be6551939 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::Model, + model::AbstractModel, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end From b1106ee65b60a8eebffada99b5c770a869730169 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:10:12 +0100 Subject: [PATCH 2/8] removed redundant definition --- src/contexts.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 946c22092..5abae4168 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -141,7 +141,6 @@ function decondition(context::ConditionContext, sym, syms...) syms... ) end -decondition(context::ConditionContext) = context function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} return decondition( ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), From 0f00771e8f391311cd6956ac33e3e31c770a2409 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:36:28 +0100 Subject: [PATCH 3/8] return condition model by default --- src/compiler.jl | 16 ++++++++++------ src/contexts.jl | 14 +++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..00be43551 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,7 +23,8 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) + $(DynamicPPL.inmissings)($vn, __model__) || + (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) true else # Evaluate the LHS @@ -427,11 +428,14 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, + return $(DynamicPPL.condition)( + $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, + $allargs_namedtuple, + $defaults_namedtuple, + ), + $allargs_namedtuple ) end diff --git a/src/contexts.jl b/src/contexts.jl index 5abae4168..0a0e977b9 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -134,16 +134,16 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v end # TODO: Can we maybe do this in a better way? -decondition(context::ConditionContext) = context +# When no second argument is given, we remove _all_ conditioned variables. +# TODO: Should we remove this and just return `context.context`? +# That will work better if `Model` becomes like `ContextualModel`. +decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext, sym) + return ConditionContext(BangBang.delete!!(context.values, sym), context.context) +end function decondition(context::ConditionContext, sym, syms...) return decondition( ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end -function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} - return decondition( - ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), - syms... - ) -end From c754291411f3af1bef2468838493a750eaefa427 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:47:50 +0100 Subject: [PATCH 4/8] formatting --- src/compiler.jl | 9 ++++++--- src/contexts.jl | 6 ++---- src/model.jl | 3 +-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 00be43551..191152f4e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,8 +23,11 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) + $(DynamicPPL.inmissings)($vn, __model__) || + ( + __context__ isa $(DynamicPPL.ConditionContext) && + !$(Base.haskey)(__context__, $vn) + ) true else # Evaluate the LHS @@ -435,7 +438,7 @@ function build_output(modelinfo, linenumbernode) $allargs_namedtuple, $defaults_namedtuple, ), - $allargs_namedtuple + $allargs_namedtuple, ) end diff --git a/src/contexts.jl b/src/contexts.jl index 0a0e977b9..9831dde4a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,6 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end - struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -128,7 +127,7 @@ function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) w return ConditionContext(merge(context.values, values), context.context) end -function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end @@ -143,7 +142,6 @@ function decondition(context::ConditionContext, sym) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), context.context), - syms... + ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end diff --git a/src/model.jl b/src/model.jl index 3ce707e2e..39a9f4fd3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -34,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} From 2d3f94cb6c137f6ce22683392cf4fcf0c579a47d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:15:23 +0100 Subject: [PATCH 5/8] forgot to include contextual model in previous commit --- src/contextual_model.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 src/contextual_model.jl diff --git a/src/contextual_model.jl b/src/contextual_model.jl new file mode 100644 index 000000000..382c45034 --- /dev/null +++ b/src/contextual_model.jl @@ -0,0 +1,29 @@ +struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel + context::Ctx + model::M +end + +function contextualize(model::AbstractModel, context::AbstractContext) + return ContextualModel(context, model) +end + +# TODO: What do we do for other contexts? Could handle this in general if we had a +# notion of wrapper-, primitive-context, etc. +function (cmodel::ContextualModel{<:ConditionContext})( + varinfo::AbstractVarInfo, context::AbstractContext +) + # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as + # `ConditionContext` child. + return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) +end + +condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) +condition(model::AbstractModel; values...) = condition(model, (; values...)) +function condition(cmodel::ContextualModel{<:ConditionContext}, values) + return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) +end + +decondition(model::AbstractModel, args...) = model +function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) + return contextualize(cmodel.model, decondition(cmodel.context, syms...)) +end From 27cd0d0915a7a89e979667f0f4ee66c511841c6f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:15:56 +0100 Subject: [PATCH 6/8] make Model contain ConditionContext rather than a special ContextualModel --- src/DynamicPPL.jl | 3 +- src/compiler.jl | 12 ++--- src/context_implementations.jl | 14 +++--- src/contexts.jl | 28 +++++++++-- src/contextual_model.jl | 29 ------------ src/model.jl | 86 +++++++++++++++++++++------------- 6 files changed, 91 insertions(+), 81 deletions(-) delete mode 100644 src/contextual_model.jl diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bfcc49e6b..2f933e4d3 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -121,11 +121,11 @@ abstract type AbstractContext end include("utils.jl") include("selector.jl") +include("contexts.jl") include("model.jl") include("sampler.jl") include("varname.jl") include("distribution_wrappers.jl") -include("contexts.jl") include("varinfo.jl") include("threadsafe.jl") include("context_implementations.jl") @@ -134,6 +134,5 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -include("contextual_model.jl") end # module diff --git a/src/compiler.jl b/src/compiler.jl index 191152f4e..2cb796263 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,7 +23,6 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || ( __context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn) @@ -431,14 +430,11 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.condition)( - $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, - ), + return $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, $allargs_namedtuple, + $defaults_namedtuple, ) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 68165a64f..052b4ab9c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -134,9 +134,9 @@ function tilde_assume!(context, right, vn, inds, vi) end function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - value = if haskey(context, vn) + if haskey(context, vn) # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, getsym(vn)) else _getindex(getfield(context.values, getsym(vn)), inds) @@ -149,7 +149,7 @@ function tilde_assume!(context::ConditionContext, right, vn, inds, vi) tilde_observe!(context.context, right, value, vn, inds, vi) else - tilde_assume!(context.context, right, vn, inds, vi) + value = tilde_assume!(context.context, right, vn, inds, vi) end return value @@ -446,9 +446,9 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - value = if vn in context + if vn in context # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, sym) else _getindex(getfield(context.values, sym), inds) @@ -459,9 +459,9 @@ function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) vi[vn] = value end - dot_tilde_observe!(context.context, right, left, vn, inds, vi) + dot_tilde_observe!(context.context, right, value, vn, inds, vi) else - dot_tilde_assume!(context.context, right, left, vn, inds, vi) + value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) end return value diff --git a/src/contexts.jl b/src/contexts.jl index 9831dde4a..136967f0f 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -110,14 +110,36 @@ end struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx + + function ConditionContext{Values}(values::Values, context::AbstractContext) where {names, Values<:NamedTuple{names}} + return new{names, typeof(values), typeof(context)}(values, context) + end +end + +@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} + names_expr = Expr(:tuple) + values_expr = Expr(:tuple) + + for (n, v) in zip(names, values.parameters) + if !(v <: Missing) + push!(names_expr.args, QuoteNode(n)) + push!(values_expr.args, :(nt.$n)) + end + end + + return :(NamedTuple{$names_expr}($values_expr)) end -function ConditionContext(values::NamedTuple{Vars}) where {Vars} +function ConditionContext(context::ConditionContext, child_context::AbstractContext) + return ConditionContext(context.values, child_context) +end +function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end -function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} - return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +function ConditionContext(values::NamedTuple, context::AbstractContext) + values_wo_missing = drop_missings(values) + return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) end # Try to avoid nested `ConditionContext`. diff --git a/src/contextual_model.jl b/src/contextual_model.jl deleted file mode 100644 index 382c45034..000000000 --- a/src/contextual_model.jl +++ /dev/null @@ -1,29 +0,0 @@ -struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel - context::Ctx - model::M -end - -function contextualize(model::AbstractModel, context::AbstractContext) - return ContextualModel(context, model) -end - -# TODO: What do we do for other contexts? Could handle this in general if we had a -# notion of wrapper-, primitive-context, etc. -function (cmodel::ContextualModel{<:ConditionContext})( - varinfo::AbstractVarInfo, context::AbstractContext -) - # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as - # `ConditionContext` child. - return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) -end - -condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) -condition(model::AbstractModel; values...) = condition(model, (; values...)) -function condition(cmodel::ContextualModel{<:ConditionContext}, values) - return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) -end - -decondition(model::AbstractModel, args...) = model -function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) - return contextualize(cmodel.model, decondition(cmodel.context, syms...)) -end diff --git a/src/model.jl b/src/model.jl index 39a9f4fd3..221bdab5c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -34,44 +34,36 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel +struct Model{ + F, + argnames, + defaultnames, + Targs, + Tdefaults, + conditionnames, + Ctx<:ConditionContext{conditionnames}, +} <: AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + context::Ctx - """ - Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) - - Create a model of name `name` with evaluation function `f` and missing arguments - overwritten by `missings`. - """ - function Model{missings}( + function Model( name::Symbol, f::F, args::NamedTuple{argnames,Targs}, - defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults + defaults::NamedTuple{defaultnames,Tdefaults}=NamedTuple(), + context::ConditionContext{conditionnames}=ConditionContext(args, DefaultContext()) + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,conditionnames} + return new{F,argnames,defaultnames,Targs,Tdefaults,conditionnames,typeof(context)}( + name, f, args, defaults, context ) end end -""" - Model(name::Symbol, f, args::NamedTuple[, defaults::NamedTuple = ()]) - -Create a model of name `name` with evaluation function `f` and missing arguments deduced -from `args`. - -Default arguments `defaults` are used internally when constructing instances of the same -model with different arguments. -""" -@generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() -) where {F,argnames,Targs} - missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) +function Model(m::Model, context::ConditionContext) + return Model(m.name, m.f, m.args, m.defaults, context) end """ @@ -94,10 +86,12 @@ end (model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) + condition_context = ConditionContext(model.context.values, context) + if Threads.nthreads() == 1 - return evaluate_threadunsafe(model, varinfo, context) + return evaluate_threadunsafe(model, varinfo, condition_context) else - return evaluate_threadsafe(model, varinfo, context) + return evaluate_threadsafe(model, varinfo, condition_context) end end @@ -155,12 +149,40 @@ end Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( - model::Model{_F,argnames}, varinfo, context -) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + model::Model{_F,argnames,<:Any,<:Any,<:Any,conditionnames}, varinfo, context +) where {_F,argnames,conditionnames} + unwrap_args = [] + for var in argnames + # If `var` is not to be found in the `ConditionContext`, we fall back + # to the original args. + expr = if var in conditionnames + :($matchingvalue(context, varinfo, model.context.values.$var)) + else + :($matchingvalue(context, varinfo, model.args.$var)) + end + push!(unwrap_args, expr) + end + + return :(model.f(model, varinfo, ConditionContext(model.context, context), $(unwrap_args...))) end +""" + condition(model::Model, values::NamedTuple) + condition(model::Model; values...) + +Condition `model` on the specifide `values`, i.e. make `model` treat `values` as observations. +""" +condition(model::Model, values) = Model(model, ConditionContext(values, model.context)) +condition(model::Model; values...) = condition(model, (; values...)) + +""" + decondition(model::Model) + decondition(model::Model, symbols...) + +Decondition `symbols` in `model`, i.e. make `model` treat them as random variables. +""" +decondition(model::Model, symbols...) = Model(model, decondition(model.context, symbols...)) + """ getargnames(model::Model) From 5105d7e06dee1ae127f485fb9d0ecdfcfc211e5c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:42:41 +0100 Subject: [PATCH 7/8] removed redundant AbstractModel --- src/model.jl | 14 ++++++-------- src/varinfo.jl | 6 +++--- 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/src/model.jl b/src/model.jl index 221bdab5c..e936e396c 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,3 @@ -abstract type AbstractModel <: AbstractProbabilisticProgram end - """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -42,7 +40,7 @@ struct Model{ Tdefaults, conditionnames, Ctx<:ConditionContext{conditionnames}, -} <: AbstractModel +} <: AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -75,7 +73,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::AbstractModel)( +function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -84,7 +82,7 @@ function (model::AbstractModel)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) +(model::Model)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) condition_context = ConditionContext(model.context.values, context) @@ -95,17 +93,17 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::AbstractModel)(args...) +function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end diff --git a/src/varinfo.jl b/src/varinfo.jl index d3fe4d21a..a226506f4 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::AbstractModel, + model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end From 44f45c94f4c53c38412c625626266ffc31a1c037 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 23:07:33 +0100 Subject: [PATCH 8/8] formatting --- src/compiler.jl | 9 ++++----- src/contexts.jl | 6 ++++-- src/model.jl | 6 ++++-- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2cb796263..3d405abd4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -22,11 +22,10 @@ function isassumption(expr::Union{Symbol,Expr}) let $vn = $(varname(expr)) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - ( - __context__ isa $(DynamicPPL.ConditionContext) && - !$(Base.haskey)(__context__, $vn) - ) + if !$(DynamicPPL.inargnames)($vn, __model__) || ( + __context__ isa $(DynamicPPL.ConditionContext) && + !$(Base.haskey)(__context__, $vn) + ) true else # Evaluate the LHS diff --git a/src/contexts.jl b/src/contexts.jl index 136967f0f..e06d4c6eb 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -111,8 +111,10 @@ struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx - function ConditionContext{Values}(values::Values, context::AbstractContext) where {names, Values<:NamedTuple{names}} - return new{names, typeof(values), typeof(context)}(values, context) + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) end end diff --git a/src/model.jl b/src/model.jl index e936e396c..3dfe13bbf 100644 --- a/src/model.jl +++ b/src/model.jl @@ -52,7 +52,7 @@ struct Model{ f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}=NamedTuple(), - context::ConditionContext{conditionnames}=ConditionContext(args, DefaultContext()) + context::ConditionContext{conditionnames}=ConditionContext(args, DefaultContext()), ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,conditionnames} return new{F,argnames,defaultnames,Targs,Tdefaults,conditionnames,typeof(context)}( name, f, args, defaults, context @@ -161,7 +161,9 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf push!(unwrap_args, expr) end - return :(model.f(model, varinfo, ConditionContext(model.context, context), $(unwrap_args...))) + return :(model.f( + model, varinfo, ConditionContext(model.context, context), $(unwrap_args...) + )) end """