From cd1c46de8bef1242f22ea9303ff1faca0bb972c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:08:01 +0100 Subject: [PATCH 01/26] added ConditionContext and ContextualModel --- Project.toml | 1 + src/DynamicPPL.jl | 6 ++++ src/context_implementations.jl | 51 ++++++++++++++++++++++++++++++++++ src/contexts.jl | 42 ++++++++++++++++++++++++++++ src/model.jl | 14 ++++++---- src/varinfo.jl | 6 ++-- 6 files changed, 111 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 921dc054d..7cc3db31d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..bfcc49e6b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC +using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -67,6 +68,7 @@ export AbstractVarInfo, vectorize, # Model Model, + ContextualModel, getmissings, getargnames, generated_quantities, @@ -81,6 +83,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +102,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel @@ -129,5 +134,6 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("contextual_model.jl") end # module diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..68165a64f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi) return value end +function tilde_assume!(context::ConditionContext, right, vn, inds, vi) + value = if haskey(context, vn) + # Extract value. + if inds isa Tuple{} + getfield(context.values, getsym(vn)) + else + _getindex(getfield(context.values, getsym(vn)), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + tilde_observe!(context.context, right, value, vn, inds, vi) + else + tilde_assume!(context.context, right, vn, inds, vi) + end + + return value +end + # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi) return left end +function tilde_observe!(context::ConditionContext, right, left, vi) + return tilde_observe!(context.context, right, left, vi) +end + function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end +function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) + value = if vn in context + # Extract value. + if inds isa Tuple{} + getfield(context.values, sym) + else + _getindex(getfield(context.values, sym), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + dot_tilde_observe!(context.context, right, left, vn, inds, vi) + else + dot_tilde_assume!(context.context, right, left, vn, inds, vi) + end + + return value +end + # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi @@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end +function dot_tilde_observe!(context::ConditionContext, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) +end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..946c22092 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -106,3 +106,45 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + + +struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +function ConditionContext(values::NamedTuple{Vars}) where {Vars} + return ConditionContext(values, DefaultContext()) +end + +function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} + return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), context.context) +end + +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + +# TODO: Can we maybe do this in a better way? +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, sym, syms...) + return decondition( + ConditionContext(BangBang.delete!!(context.values, sym), context.context), + syms... + ) +end +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} + return decondition( + ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), + syms... + ) +end diff --git a/src/model.jl b/src/model.jl index 9ec047a44..3ce707e2e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,5 @@ +abstract type AbstractModel <: AbstractProbabilisticProgram end + """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -33,7 +35,7 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x ``` """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram + AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -82,7 +84,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( +function (model::AbstractModel)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -91,7 +93,7 @@ function (model::Model)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::Model)(context::AbstractContext) = model(VarInfo(), context) +(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) @@ -100,17 +102,17 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) +function (model::AbstractModel)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..be6551939 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::Model, + model::AbstractModel, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end From b1106ee65b60a8eebffada99b5c770a869730169 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:10:12 +0100 Subject: [PATCH 02/26] removed redundant definition --- src/contexts.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 946c22092..5abae4168 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -141,7 +141,6 @@ function decondition(context::ConditionContext, sym, syms...) syms... ) end -decondition(context::ConditionContext) = context function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} return decondition( ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), From 0f00771e8f391311cd6956ac33e3e31c770a2409 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:36:28 +0100 Subject: [PATCH 03/26] return condition model by default --- src/compiler.jl | 16 ++++++++++------ src/contexts.jl | 14 +++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..00be43551 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,7 +23,8 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) + $(DynamicPPL.inmissings)($vn, __model__) || + (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) true else # Evaluate the LHS @@ -427,11 +428,14 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, + return $(DynamicPPL.condition)( + $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, + $allargs_namedtuple, + $defaults_namedtuple, + ), + $allargs_namedtuple ) end diff --git a/src/contexts.jl b/src/contexts.jl index 5abae4168..0a0e977b9 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -134,16 +134,16 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v end # TODO: Can we maybe do this in a better way? -decondition(context::ConditionContext) = context +# When no second argument is given, we remove _all_ conditioned variables. +# TODO: Should we remove this and just return `context.context`? +# That will work better if `Model` becomes like `ContextualModel`. +decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext, sym) + return ConditionContext(BangBang.delete!!(context.values, sym), context.context) +end function decondition(context::ConditionContext, sym, syms...) return decondition( ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end -function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} - return decondition( - ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), - syms... - ) -end From c754291411f3af1bef2468838493a750eaefa427 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:47:50 +0100 Subject: [PATCH 04/26] formatting --- src/compiler.jl | 9 ++++++--- src/contexts.jl | 6 ++---- src/model.jl | 3 +-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 00be43551..191152f4e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,8 +23,11 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) + $(DynamicPPL.inmissings)($vn, __model__) || + ( + __context__ isa $(DynamicPPL.ConditionContext) && + !$(Base.haskey)(__context__, $vn) + ) true else # Evaluate the LHS @@ -435,7 +438,7 @@ function build_output(modelinfo, linenumbernode) $allargs_namedtuple, $defaults_namedtuple, ), - $allargs_namedtuple + $allargs_namedtuple, ) end diff --git a/src/contexts.jl b/src/contexts.jl index 0a0e977b9..9831dde4a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,6 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end - struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -128,7 +127,7 @@ function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) w return ConditionContext(merge(context.values, values), context.context) end -function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end @@ -143,7 +142,6 @@ function decondition(context::ConditionContext, sym) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), context.context), - syms... + ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end diff --git a/src/model.jl b/src/model.jl index 3ce707e2e..39a9f4fd3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -34,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} From 2d3f94cb6c137f6ce22683392cf4fcf0c579a47d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:15:23 +0100 Subject: [PATCH 05/26] forgot to include contextual model in previous commit --- src/contextual_model.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 src/contextual_model.jl diff --git a/src/contextual_model.jl b/src/contextual_model.jl new file mode 100644 index 000000000..382c45034 --- /dev/null +++ b/src/contextual_model.jl @@ -0,0 +1,29 @@ +struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel + context::Ctx + model::M +end + +function contextualize(model::AbstractModel, context::AbstractContext) + return ContextualModel(context, model) +end + +# TODO: What do we do for other contexts? Could handle this in general if we had a +# notion of wrapper-, primitive-context, etc. +function (cmodel::ContextualModel{<:ConditionContext})( + varinfo::AbstractVarInfo, context::AbstractContext +) + # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as + # `ConditionContext` child. + return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) +end + +condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) +condition(model::AbstractModel; values...) = condition(model, (; values...)) +function condition(cmodel::ContextualModel{<:ConditionContext}, values) + return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) +end + +decondition(model::AbstractModel, args...) = model +function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) + return contextualize(cmodel.model, decondition(cmodel.context, syms...)) +end From 3e5a79f7e1b14ba610f88377ef9b9f177777becb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:45:45 +0100 Subject: [PATCH 06/26] fixed typos --- src/context_implementations.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 68165a64f..052b4ab9c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -134,9 +134,9 @@ function tilde_assume!(context, right, vn, inds, vi) end function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - value = if haskey(context, vn) + if haskey(context, vn) # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, getsym(vn)) else _getindex(getfield(context.values, getsym(vn)), inds) @@ -149,7 +149,7 @@ function tilde_assume!(context::ConditionContext, right, vn, inds, vi) tilde_observe!(context.context, right, value, vn, inds, vi) else - tilde_assume!(context.context, right, vn, inds, vi) + value = tilde_assume!(context.context, right, vn, inds, vi) end return value @@ -446,9 +446,9 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - value = if vn in context + if vn in context # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, sym) else _getindex(getfield(context.values, sym), inds) @@ -459,9 +459,9 @@ function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) vi[vn] = value end - dot_tilde_observe!(context.context, right, left, vn, inds, vi) + dot_tilde_observe!(context.context, right, value, vn, inds, vi) else - dot_tilde_assume!(context.context, right, left, vn, inds, vi) + value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) end return value From d4e42387888a4d7c9b2f9e62541a8de70fbb6f0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 23:06:59 +0100 Subject: [PATCH 07/26] added some niceties to ConditionContext --- src/contexts.jl | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 9831dde4a..e06d4c6eb 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -110,14 +110,38 @@ end struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx + + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) + end +end + +@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} + names_expr = Expr(:tuple) + values_expr = Expr(:tuple) + + for (n, v) in zip(names, values.parameters) + if !(v <: Missing) + push!(names_expr.args, QuoteNode(n)) + push!(values_expr.args, :(nt.$n)) + end + end + + return :(NamedTuple{$names_expr}($values_expr)) end -function ConditionContext(values::NamedTuple{Vars}) where {Vars} +function ConditionContext(context::ConditionContext, child_context::AbstractContext) + return ConditionContext(context.values, child_context) +end +function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end -function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} - return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +function ConditionContext(values::NamedTuple, context::AbstractContext) + values_wo_missing = drop_missings(values) + return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) end # Try to avoid nested `ConditionContext`. From 85a47eb098498df027ebaf2efe124517507933cb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 01:49:52 +0100 Subject: [PATCH 08/26] added support for vectors of VarName --- src/context_implementations.jl | 2 +- src/contexts.jl | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 052b4ab9c..813d40875 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -446,7 +446,7 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - if vn in context + if haskey(context, vn) # Extract value. value = if inds isa Tuple{} getfield(context.values, sym) diff --git a/src/contexts.jl b/src/contexts.jl index e06d4c6eb..500d3d60b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -156,6 +156,11 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v return sym in vars end +function Base.haskey(context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}) where {vars,sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + # TODO: Can we maybe do this in a better way? # When no second argument is given, we remove _all_ conditioned variables. # TODO: Should we remove this and just return `context.context`? From c7dae8daa1a66712328161be7a8b4fe94575deae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 01:52:17 +0100 Subject: [PATCH 09/26] Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 500d3d60b..d0cd81287 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -156,7 +156,9 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v return sym in vars end -function Base.haskey(context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}) where {vars,sym} +function Base.haskey( + context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} +) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end From 75680b3d5715eb29239ca39ff7401d3ce966f194 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 05:12:00 +0100 Subject: [PATCH 10/26] upper-bound Distributions.jl in tests --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 146e92c55..37c509610 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "2.1, 3.0" AbstractPPL = "0.1.3" Bijectors = "0.9.5" -Distributions = "0.24, 0.25" +Distributions = "< 0.25.11" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" From f0ae744e625adbe0776021a0991d0aeb14fd8790 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 05:33:42 +0100 Subject: [PATCH 11/26] make the isassumption check using context extensible and nicer --- src/compiler.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d0485dba9..9eca8abae 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -24,10 +24,7 @@ function isassumption(expr::Union{Symbol,Expr}) # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || $(DynamicPPL.inmissings)($vn, __model__) || - ( - __context__ isa $(DynamicPPL.ConditionContext) && - !$(Base.haskey)(__context__, $vn) - ) + $(DynamicPPL.contextual_isassumption)(__context__, $vn) true else # Evaluate the LHS @@ -37,6 +34,9 @@ function isassumption(expr::Union{Symbol,Expr}) end end +contextual_isassumption(context::AbstractContext, vn) = false +contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) + # failsafe: a literal is never an assumption isassumption(expr) = :(false) From b990bb05a3baec3f67c714fa11d529d974a97839 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 06:14:33 +0100 Subject: [PATCH 12/26] renamed type-parameter for ConditionContext --- src/contexts.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index d0cd81287..8bfc25a9b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,7 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end -struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext +struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -145,7 +145,7 @@ function ConditionContext(values::NamedTuple, context::AbstractContext) end # Try to avoid nested `ConditionContext`. -function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} +function ConditionContext(values::NamedTuple{Names}, context::ConditionContext) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), context.context) From 8994cd78834cfa9408927dca7be95532a5a554c7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 06:27:08 +0100 Subject: [PATCH 13/26] Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 8bfc25a9b..96bd43cd7 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -145,7 +145,9 @@ function ConditionContext(values::NamedTuple, context::AbstractContext) end # Try to avoid nested `ConditionContext`. -function ConditionContext(values::NamedTuple{Names}, context::ConditionContext) where {Names} +function ConditionContext( + values::NamedTuple{Names}, context::ConditionContext +) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), context.context) From 94da453a489102f2a326e777993e2c45001d966a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:17:50 +0100 Subject: [PATCH 14/26] introduced convenient _getvalue method --- src/context_implementations.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 813d40875..9f9ca3004 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) +_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) +function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} + val = getproperty(x, sym) + + # This should work with both cartesian and linear indexing. + return map(vns) do vn + _getindex(val, vn) + end +end # assume """ From 4e74cf8a6ab99de553b23ca898ce451a8d190aac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:18:04 +0100 Subject: [PATCH 15/26] overload tilde_assume rather than tilde_assume! and others for ConditionContext --- src/context_implementations.jl | 117 +++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 48 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9f9ca3004..54f6d5965 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -127,6 +127,36 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end +# `ConditionContext` +function tilde_assume(context::ConditionContext, right, vn, inds, vi) + if !haskey(context, vn) + # Not conditioned on => defer to child context. + return tilde_assume(context.context, right, vn, inds, vi) + end + + # Extract value. + value = _getvalue(context.values, vn) + + # Update the value in `vi`. + # TODO: Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + logp = tilde_observe(context.context, right, value, vn, inds, vi) + return value, logp +end + +function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) + if haskey(context, vn) + # Defer to child context. + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) + end + + # If we're conditioning, then we just fall back to non-rng impl. + return tilde_assume(context, right, vn, inds, vi) +end + """ tilde_assume!(context, right, vn, inds, vi) @@ -142,28 +172,6 @@ function tilde_assume!(context, right, vn, inds, vi) return value end -function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - if haskey(context, vn) - # Extract value. - value = if inds isa Tuple{} - getfield(context.values, getsym(vn)) - else - _getindex(getfield(context.values, getsym(vn)), inds) - end - - # Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - tilde_observe!(context.context, right, value, vn, inds, vi) - else - value = tilde_assume!(context.context, right, vn, inds, vi) - end - - return value -end - # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -220,6 +228,14 @@ function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function tilde_observe(context::ConditionContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, vname, vi) +end +function tilde_observe(context::ConditionContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) +end + """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -248,10 +264,6 @@ function tilde_observe!(context, right, left, vi) return left end -function tilde_observe!(context::ConditionContext, right, left, vi) - return tilde_observe!(context.context, right, left, vi) -end - function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -440,6 +452,37 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, ) end +# `ConditionContext` +function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) + if !haskey(context, vn) + # Not conditioned on => defer to child context. + return dot_tilde_assume(context.context, right, left, vn, inds, vi) + end + + # Extract value. + # FIXME: Handle the case where `vn` is actually `AbstractArray{<:VarName}`. + value = _getvalue(context.values, vn) + + # Update the value in `vi`. + # TODO: Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + logp = dot_tilde_observe(context.context, right, left, value, vn, inds, vi) + return value, logp +end + +function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) + if !haskey(context, vn) + # Defer to child context. + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) + end + + # If we're conditioning, then we just fall back to non-rng impl. + return dot_tilde_assume(context, right, left, vn, inds, vi) +end + """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -454,28 +497,6 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end -function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - if haskey(context, vn) - # Extract value. - value = if inds isa Tuple{} - getfield(context.values, sym) - else - _getindex(getfield(context.values, sym), inds) - end - - # Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - dot_tilde_observe!(context.context, right, value, vn, inds, vi) - else - value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) - end - - return value -end - # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi From f9cdfa9910226a2ad790654f4ddeb390e531e95e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:21:42 +0100 Subject: [PATCH 16/26] added contextual_isassumption for PrefixContext --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 9eca8abae..6bcd68c7a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -36,6 +36,9 @@ end contextual_isassumption(context::AbstractContext, vn) = false contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) +function contextual_isassumption(context::PrefixContext, vn::VarName) + return contextual_isassumption(context.context, prefix(context, vn)) +end # failsafe: a literal is never an assumption isassumption(expr) = :(false) From 835a41ead81a2387f1113fb148d42a7851f2cd1d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:24:28 +0100 Subject: [PATCH 17/26] implemented contextual_isassumption for all contexts --- src/compiler.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6bcd68c7a..3ddbb6340 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -35,10 +35,19 @@ function isassumption(expr::Union{Symbol,Expr}) end contextual_isassumption(context::AbstractContext, vn) = false -contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) -function contextual_isassumption(context::PrefixContext, vn::VarName) +function contextual_isassumption(context::ConditionContext, vn) + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. + return !(haskey(context, vn)) || contextual_isassumption(context, vn) +end +function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(context.context, prefix(context, vn)) end +function contextual_isassumption(context::MiniBatchContext, vn) + return contextual_isassumption(context.context, vn) +end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end # failsafe: a literal is never an assumption isassumption(expr) = :(false) From 5d110d5cb4bb885d44ed8b8d33574f41aae22a4b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:32:54 +0100 Subject: [PATCH 18/26] improved the way ConditionContext works signficantly --- src/compiler.jl | 72 ++++++++++++++++++++++++---------- src/context_implementations.jl | 51 ++---------------------- src/contexts.jl | 6 +++ src/loglikelihoods.jl | 8 ++++ 4 files changed, 70 insertions(+), 67 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 3ddbb6340..d640ef1d7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -20,24 +20,52 @@ function isassumption(expr::Union{Symbol,Expr}) return quote let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - $(DynamicPPL.contextual_isassumption)(__context__, $vn) - true + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) + true + else + $expr === missing + end else - # Evaluate the LHS - $(maybe_view(expr)) === missing + false end end end end -contextual_isassumption(context::AbstractContext, vn) = false +""" + contextual_isassumption(context, vn) + +Return `true` if `vn` is considered an assumption by `context`. + +The default implementation for `AbstractContext` always returns `true`. +""" +contextual_isassumption(context::AbstractContext, vn) = true function contextual_isassumption(context::ConditionContext, vn) # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. - return !(haskey(context, vn)) || contextual_isassumption(context, vn) + + # We have either of the following cases: + # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, + # which means we have a value to replace with and we don't need to recurse. + # 2. One of the decendant contexts consider it as an observation, i.e. + # `contextual_isassumption` evaluates to `false`. + # The below then evaluates to `!(false || true) === false`. + # 3. Neither `context` nor any of it's decendants considers it an observation, + # in which case the below evaluates to `!(false || false) === true`. + return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) end function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(context.context, prefix(context, vn)) @@ -45,9 +73,6 @@ end function contextual_isassumption(context::MiniBatchContext, vn) return contextual_isassumption(context.context, vn) end -function contextual_isassumption(context::PointwiseLikelihoodContext, vn) - return contextual_isassumption(context.context, vn) -end # failsafe: a literal is never an assumption isassumption(expr) = :(false) @@ -348,6 +373,11 @@ function generate_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left = $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -392,6 +422,11 @@ function generate_dot_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left .= $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -453,14 +488,11 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.condition)( - $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, - ), + return $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, $allargs_namedtuple, + $defaults_namedtuple, ) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 54f6d5965..e632a4c95 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -129,32 +129,11 @@ end # `ConditionContext` function tilde_assume(context::ConditionContext, right, vn, inds, vi) - if !haskey(context, vn) - # Not conditioned on => defer to child context. - return tilde_assume(context.context, right, vn, inds, vi) - end - - # Extract value. - value = _getvalue(context.values, vn) - - # Update the value in `vi`. - # TODO: Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - logp = tilde_observe(context.context, right, value, vn, inds, vi) - return value, logp + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) - if haskey(context, vn) - # Defer to child context. - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) - end - - # If we're conditioning, then we just fall back to non-rng impl. - return tilde_assume(context, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) end """ @@ -454,33 +433,11 @@ end # `ConditionContext` function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) - if !haskey(context, vn) - # Not conditioned on => defer to child context. - return dot_tilde_assume(context.context, right, left, vn, inds, vi) - end - - # Extract value. - # FIXME: Handle the case where `vn` is actually `AbstractArray{<:VarName}`. - value = _getvalue(context.values, vn) - - # Update the value in `vi`. - # TODO: Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - logp = dot_tilde_observe(context.context, right, left, value, vn, inds, vi) - return value, logp + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) - if !haskey(context, vn) - # Defer to child context. - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) - end - - # If we're conditioning, then we just fall back to non-rng impl. - return dot_tilde_assume(context, right, left, vn, inds, vi) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index 96bd43cd7..f58242b5c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -153,6 +153,12 @@ function ConditionContext( return ConditionContext(merge(context.values, values), context.context) end +function getvalue(context::ConditionContext, vn) + haskey(context, vn) ? _getvalue(context.values, vn) : getvalue(context.context, vn) +end +getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) +getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) + function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6c66e4ec4..2203eda13 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -13,6 +13,14 @@ function PointwiseLikelihoodContext( ) end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end + +function contextual_isobservation(context::PointwiseLikelihoodContext, vn) + return contextual_isobservation(context.context, vn) +end + function Base.push!( context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, From 560ca83ecaf57eaa2b2cdb3d68b8b2f39e5854bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:35:34 +0100 Subject: [PATCH 19/26] forgot impl of dot_tilde_observe for ConditionContext --- src/context_implementations.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e632a4c95..e464b69de 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -646,6 +646,11 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function dot_tilde_observe(context::ConditionContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) +end + """ dot_tilde_observe!(context, right, left, vname, vinds, vi) @@ -672,9 +677,6 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end -function dot_tilde_observe!(context::ConditionContext, right, left, vi) - return dot_tilde_observe!(context.context, right, left, vi) -end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) From 5635c3b530a2b3085b4787c6e491950e2ade722e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:39:24 +0100 Subject: [PATCH 20/26] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 4 +++- src/contexts.jl | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e464b69de..a0b3c8e77 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -436,7 +436,9 @@ function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, left, vn, inds, vi) end -function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) +function dot_tilde_assume( + rng, context::ConditionContext, sampler, right, left, vn, inds, vi +) return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index f58242b5c..f48a5187c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -154,7 +154,11 @@ function ConditionContext( end function getvalue(context::ConditionContext, vn) - haskey(context, vn) ? _getvalue(context.values, vn) : getvalue(context.context, vn) + return if haskey(context, vn) + _getvalue(context.values, vn) + else + getvalue(context.context, vn) + end end getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) From e78dc6546c08d960d5ae3e6a95ca3d0e5f4f056f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Jul 2021 02:38:40 +0100 Subject: [PATCH 21/26] overload _evaluate rather than the model-call directly --- src/contextual_model.jl | 6 ++---- src/model.jl | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 382c45034..4b843ed0e 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -9,12 +9,10 @@ end # TODO: What do we do for other contexts? Could handle this in general if we had a # notion of wrapper-, primitive-context, etc. -function (cmodel::ContextualModel{<:ConditionContext})( - varinfo::AbstractVarInfo, context::AbstractContext -) +function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) diff --git a/src/model.jl b/src/model.jl index e5b9beb24..7b5d8918a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -93,7 +93,7 @@ function (model::AbstractModel)( end (model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) +function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else From e1a7d38d33b91cfa0acb1b36e31d6fda0e50579f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 06:52:02 +0100 Subject: [PATCH 22/26] formatting --- src/contextual_model.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 4b843ed0e..793491720 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,9 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate( + cmodel.model, varinfo, ConditionContext(cmodel.context.values, context) + ) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From 4e566f78d4deff430b2ebe065ff8909b2f7a3b64 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:57:54 +0100 Subject: [PATCH 23/26] remove the drop_missing as it is no longer needed --- src/contexts.jl | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index f48a5187c..1c54cf86a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -118,30 +118,14 @@ struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext end end -@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} - names_expr = Expr(:tuple) - values_expr = Expr(:tuple) - - for (n, v) in zip(names, values.parameters) - if !(v <: Missing) - push!(names_expr.args, QuoteNode(n)) - push!(values_expr.args, :(nt.$n)) - end - end - - return :(NamedTuple{$names_expr}($values_expr)) -end - function ConditionContext(context::ConditionContext, child_context::AbstractContext) return ConditionContext(context.values, child_context) end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end - function ConditionContext(values::NamedTuple, context::AbstractContext) - values_wo_missing = drop_missings(values) - return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) + return ConditionContext{typeof(values)}(values, context) end # Try to avoid nested `ConditionContext`. From 61960832b75ff490862eaca5da6e186602fc3c93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:11:51 +0100 Subject: [PATCH 24/26] made show a bit nicer for ConditionContext --- src/contexts.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 1c54cf86a..4a6834b32 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -137,6 +137,10 @@ function ConditionContext( return ConditionContext(merge(context.values, values), context.context) end +function Base.show(io::IO, context::ConditionContext) + println(io, "ConditionContext($(context.values), $(context.context))") +end + function getvalue(context::ConditionContext, vn) return if haskey(context, vn) _getvalue(context.values, vn) From 65048fc4c57e1ec884cddbcdcd880269cfdc33a3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:12:32 +0100 Subject: [PATCH 25/26] use print instead of println in show --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 4a6834b32..fa0bf585a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -138,7 +138,7 @@ function ConditionContext( end function Base.show(io::IO, context::ConditionContext) - println(io, "ConditionContext($(context.values), $(context.context))") + print(io, "ConditionContext($(context.values), $(context.context))") end function getvalue(context::ConditionContext, vn) From 649af299a7f68c0def25026bbf06850aedbbfdbc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:14:15 +0100 Subject: [PATCH 26/26] formatting --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index fa0bf585a..926dd0fe0 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -138,7 +138,7 @@ function ConditionContext( end function Base.show(io::IO, context::ConditionContext) - print(io, "ConditionContext($(context.values), $(context.context))") + return print(io, "ConditionContext($(context.values), $(context.context))") end function getvalue(context::ConditionContext, vn)