From cd1c46de8bef1242f22ea9303ff1faca0bb972c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:08:01 +0100 Subject: [PATCH 01/61] added ConditionContext and ContextualModel --- Project.toml | 1 + src/DynamicPPL.jl | 6 ++++ src/context_implementations.jl | 51 ++++++++++++++++++++++++++++++++++ src/contexts.jl | 42 ++++++++++++++++++++++++++++ src/model.jl | 14 ++++++---- src/varinfo.jl | 6 ++-- 6 files changed, 111 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 921dc054d..7cc3db31d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..bfcc49e6b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC +using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -67,6 +68,7 @@ export AbstractVarInfo, vectorize, # Model Model, + ContextualModel, getmissings, getargnames, generated_quantities, @@ -81,6 +83,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +102,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel @@ -129,5 +134,6 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("contextual_model.jl") end # module diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..68165a64f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi) return value end +function tilde_assume!(context::ConditionContext, right, vn, inds, vi) + value = if haskey(context, vn) + # Extract value. + if inds isa Tuple{} + getfield(context.values, getsym(vn)) + else + _getindex(getfield(context.values, getsym(vn)), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + tilde_observe!(context.context, right, value, vn, inds, vi) + else + tilde_assume!(context.context, right, vn, inds, vi) + end + + return value +end + # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi) return left end +function tilde_observe!(context::ConditionContext, right, left, vi) + return tilde_observe!(context.context, right, left, vi) +end + function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end +function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) + value = if vn in context + # Extract value. + if inds isa Tuple{} + getfield(context.values, sym) + else + _getindex(getfield(context.values, sym), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + dot_tilde_observe!(context.context, right, left, vn, inds, vi) + else + dot_tilde_assume!(context.context, right, left, vn, inds, vi) + end + + return value +end + # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi @@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end +function dot_tilde_observe!(context::ConditionContext, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) +end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..946c22092 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -106,3 +106,45 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + + +struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +function ConditionContext(values::NamedTuple{Vars}) where {Vars} + return ConditionContext(values, DefaultContext()) +end + +function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} + return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), context.context) +end + +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + +# TODO: Can we maybe do this in a better way? +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, sym, syms...) + return decondition( + ConditionContext(BangBang.delete!!(context.values, sym), context.context), + syms... + ) +end +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} + return decondition( + ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), + syms... + ) +end diff --git a/src/model.jl b/src/model.jl index 9ec047a44..3ce707e2e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,5 @@ +abstract type AbstractModel <: AbstractProbabilisticProgram end + """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -33,7 +35,7 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x ``` """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram + AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -82,7 +84,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( +function (model::AbstractModel)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -91,7 +93,7 @@ function (model::Model)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::Model)(context::AbstractContext) = model(VarInfo(), context) +(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) @@ -100,17 +102,17 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) +function (model::AbstractModel)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..be6551939 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::Model, + model::AbstractModel, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end From b1106ee65b60a8eebffada99b5c770a869730169 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:10:12 +0100 Subject: [PATCH 02/61] removed redundant definition --- src/contexts.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 946c22092..5abae4168 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -141,7 +141,6 @@ function decondition(context::ConditionContext, sym, syms...) syms... ) end -decondition(context::ConditionContext) = context function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} return decondition( ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), From 0f00771e8f391311cd6956ac33e3e31c770a2409 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:36:28 +0100 Subject: [PATCH 03/61] return condition model by default --- src/compiler.jl | 16 ++++++++++------ src/contexts.jl | 14 +++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..00be43551 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,7 +23,8 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) + $(DynamicPPL.inmissings)($vn, __model__) || + (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) true else # Evaluate the LHS @@ -427,11 +428,14 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, + return $(DynamicPPL.condition)( + $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, + $allargs_namedtuple, + $defaults_namedtuple, + ), + $allargs_namedtuple ) end diff --git a/src/contexts.jl b/src/contexts.jl index 5abae4168..0a0e977b9 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -134,16 +134,16 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v end # TODO: Can we maybe do this in a better way? -decondition(context::ConditionContext) = context +# When no second argument is given, we remove _all_ conditioned variables. +# TODO: Should we remove this and just return `context.context`? +# That will work better if `Model` becomes like `ContextualModel`. +decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext, sym) + return ConditionContext(BangBang.delete!!(context.values, sym), context.context) +end function decondition(context::ConditionContext, sym, syms...) return decondition( ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end -function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} - return decondition( - ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), - syms... - ) -end From c754291411f3af1bef2468838493a750eaefa427 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:47:50 +0100 Subject: [PATCH 04/61] formatting --- src/compiler.jl | 9 ++++++--- src/contexts.jl | 6 ++---- src/model.jl | 3 +-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 00be43551..191152f4e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,8 +23,11 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) + $(DynamicPPL.inmissings)($vn, __model__) || + ( + __context__ isa $(DynamicPPL.ConditionContext) && + !$(Base.haskey)(__context__, $vn) + ) true else # Evaluate the LHS @@ -435,7 +438,7 @@ function build_output(modelinfo, linenumbernode) $allargs_namedtuple, $defaults_namedtuple, ), - $allargs_namedtuple + $allargs_namedtuple, ) end diff --git a/src/contexts.jl b/src/contexts.jl index 0a0e977b9..9831dde4a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,6 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end - struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -128,7 +127,7 @@ function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) w return ConditionContext(merge(context.values, values), context.context) end -function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end @@ -143,7 +142,6 @@ function decondition(context::ConditionContext, sym) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), context.context), - syms... + ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end diff --git a/src/model.jl b/src/model.jl index 3ce707e2e..39a9f4fd3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -34,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} From 2d3f94cb6c137f6ce22683392cf4fcf0c579a47d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:15:23 +0100 Subject: [PATCH 05/61] forgot to include contextual model in previous commit --- src/contextual_model.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 src/contextual_model.jl diff --git a/src/contextual_model.jl b/src/contextual_model.jl new file mode 100644 index 000000000..382c45034 --- /dev/null +++ b/src/contextual_model.jl @@ -0,0 +1,29 @@ +struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel + context::Ctx + model::M +end + +function contextualize(model::AbstractModel, context::AbstractContext) + return ContextualModel(context, model) +end + +# TODO: What do we do for other contexts? Could handle this in general if we had a +# notion of wrapper-, primitive-context, etc. +function (cmodel::ContextualModel{<:ConditionContext})( + varinfo::AbstractVarInfo, context::AbstractContext +) + # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as + # `ConditionContext` child. + return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) +end + +condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) +condition(model::AbstractModel; values...) = condition(model, (; values...)) +function condition(cmodel::ContextualModel{<:ConditionContext}, values) + return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) +end + +decondition(model::AbstractModel, args...) = model +function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) + return contextualize(cmodel.model, decondition(cmodel.context, syms...)) +end From 3e5a79f7e1b14ba610f88377ef9b9f177777becb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:45:45 +0100 Subject: [PATCH 06/61] fixed typos --- src/context_implementations.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 68165a64f..052b4ab9c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -134,9 +134,9 @@ function tilde_assume!(context, right, vn, inds, vi) end function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - value = if haskey(context, vn) + if haskey(context, vn) # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, getsym(vn)) else _getindex(getfield(context.values, getsym(vn)), inds) @@ -149,7 +149,7 @@ function tilde_assume!(context::ConditionContext, right, vn, inds, vi) tilde_observe!(context.context, right, value, vn, inds, vi) else - tilde_assume!(context.context, right, vn, inds, vi) + value = tilde_assume!(context.context, right, vn, inds, vi) end return value @@ -446,9 +446,9 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - value = if vn in context + if vn in context # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, sym) else _getindex(getfield(context.values, sym), inds) @@ -459,9 +459,9 @@ function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) vi[vn] = value end - dot_tilde_observe!(context.context, right, left, vn, inds, vi) + dot_tilde_observe!(context.context, right, value, vn, inds, vi) else - dot_tilde_assume!(context.context, right, left, vn, inds, vi) + value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) end return value From d4e42387888a4d7c9b2f9e62541a8de70fbb6f0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 23:06:59 +0100 Subject: [PATCH 07/61] added some niceties to ConditionContext --- src/contexts.jl | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 9831dde4a..e06d4c6eb 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -110,14 +110,38 @@ end struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx + + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) + end +end + +@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} + names_expr = Expr(:tuple) + values_expr = Expr(:tuple) + + for (n, v) in zip(names, values.parameters) + if !(v <: Missing) + push!(names_expr.args, QuoteNode(n)) + push!(values_expr.args, :(nt.$n)) + end + end + + return :(NamedTuple{$names_expr}($values_expr)) end -function ConditionContext(values::NamedTuple{Vars}) where {Vars} +function ConditionContext(context::ConditionContext, child_context::AbstractContext) + return ConditionContext(context.values, child_context) +end +function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end -function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} - return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +function ConditionContext(values::NamedTuple, context::AbstractContext) + values_wo_missing = drop_missings(values) + return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) end # Try to avoid nested `ConditionContext`. From 85a47eb098498df027ebaf2efe124517507933cb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 01:49:52 +0100 Subject: [PATCH 08/61] added support for vectors of VarName --- src/context_implementations.jl | 2 +- src/contexts.jl | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 052b4ab9c..813d40875 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -446,7 +446,7 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - if vn in context + if haskey(context, vn) # Extract value. value = if inds isa Tuple{} getfield(context.values, sym) diff --git a/src/contexts.jl b/src/contexts.jl index e06d4c6eb..500d3d60b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -156,6 +156,11 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v return sym in vars end +function Base.haskey(context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}) where {vars,sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + # TODO: Can we maybe do this in a better way? # When no second argument is given, we remove _all_ conditioned variables. # TODO: Should we remove this and just return `context.context`? From c7dae8daa1a66712328161be7a8b4fe94575deae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 01:52:17 +0100 Subject: [PATCH 09/61] Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 500d3d60b..d0cd81287 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -156,7 +156,9 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v return sym in vars end -function Base.haskey(context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}) where {vars,sym} +function Base.haskey( + context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} +) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end From 75680b3d5715eb29239ca39ff7401d3ce966f194 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 05:12:00 +0100 Subject: [PATCH 10/61] upper-bound Distributions.jl in tests --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 146e92c55..37c509610 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "2.1, 3.0" AbstractPPL = "0.1.3" Bijectors = "0.9.5" -Distributions = "0.24, 0.25" +Distributions = "< 0.25.11" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" From f0ae744e625adbe0776021a0991d0aeb14fd8790 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 05:33:42 +0100 Subject: [PATCH 11/61] make the isassumption check using context extensible and nicer --- src/compiler.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d0485dba9..9eca8abae 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -24,10 +24,7 @@ function isassumption(expr::Union{Symbol,Expr}) # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || $(DynamicPPL.inmissings)($vn, __model__) || - ( - __context__ isa $(DynamicPPL.ConditionContext) && - !$(Base.haskey)(__context__, $vn) - ) + $(DynamicPPL.contextual_isassumption)(__context__, $vn) true else # Evaluate the LHS @@ -37,6 +34,9 @@ function isassumption(expr::Union{Symbol,Expr}) end end +contextual_isassumption(context::AbstractContext, vn) = false +contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) + # failsafe: a literal is never an assumption isassumption(expr) = :(false) From b990bb05a3baec3f67c714fa11d529d974a97839 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 06:14:33 +0100 Subject: [PATCH 12/61] renamed type-parameter for ConditionContext --- src/contexts.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index d0cd81287..8bfc25a9b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,7 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end -struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext +struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -145,7 +145,7 @@ function ConditionContext(values::NamedTuple, context::AbstractContext) end # Try to avoid nested `ConditionContext`. -function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} +function ConditionContext(values::NamedTuple{Names}, context::ConditionContext) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), context.context) From 8994cd78834cfa9408927dca7be95532a5a554c7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 06:27:08 +0100 Subject: [PATCH 13/61] Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 8bfc25a9b..96bd43cd7 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -145,7 +145,9 @@ function ConditionContext(values::NamedTuple, context::AbstractContext) end # Try to avoid nested `ConditionContext`. -function ConditionContext(values::NamedTuple{Names}, context::ConditionContext) where {Names} +function ConditionContext( + values::NamedTuple{Names}, context::ConditionContext +) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), context.context) From 94da453a489102f2a326e777993e2c45001d966a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:17:50 +0100 Subject: [PATCH 14/61] introduced convenient _getvalue method --- src/context_implementations.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 813d40875..9f9ca3004 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) +_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) +function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} + val = getproperty(x, sym) + + # This should work with both cartesian and linear indexing. + return map(vns) do vn + _getindex(val, vn) + end +end # assume """ From 4e74cf8a6ab99de553b23ca898ce451a8d190aac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:18:04 +0100 Subject: [PATCH 15/61] overload tilde_assume rather than tilde_assume! and others for ConditionContext --- src/context_implementations.jl | 117 +++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 48 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9f9ca3004..54f6d5965 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -127,6 +127,36 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end +# `ConditionContext` +function tilde_assume(context::ConditionContext, right, vn, inds, vi) + if !haskey(context, vn) + # Not conditioned on => defer to child context. + return tilde_assume(context.context, right, vn, inds, vi) + end + + # Extract value. + value = _getvalue(context.values, vn) + + # Update the value in `vi`. + # TODO: Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + logp = tilde_observe(context.context, right, value, vn, inds, vi) + return value, logp +end + +function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) + if haskey(context, vn) + # Defer to child context. + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) + end + + # If we're conditioning, then we just fall back to non-rng impl. + return tilde_assume(context, right, vn, inds, vi) +end + """ tilde_assume!(context, right, vn, inds, vi) @@ -142,28 +172,6 @@ function tilde_assume!(context, right, vn, inds, vi) return value end -function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - if haskey(context, vn) - # Extract value. - value = if inds isa Tuple{} - getfield(context.values, getsym(vn)) - else - _getindex(getfield(context.values, getsym(vn)), inds) - end - - # Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - tilde_observe!(context.context, right, value, vn, inds, vi) - else - value = tilde_assume!(context.context, right, vn, inds, vi) - end - - return value -end - # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -220,6 +228,14 @@ function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function tilde_observe(context::ConditionContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, vname, vi) +end +function tilde_observe(context::ConditionContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) +end + """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -248,10 +264,6 @@ function tilde_observe!(context, right, left, vi) return left end -function tilde_observe!(context::ConditionContext, right, left, vi) - return tilde_observe!(context.context, right, left, vi) -end - function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -440,6 +452,37 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, ) end +# `ConditionContext` +function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) + if !haskey(context, vn) + # Not conditioned on => defer to child context. + return dot_tilde_assume(context.context, right, left, vn, inds, vi) + end + + # Extract value. + # FIXME: Handle the case where `vn` is actually `AbstractArray{<:VarName}`. + value = _getvalue(context.values, vn) + + # Update the value in `vi`. + # TODO: Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + logp = dot_tilde_observe(context.context, right, left, value, vn, inds, vi) + return value, logp +end + +function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) + if !haskey(context, vn) + # Defer to child context. + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) + end + + # If we're conditioning, then we just fall back to non-rng impl. + return dot_tilde_assume(context, right, left, vn, inds, vi) +end + """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -454,28 +497,6 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end -function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - if haskey(context, vn) - # Extract value. - value = if inds isa Tuple{} - getfield(context.values, sym) - else - _getindex(getfield(context.values, sym), inds) - end - - # Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - dot_tilde_observe!(context.context, right, value, vn, inds, vi) - else - value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) - end - - return value -end - # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi From f9cdfa9910226a2ad790654f4ddeb390e531e95e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:21:42 +0100 Subject: [PATCH 16/61] added contextual_isassumption for PrefixContext --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 9eca8abae..6bcd68c7a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -36,6 +36,9 @@ end contextual_isassumption(context::AbstractContext, vn) = false contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) +function contextual_isassumption(context::PrefixContext, vn::VarName) + return contextual_isassumption(context.context, prefix(context, vn)) +end # failsafe: a literal is never an assumption isassumption(expr) = :(false) From 835a41ead81a2387f1113fb148d42a7851f2cd1d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:24:28 +0100 Subject: [PATCH 17/61] implemented contextual_isassumption for all contexts --- src/compiler.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6bcd68c7a..3ddbb6340 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -35,10 +35,19 @@ function isassumption(expr::Union{Symbol,Expr}) end contextual_isassumption(context::AbstractContext, vn) = false -contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) -function contextual_isassumption(context::PrefixContext, vn::VarName) +function contextual_isassumption(context::ConditionContext, vn) + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. + return !(haskey(context, vn)) || contextual_isassumption(context, vn) +end +function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(context.context, prefix(context, vn)) end +function contextual_isassumption(context::MiniBatchContext, vn) + return contextual_isassumption(context.context, vn) +end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end # failsafe: a literal is never an assumption isassumption(expr) = :(false) From 5d110d5cb4bb885d44ed8b8d33574f41aae22a4b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:32:54 +0100 Subject: [PATCH 18/61] improved the way ConditionContext works signficantly --- src/compiler.jl | 72 ++++++++++++++++++++++++---------- src/context_implementations.jl | 51 ++---------------------- src/contexts.jl | 6 +++ src/loglikelihoods.jl | 8 ++++ 4 files changed, 70 insertions(+), 67 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 3ddbb6340..d640ef1d7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -20,24 +20,52 @@ function isassumption(expr::Union{Symbol,Expr}) return quote let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - $(DynamicPPL.contextual_isassumption)(__context__, $vn) - true + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) + true + else + $expr === missing + end else - # Evaluate the LHS - $(maybe_view(expr)) === missing + false end end end end -contextual_isassumption(context::AbstractContext, vn) = false +""" + contextual_isassumption(context, vn) + +Return `true` if `vn` is considered an assumption by `context`. + +The default implementation for `AbstractContext` always returns `true`. +""" +contextual_isassumption(context::AbstractContext, vn) = true function contextual_isassumption(context::ConditionContext, vn) # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. - return !(haskey(context, vn)) || contextual_isassumption(context, vn) + + # We have either of the following cases: + # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, + # which means we have a value to replace with and we don't need to recurse. + # 2. One of the decendant contexts consider it as an observation, i.e. + # `contextual_isassumption` evaluates to `false`. + # The below then evaluates to `!(false || true) === false`. + # 3. Neither `context` nor any of it's decendants considers it an observation, + # in which case the below evaluates to `!(false || false) === true`. + return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) end function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(context.context, prefix(context, vn)) @@ -45,9 +73,6 @@ end function contextual_isassumption(context::MiniBatchContext, vn) return contextual_isassumption(context.context, vn) end -function contextual_isassumption(context::PointwiseLikelihoodContext, vn) - return contextual_isassumption(context.context, vn) -end # failsafe: a literal is never an assumption isassumption(expr) = :(false) @@ -348,6 +373,11 @@ function generate_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left = $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -392,6 +422,11 @@ function generate_dot_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left .= $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -453,14 +488,11 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.condition)( - $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, - ), + return $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, $allargs_namedtuple, + $defaults_namedtuple, ) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 54f6d5965..e632a4c95 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -129,32 +129,11 @@ end # `ConditionContext` function tilde_assume(context::ConditionContext, right, vn, inds, vi) - if !haskey(context, vn) - # Not conditioned on => defer to child context. - return tilde_assume(context.context, right, vn, inds, vi) - end - - # Extract value. - value = _getvalue(context.values, vn) - - # Update the value in `vi`. - # TODO: Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - logp = tilde_observe(context.context, right, value, vn, inds, vi) - return value, logp + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) - if haskey(context, vn) - # Defer to child context. - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) - end - - # If we're conditioning, then we just fall back to non-rng impl. - return tilde_assume(context, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) end """ @@ -454,33 +433,11 @@ end # `ConditionContext` function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) - if !haskey(context, vn) - # Not conditioned on => defer to child context. - return dot_tilde_assume(context.context, right, left, vn, inds, vi) - end - - # Extract value. - # FIXME: Handle the case where `vn` is actually `AbstractArray{<:VarName}`. - value = _getvalue(context.values, vn) - - # Update the value in `vi`. - # TODO: Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - logp = dot_tilde_observe(context.context, right, left, value, vn, inds, vi) - return value, logp + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) - if !haskey(context, vn) - # Defer to child context. - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) - end - - # If we're conditioning, then we just fall back to non-rng impl. - return dot_tilde_assume(context, right, left, vn, inds, vi) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index 96bd43cd7..f58242b5c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -153,6 +153,12 @@ function ConditionContext( return ConditionContext(merge(context.values, values), context.context) end +function getvalue(context::ConditionContext, vn) + haskey(context, vn) ? _getvalue(context.values, vn) : getvalue(context.context, vn) +end +getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) +getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) + function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6c66e4ec4..2203eda13 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -13,6 +13,14 @@ function PointwiseLikelihoodContext( ) end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end + +function contextual_isobservation(context::PointwiseLikelihoodContext, vn) + return contextual_isobservation(context.context, vn) +end + function Base.push!( context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, From 560ca83ecaf57eaa2b2cdb3d68b8b2f39e5854bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:35:34 +0100 Subject: [PATCH 19/61] forgot impl of dot_tilde_observe for ConditionContext --- src/context_implementations.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e632a4c95..e464b69de 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -646,6 +646,11 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function dot_tilde_observe(context::ConditionContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) +end + """ dot_tilde_observe!(context, right, left, vname, vinds, vi) @@ -672,9 +677,6 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end -function dot_tilde_observe!(context::ConditionContext, right, left, vi) - return dot_tilde_observe!(context.context, right, left, vi) -end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) From 5635c3b530a2b3085b4787c6e491950e2ade722e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:39:24 +0100 Subject: [PATCH 20/61] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 4 +++- src/contexts.jl | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e464b69de..a0b3c8e77 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -436,7 +436,9 @@ function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, left, vn, inds, vi) end -function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) +function dot_tilde_assume( + rng, context::ConditionContext, sampler, right, left, vn, inds, vi +) return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index f58242b5c..f48a5187c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -154,7 +154,11 @@ function ConditionContext( end function getvalue(context::ConditionContext, vn) - haskey(context, vn) ? _getvalue(context.values, vn) : getvalue(context.context, vn) + return if haskey(context, vn) + _getvalue(context.values, vn) + else + getvalue(context.context, vn) + end end getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) From e78dc6546c08d960d5ae3e6a95ca3d0e5f4f056f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Jul 2021 02:38:40 +0100 Subject: [PATCH 21/61] overload _evaluate rather than the model-call directly --- src/contextual_model.jl | 6 ++---- src/model.jl | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 382c45034..4b843ed0e 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -9,12 +9,10 @@ end # TODO: What do we do for other contexts? Could handle this in general if we had a # notion of wrapper-, primitive-context, etc. -function (cmodel::ContextualModel{<:ConditionContext})( - varinfo::AbstractVarInfo, context::AbstractContext -) +function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) diff --git a/src/model.jl b/src/model.jl index e5b9beb24..7b5d8918a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -93,7 +93,7 @@ function (model::AbstractModel)( end (model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) +function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else From 1d3b11e9caa2d2e2184940b42eadf6a176b7ccb1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 06:51:20 +0100 Subject: [PATCH 22/61] drop now unneceesary impls for tilds for ConditionContext --- src/context_implementations.jl | 33 --------------------------------- src/contexts.jl | 3 +++ 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8e2b02851..ce3f89739 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -127,15 +127,6 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end -# `ConditionContext` -function tilde_assume(context::ConditionContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) -end - """ tilde_assume!(context, right, vn, inds, vi) @@ -204,14 +195,6 @@ function tilde_observe(context::PrefixContext, right, left, vname, vi) return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -# `ConditionContext` -function tilde_observe(context::ConditionContext, right, left, vname, vi) - return tilde_observe(context.context, right, left, vname, vi) -end -function tilde_observe(context::ConditionContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end - """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -428,17 +411,6 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, ) end -# `ConditionContext` -function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - -function dot_tilde_assume( - rng, context::ConditionContext, sampler, right, left, vn, inds, vi -) - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) -end - """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -640,11 +612,6 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end -# `ConditionContext` -function dot_tilde_observe(context::ConditionContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vi) -end - """ dot_tilde_observe!(context, right, left, vname, vinds, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 98790a785..947f31538 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -212,3 +212,6 @@ function decondition(context::ConditionContext, sym, syms...) ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end + +NodeTrait(context::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context From e1a7d38d33b91cfa0acb1b36e31d6fda0e50579f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 06:52:02 +0100 Subject: [PATCH 23/61] formatting --- src/contextual_model.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 4b843ed0e..793491720 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,9 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate( + cmodel.model, varinfo, ConditionContext(cmodel.context.values, context) + ) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From b42c34f346c69c4b5ab615ef0306214e74e1a29b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:03:35 +0100 Subject: [PATCH 24/61] address issues using traits --- src/compiler.jl | 15 +++++++++------ src/contexts.jl | 34 ++++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d640ef1d7..eeced2701 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -53,7 +53,13 @@ Return `true` if `vn` is considered an assumption by `context`. The default implementation for `AbstractContext` always returns `true`. """ -contextual_isassumption(context::AbstractContext, vn) = true +contextual_isassumption(::IsLeaf, context, vn) = true +function contextual_isassumption(::IsParent, context, vn) + return contextual_isassumption(childcontext(context), vn) +end +function contextual_isassumption(context::AbstractContext, vn) + return contextual_isassumption(NodeTrait(context), context, vn) +end function contextual_isassumption(context::ConditionContext, vn) # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. @@ -65,13 +71,10 @@ function contextual_isassumption(context::ConditionContext, vn) # The below then evaluates to `!(false || true) === false`. # 3. Neither `context` nor any of it's decendants considers it an observation, # in which case the below evaluates to `!(false || false) === true`. - return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) + return !(haskey(context, vn) || !contextual_isassumption(childcontext(context), vn)) end function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(context.context, prefix(context, vn)) -end -function contextual_isassumption(context::MiniBatchContext, vn) - return contextual_isassumption(context.context, vn) + return contextual_isassumption(childcontext(context), prefix(context, vn)) end # failsafe: a literal is never an assumption diff --git a/src/contexts.jl b/src/contexts.jl index 947f31538..38cbe213e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -174,19 +174,34 @@ function ConditionContext( ) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. - return ConditionContext(merge(context.values, values), context.context) + return ConditionContext(merge(context.values, values), childcontext(context)) end +""" + getvalue(context, vn) + +Return the value of the parameter corresponding to `vn` from `context`. +If `context` does not contain the value for `vn`, then `nothing` is returned, +e.g. [`DefaultContext`](@ref) will always return `nothing`. +""" +getvalue(::IsLeaf, context, vn) = nothing +getvalue(::IsParent, context, vn) = getvalue(childcontext(context), vn) +getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), context, vn) +getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn)) function getvalue(context::ConditionContext, vn) return if haskey(context, vn) _getvalue(context.values, vn) else - getvalue(context.context, vn) + getvalue(childcontext(context), vn) end end -getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) -getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) +# General implementations of `haskey`. +Base.haskey(::IsLeaf, context, vn) = false +Base.haskey(::IsParent, context, vn) = Base.haskey(childcontext(context), vn) +Base.haskey(context::AbstractContext, vn) = Base.haskey(NodeTrait(context), context, vn) + +# Specific to `ConditionContext`. function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars @@ -201,15 +216,18 @@ end # TODO: Can we maybe do this in a better way? # When no second argument is given, we remove _all_ conditioned variables. -# TODO: Should we remove this and just return `context.context`? +# TODO: Should we remove this and just return `childcontext(context)`? # That will work better if `Model` becomes like `ContextualModel`. -decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext) + return ConditionContext(NamedTuple(), childcontext(context)) +end function decondition(context::ConditionContext, sym) - return ConditionContext(BangBang.delete!!(context.values, sym), context.context) + return ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... + ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)), + syms..., ) end From c7c60e6a435ef89fe7bebd39bfe76df08f0d8ea9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:23:59 +0100 Subject: [PATCH 25/61] added rewrap for contexts --- src/contexts.jl | 62 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 38cbe213e..2f74847e6 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -8,9 +8,44 @@ struct IsParent end NodeTrait(f, context) Specifies the role of `context` in the context-tree. + +The officially supported traits are: +- `IsLeaf`: `context` does not have any decendants. +- `IsParent`: `context` has a child context to which we often defer. + Expects the following methods to be implemented: + - [`childcontext`](@ref) + - [`rewrap`](@ref) """ NodeTrait(_, context) = NodeTrait(context) +""" + childcontext(context) + +Return the descendant context of `context`. +""" +childcontext + +""" + rewrap(parent::AbstractContext, child::AbstractContext) + +Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), +effectively updating the child context. + +# Examples +```jldoctest +julia> ctx = SamplingContext(); + +julia> DynamicPPL.childcontext(ctx) +DefaultContext() + +julia> ctx_prior = DynamicPPL.rewrap(ctx, PriorContext()); # only compute the logprior + +julia> DynamicPPL.childcontext(ctx_prior) +PriorContext{Nothing}(nothing) +``` +""" +rewrap + # Contexts """ SamplingContext(rng, sampler, context) @@ -25,8 +60,14 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end +SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context) +SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context) +SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext()) +SamplingContext() = SamplingContext(SampleFromPrior()) + NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context +rewrap(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) """ struct DefaultContext <: AbstractContext end @@ -87,6 +128,7 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context +rewrap(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) """ PrefixContext{Prefix}(context) @@ -108,6 +150,7 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context +rewrap(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) const PREFIX_SEPARATOR = Symbol(".") @@ -156,9 +199,7 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -function ConditionContext(context::ConditionContext, child_context::AbstractContext) - return ConditionContext(context.values, child_context) -end +ConditionContext(; values...) = ConditionContext((; values...)) function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -177,6 +218,10 @@ function ConditionContext( return ConditionContext(merge(context.values, values), childcontext(context)) end +NodeTrait(context::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context +rewrap(parent::ConditionContext, child) = ConditionContext(parent.values, child) + """ getvalue(context, vn) @@ -214,10 +259,10 @@ function Base.haskey( return sym in vars end -# TODO: Can we maybe do this in a better way? -# When no second argument is given, we remove _all_ conditioned variables. -# TODO: Should we remove this and just return `childcontext(context)`? -# That will work better if `Model` becomes like `ContextualModel`. +# Recursively `decondition` the context. +decondition(::IsLeaf, context) = context +decondition(::IsParent, context) = rewrap(context, decondition(childcontext(context))) +decondition(context) = decondition(NodeTrait(context), context) function decondition(context::ConditionContext) return ConditionContext(NamedTuple(), childcontext(context)) end @@ -230,6 +275,3 @@ function decondition(context::ConditionContext, sym, syms...) syms..., ) end - -NodeTrait(context::ConditionContext) = IsParent() -childcontext(context::ConditionContext) = context.context From be67807948b72b0ef47a64d071835ed538384b0d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:31:47 +0100 Subject: [PATCH 26/61] do decondition properly --- src/contexts.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 2f74847e6..648039694 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -260,18 +260,25 @@ function Base.haskey( end # Recursively `decondition` the context. -decondition(::IsLeaf, context) = context -decondition(::IsParent, context) = rewrap(context, decondition(childcontext(context))) -decondition(context) = decondition(NodeTrait(context), context) +decondition(::IsLeaf, context, args...) = context +function decondition(::IsParent, context, args...) + return rewrap(context, decondition(childcontext(context), args...)) +end +decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) - return ConditionContext(NamedTuple(), childcontext(context)) + return ConditionContext(NamedTuple(), decondition(childcontext(context))) end function decondition(context::ConditionContext, sym) - return ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)) + return ConditionContext( + BangBang.delete!!(context.values, sym), childcontext(context, sym) + ) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)), + ConditionContext( + BangBang.delete!!(context.values, sym), + decondition(childcontext(context), syms...), + ), syms..., ) end From 0468297876eaf773ef823adf665a510fec226a74 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:46:53 +0100 Subject: [PATCH 27/61] added some examples and decondition now removes ConditionContext --- src/contexts.jl | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 648039694..7857c1f92 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -199,7 +199,7 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -ConditionContext(; values...) = ConditionContext((; values...)) +ConditionContext(context=DefaultContext(); values...) = ConditionContext((; values...), context) function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -259,14 +259,46 @@ function Base.haskey( return sym in vars end -# Recursively `decondition` the context. +""" + context([context::AbstractContext,] values::NamedTuple) + context([context::AbstractContext]; values...) + +Return `ConditionContext` with `values` and wrapping `context`. +""" +condition(context=DefaultContext(); values...) = ConditionContext(context; values...) + +""" + decondition(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer conditioned on. + +Note that this recursively traverses contexts, deconditioning all along the way. + +# Examples +```jldoctest +julia> ctx = DefaultContext(); + +julia> decondition(ctx) === ctx # this is a no-op +true + +julia> ctx = ConditionContext(x = 1.0); + +julia> decondition(ctx) +DefaultContext() + +julia> ctx_nested = ConditionContext(SamplingContext(ConditionContext(y=2.0)), x=1.0); + +julia> decondition(ctx_nested) +SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLOBAL_RNG(), SampleFromPrior(), DefaultContext()) +``` +""" decondition(::IsLeaf, context, args...) = context function decondition(::IsParent, context, args...) return rewrap(context, decondition(childcontext(context), args...)) end decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) - return ConditionContext(NamedTuple(), decondition(childcontext(context))) + return decondition(childcontext(context)) end function decondition(context::ConditionContext, sym) return ConditionContext( From 80e3d5fd48b3c46f4c11749ed380e51d08ed5fa0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:56:37 +0100 Subject: [PATCH 28/61] improved condition and decondition a bit further --- src/contexts.jl | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 7857c1f92..7bb437867 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -199,7 +199,9 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -ConditionContext(context=DefaultContext(); values...) = ConditionContext((; values...), context) +function ConditionContext(context=DefaultContext(); values...) + return ConditionContext((; values...), context) +end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -263,9 +265,18 @@ end context([context::AbstractContext,] values::NamedTuple) context([context::AbstractContext]; values...) -Return `ConditionContext` with `values` and wrapping `context`. +Return `ConditionContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`decondition`](@ref) """ -condition(context=DefaultContext(); values...) = ConditionContext(context; values...) +condition() = decondition(ConditionContext()) +condition(values::NamedTuple) = condition(DefaultContext(), values) +condition(context::AbstractContext, values::NamedTuple{()}) = context +condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) +function condition(context::AbstractContext=DefaultContext(); values...) + return ConditionContext(context; values...) +end """ decondition(context::AbstractContext, syms...) @@ -274,6 +285,8 @@ Return `context` but with `syms` no longer conditioned on. Note that this recursively traverses contexts, deconditioning all along the way. +See also: [`condition`](@ref) + # Examples ```jldoctest julia> ctx = DefaultContext(); @@ -301,15 +314,15 @@ function decondition(context::ConditionContext) return decondition(childcontext(context)) end function decondition(context::ConditionContext, sym) - return ConditionContext( - BangBang.delete!!(context.values, sym), childcontext(context, sym) + return condition( + decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym) ) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext( - BangBang.delete!!(context.values, sym), + condition( decondition(childcontext(context), syms...), + BangBang.delete!!(context.values, sym), ), syms..., ) From 9419e76bbbe2ffb7f708c67ac030d9769273cef6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:56:53 +0100 Subject: [PATCH 29/61] use rewrap in _evaluate for ContextualModel --- src/contextual_model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 4b843ed0e..883fc5379 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,7 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate(cmodel.model, varinfo, rewrap(cmodel.context, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From 4e566f78d4deff430b2ebe065ff8909b2f7a3b64 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:57:54 +0100 Subject: [PATCH 30/61] remove the drop_missing as it is no longer needed --- src/contexts.jl | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index f48a5187c..1c54cf86a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -118,30 +118,14 @@ struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext end end -@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} - names_expr = Expr(:tuple) - values_expr = Expr(:tuple) - - for (n, v) in zip(names, values.parameters) - if !(v <: Missing) - push!(names_expr.args, QuoteNode(n)) - push!(values_expr.args, :(nt.$n)) - end - end - - return :(NamedTuple{$names_expr}($values_expr)) -end - function ConditionContext(context::ConditionContext, child_context::AbstractContext) return ConditionContext(context.values, child_context) end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end - function ConditionContext(values::NamedTuple, context::AbstractContext) - values_wo_missing = drop_missings(values) - return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) + return ConditionContext{typeof(values)}(values, context) end # Try to avoid nested `ConditionContext`. From b27228aec8a15310f9076cc9623ec0de59f43221 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:08:14 +0100 Subject: [PATCH 31/61] rename rewrap to setchildcontet --- src/contexts.jl | 18 +++++++++--------- src/contextual_model.jl | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 00d6598af..33da613a7 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -14,7 +14,7 @@ The officially supported traits are: - `IsParent`: `context` has a child context to which we often defer. Expects the following methods to be implemented: - [`childcontext`](@ref) - - [`rewrap`](@ref) + - [`setchildcontext`](@ref) """ NodeTrait(_, context) = NodeTrait(context) @@ -26,7 +26,7 @@ Return the descendant context of `context`. childcontext """ - rewrap(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. @@ -38,13 +38,13 @@ julia> ctx = SamplingContext(); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.rewrap(ctx, PriorContext()); # only compute the logprior +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior julia> DynamicPPL.childcontext(ctx_prior) PriorContext{Nothing}(nothing) ``` """ -rewrap +setchildcontext # Contexts """ @@ -67,7 +67,7 @@ SamplingContext() = SamplingContext(SampleFromPrior()) NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context -rewrap(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) +setchildcontext(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) """ struct DefaultContext <: AbstractContext end @@ -128,7 +128,7 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context -rewrap(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) +setchildcontext(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) """ PrefixContext{Prefix}(context) @@ -150,7 +150,7 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -rewrap(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) +setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) const PREFIX_SEPARATOR = Symbol(".") @@ -206,7 +206,7 @@ end NodeTrait(context::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -rewrap(parent::ConditionContext, child) = ConditionContext(parent.values, child) +setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) """ getvalue(context, vn) @@ -291,7 +291,7 @@ SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLO """ decondition(::IsLeaf, context, args...) = context function decondition(::IsParent, context, args...) - return rewrap(context, decondition(childcontext(context), args...)) + return setchildcontext(context, decondition(childcontext(context), args...)) end decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 883fc5379..3ab931b65 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,7 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, rewrap(cmodel.context, context)) + return _evaluate(cmodel.model, varinfo, setchildcontext(cmodel.context, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From 4935d5c610be95bcb073f1776e3ff462160ef371 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:08:33 +0100 Subject: [PATCH 32/61] formatting --- src/contexts.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 33da613a7..36c6485ba 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -67,7 +67,9 @@ SamplingContext() = SamplingContext(SampleFromPrior()) NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context -setchildcontext(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) +function setchildcontext(parent::SamplingContext, child) + return SamplingContext(parent.rng, parent.sampler, child) +end """ struct DefaultContext <: AbstractContext end @@ -128,7 +130,9 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context -setchildcontext(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) +function setchildcontext(parent::MiniBatchContext, child) + return MiniBatchContext(child, parent.loglike_scalar) +end """ PrefixContext{Prefix}(context) @@ -150,7 +154,9 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) +function setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} + return PrefixContext{Prefix}(child) +end const PREFIX_SEPARATOR = Symbol(".") From 61960832b75ff490862eaca5da6e186602fc3c93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:11:51 +0100 Subject: [PATCH 33/61] made show a bit nicer for ConditionContext --- src/contexts.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 1c54cf86a..4a6834b32 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -137,6 +137,10 @@ function ConditionContext( return ConditionContext(merge(context.values, values), context.context) end +function Base.show(io::IO, context::ConditionContext) + println(io, "ConditionContext($(context.values), $(context.context))") +end + function getvalue(context::ConditionContext, vn) return if haskey(context, vn) _getvalue(context.values, vn) From 65048fc4c57e1ec884cddbcdcd880269cfdc33a3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:12:32 +0100 Subject: [PATCH 34/61] use print instead of println in show --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 4a6834b32..fa0bf585a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -138,7 +138,7 @@ function ConditionContext( end function Base.show(io::IO, context::ConditionContext) - println(io, "ConditionContext($(context.values), $(context.context))") + print(io, "ConditionContext($(context.values), $(context.context))") end function getvalue(context::ConditionContext, vn) From 649af299a7f68c0def25026bbf06850aedbbfdbc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:14:15 +0100 Subject: [PATCH 35/61] formatting --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index fa0bf585a..926dd0fe0 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -138,7 +138,7 @@ function ConditionContext( end function Base.show(io::IO, context::ConditionContext) - print(io, "ConditionContext($(context.values), $(context.context))") + return print(io, "ConditionContext($(context.values), $(context.context))") end function getvalue(context::ConditionContext, vn) From 4010ab8be8a55e0b4bc60a1194800e6d0bb38850 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 26 Jul 2021 00:02:26 +0100 Subject: [PATCH 36/61] make Model a contextual model --- src/contextual_model.jl | 4 +++- src/model.jl | 28 ++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 3ab931b65..61d5fa6b6 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -4,7 +4,7 @@ struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel end function contextualize(model::AbstractModel, context::AbstractContext) - return ContextualModel(context, model) + return Model(model.name, model.f, model.args, model.defaults, context) end # TODO: What do we do for other contexts? Could handle this in general if we had a @@ -15,6 +15,8 @@ function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context return _evaluate(cmodel.model, varinfo, setchildcontext(cmodel.context, context)) end + +Base.:|(model::AbstractModel, values) = condition(model, values) condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) condition(model::AbstractModel; values...) = condition(model, (; values...)) function condition(cmodel::ContextualModel{<:ConditionContext}, values) diff --git a/src/model.jl b/src/model.jl index 7b5d8918a..ef3f89807 100644 --- a/src/model.jl +++ b/src/model.jl @@ -34,11 +34,13 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: + AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + context::Ctx @doc """ Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) @@ -51,9 +53,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults + context::Ctx=DefaultContext(), + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + name, f, args, defaults, context ) end end @@ -68,10 +71,14 @@ Default arguments `defaults` are used internally when constructing instances of model with different arguments. """ @generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() + name::Symbol, + f::F, + args::NamedTuple{argnames,Targs}, + defaults::NamedTuple=NamedTuple(), + context::AbstractContext=DefaultContext(), ) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) + return :(Model{$missings}(name, f, args, defaults, context)) end """ @@ -157,8 +164,13 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf @generated function _evaluate( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + unwrap_args = [ + :($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames + ] + return quote + context_new = insertcontext(context, model.context) + model.f(model, varinfo, context_new, $(unwrap_args...)) + end end """ From 21c08e5852f05d3ee9ef1bdaa832efb66fb96cc9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:03:31 +0100 Subject: [PATCH 37/61] dont overload haskey and improved contextual_isassumption check --- src/compiler.jl | 22 +++++++++++----------- src/contexts.jl | 16 ++++++++-------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fb77d0323..c69a83803 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -61,17 +61,17 @@ function contextual_isassumption(context::AbstractContext, vn) return contextual_isassumption(NodeTrait(context), context, vn) end function contextual_isassumption(context::ConditionContext, vn) - # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. - - # We have either of the following cases: - # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, - # which means we have a value to replace with and we don't need to recurse. - # 2. One of the decendant contexts consider it as an observation, i.e. - # `contextual_isassumption` evaluates to `false`. - # The below then evaluates to `!(false || true) === false`. - # 3. Neither `context` nor any of it's decendants considers it an observation, - # in which case the below evaluates to `!(false || false) === true`. - return !(haskey(context, vn) || !contextual_isassumption(childcontext(context), vn)) + if hasvalue(context, vn) + val = getvalue(context, vn) + # TODO: Do we even need the `>: Missing` to help the compiler? + if eltype(val) >: Missing && val === missing + return true + end + end + + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}` + # so we defer to `childcontext` if we haven't concluded that anything yet. + return contextual_isassumption(childcontext(context), vn) end function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(childcontext(context), prefix(context, vn)) diff --git a/src/contexts.jl b/src/contexts.jl index fb26b2757..1c63df0c5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -294,7 +294,7 @@ getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn)) function getvalue(context::ConditionContext, vn) - return if haskey(context, vn) + return if hasvalue(context, vn) _getvalue(context.values, vn) else getvalue(childcontext(context), vn) @@ -302,17 +302,17 @@ function getvalue(context::ConditionContext, vn) end # General implementations of `haskey`. -Base.haskey(::IsLeaf, context, vn) = false -Base.haskey(::IsParent, context, vn) = Base.haskey(childcontext(context), vn) -Base.haskey(context::AbstractContext, vn) = Base.haskey(NodeTrait(context), context, vn) +hasvalue(::IsLeaf, context, vn) = false +hasvalue(::IsParent, context, vn) = hasvalue(childcontext(context), vn) +hasvalue(context::AbstractContext, vn) = hasvalue(NodeTrait(context), context, vn) # Specific to `ConditionContext`. -function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} +function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end -function Base.haskey( +function hasvalue( context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} ) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. @@ -332,8 +332,8 @@ condition() = decondition(ConditionContext()) condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) -function condition(context::AbstractContext=DefaultContext(); values...) - return ConditionContext(context; values...) +function condition(context::AbstractContext; values...) + return condition(context, (; values...)) end """ From 35fad3691e4653ecc812dc8e6b9b865526e5aaa7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:07:40 +0100 Subject: [PATCH 38/61] remove the now unnecessary ContextualModel and AbstractModel --- src/contextual_model.jl | 29 ----------------------------- src/model.jl | 33 ++++++++++++++++++++++++--------- 2 files changed, 24 insertions(+), 38 deletions(-) delete mode 100644 src/contextual_model.jl diff --git a/src/contextual_model.jl b/src/contextual_model.jl deleted file mode 100644 index 61d5fa6b6..000000000 --- a/src/contextual_model.jl +++ /dev/null @@ -1,29 +0,0 @@ -struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel - context::Ctx - model::M -end - -function contextualize(model::AbstractModel, context::AbstractContext) - return Model(model.name, model.f, model.args, model.defaults, context) -end - -# TODO: What do we do for other contexts? Could handle this in general if we had a -# notion of wrapper-, primitive-context, etc. -function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) - # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as - # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, setchildcontext(cmodel.context, context)) -end - - -Base.:|(model::AbstractModel, values) = condition(model, values) -condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) -condition(model::AbstractModel; values...) = condition(model, (; values...)) -function condition(cmodel::ContextualModel{<:ConditionContext}, values) - return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) -end - -decondition(model::AbstractModel, args...) = model -function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) - return contextualize(cmodel.model, decondition(cmodel.context, syms...)) -end diff --git a/src/model.jl b/src/model.jl index ef3f89807..a725bff8b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,3 @@ -abstract type AbstractModel <: AbstractProbabilisticProgram end - """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -35,7 +33,7 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x ``` """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractModel + AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -81,6 +79,23 @@ model with different arguments. return :(Model{$missings}(name, f, args, defaults, context)) end +function contextualize(model::Model, context::AbstractContext) + return Model(model.name, model.f, model.args, model.defaults, context) +end + +Base.:|(model::Model, values) = condition(model, values) + +condition(model::Model, values) = contextualize(model, ConditionContext(values)) +condition(model::Model; values...) = condition(model, (; values...)) +function condition(model::Model, values) + return contextualize(model, condition(model.context, values)) +end + +function decondition(model::Model, syms...) + return contextualize(model, decondition(model.context, syms...)) +end + + """ (model::Model)([rng, varinfo, sampler, context]) @@ -90,7 +105,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::AbstractModel)( +function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -99,8 +114,8 @@ function (model::AbstractModel)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) -function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext) +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else @@ -108,17 +123,17 @@ function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractConte end end -function (model::AbstractModel)(args...) +function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end From 9297e612592580a891994dd418ee394ab8d2e2bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:14:14 +0100 Subject: [PATCH 39/61] removed some leftovers --- Project.toml | 1 - src/DynamicPPL.jl | 1 - src/varinfo.jl | 6 +++--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index ecf47b50f..5678c050a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" -BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bfcc49e6b..aa8b8ca58 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,7 +6,6 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC -using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules diff --git a/src/varinfo.jl b/src/varinfo.jl index ba6e157e9..64c122dc2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::AbstractModel, + model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end From 65f80944f20b08f6b821af5422c33386a9863f1f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:14:53 +0100 Subject: [PATCH 40/61] removed now gone include --- src/DynamicPPL.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index aa8b8ca58..e7732d8c5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -133,6 +133,5 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -include("contextual_model.jl") end # module From b201ef3bb76ad20fb76dcc38c37b91ed486cb5c9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:15:51 +0100 Subject: [PATCH 41/61] removed leftovers --- Project.toml | 1 - src/DynamicPPL.jl | 2 -- 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index ecf47b50f..5678c050a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" -BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bfcc49e6b..e7732d8c5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,7 +6,6 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC -using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -134,6 +133,5 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -include("contextual_model.jl") end # module From 9ebdd0e9f2b5313a25565ae06d353a206363b3a8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 02:09:37 +0100 Subject: [PATCH 42/61] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/model.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index a725bff8b..92b773ab5 100644 --- a/src/model.jl +++ b/src/model.jl @@ -95,7 +95,6 @@ function decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) end - """ (model::Model)([rng, varinfo, sampler, context]) From 274ad23758d14a2a20ed3f5c2836892f9a03ccdb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 02:41:45 +0100 Subject: [PATCH 43/61] fixed typo --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index a725bff8b..8f9057459 100644 --- a/src/model.jl +++ b/src/model.jl @@ -183,7 +183,7 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf :($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames ] return quote - context_new = insertcontext(context, model.context) + context_new = setleafcontext(context, model.context) model.f(model, varinfo, context_new, $(unwrap_args...)) end end From 8cc3193d6909be89ad772ac7eb19cee121d128d5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:37:01 +0100 Subject: [PATCH 44/61] improved condition and fixed isassumption --- src/DynamicPPL.jl | 1 + src/compiler.jl | 2 +- src/contexts.jl | 63 ++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e7732d8c5..2f0ef2617 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -9,6 +9,7 @@ using AbstractMCMC: AbstractMCMC using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules +using BangBang: BangBang using Random: Random diff --git a/src/compiler.jl b/src/compiler.jl index c69a83803..6e10c4285 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -34,7 +34,7 @@ function isassumption(expr::Union{Symbol,Expr}) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) + if !($(DynamicPPL.inargnames)($vn, __model__)) || $(DynamicPPL.inmissings)($vn, __model__) true else $expr === missing diff --git a/src/contexts.jl b/src/contexts.jl index 1c63df0c5..d1b2145c2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -66,9 +66,9 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext -julia> struct ParentContext{C} +julia> struct ParentContext{C} <: AbstractContext context::C end @@ -328,14 +328,13 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ -condition() = decondition(ConditionContext()) +function condition(; values...) + return isempty(values) ? decondition(ConditionContext()) : condition(DefaultContext(), (; values...)) +end condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) -function condition(context::AbstractContext; values...) - return condition(context, (; values...)) -end - +condition(context::AbstractContext; values...) = condition(context, (; values...)) """ decondition(context::AbstractContext, syms...) @@ -347,20 +346,60 @@ See also: [`condition`](@ref) # Examples ```jldoctest -julia> ctx = DefaultContext(); +julia> using DynamicPPL: AbstractContext, leafcontext, setleafcontext, childcontext, setchildcontext + +julia> struct ParentContext{C} <: AbstractContext + context::C + end + +julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() + +julia> DynamicPPL.childcontext(context::ParentContext) = context.context + +julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) + +julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + +julia> ctx = DefaultContext() +DefaultContext() julia> decondition(ctx) === ctx # this is a no-op true -julia> ctx = ConditionContext(x = 1.0); +julia> ctx = condition(x = 1.0) # default "constructor" for `ConditionContext` +ConditionContext((x = 1.0,), DefaultContext()) -julia> decondition(ctx) +julia> decondition(ctx) # `decondition` without arguments drops all conditioning DefaultContext() -julia> ctx_nested = ConditionContext(SamplingContext(ConditionContext(y=2.0)), x=1.0); +julia> # Nested conditioning is supported. + ctx_nested = condition(ParentContext(condition(y=2.0)), x=1.0) +ConditionContext((x = 1.0,), ParentContext(ConditionContext((y = 2.0,), DefaultContext()))) + +julia> # We can also specify which variables to drop. + decondition(ctx_nested, :x) +ParentContext(ConditionContext((y = 2.0,), DefaultContext())) + +julia> # No matter the nested level. + decondition(ctx_nested, :y) +ConditionContext((x = 1.0,), ParentContext(DefaultContext())) + +julia> # Or specify multiple at in one call. + decondition(ctx_nested, :x, :y) +ParentContext(DefaultContext()) julia> decondition(ctx_nested) -SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLOBAL_RNG(), SampleFromPrior(), DefaultContext()) +ParentContext(DefaultContext()) + +julia> # `Val` is also supported. + decondition(ctx_nested, Val(:x)) +ParentContext(ConditionContext((y = 2.0,), DefaultContext())) + +julia> decondition(ctx_nested, Val(:y)) +ConditionContext((x = 1.0,), ParentContext(DefaultContext())) + +julia> decondition(ctx_nested, Val(:x), Val(:y)) +ParentContext(DefaultContext()) ``` """ decondition(::IsLeaf, context, args...) = context From 26edb2c80cc8562fdff1dfefde403ae02395d010 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:40:08 +0100 Subject: [PATCH 45/61] added BangBang as dep --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 5678c050a..fd514f836 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" AbstractPPL = "0.2" +BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" From be61ef1224cb31c43998a7720c12c219504b0624 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:41:28 +0100 Subject: [PATCH 46/61] fixed the _evaluate --- src/model.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index 45865ef69..600ba162b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -85,7 +85,6 @@ end Base.:|(model::Model, values) = condition(model, values) -condition(model::Model, values) = contextualize(model, ConditionContext(values)) condition(model::Model; values...) = condition(model, (; values...)) function condition(model::Model, values) return contextualize(model, condition(model.context, values)) @@ -181,8 +180,15 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf unwrap_args = [ :($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames ] + # We want to give `context` precedence over `model.context` while also + # preserving the leaf context of `context`. We can do this by + # 1. Set the leaf context of `model.context` to `leafcontext(context)`. + # 2. Set leaf context of `context` to the context resulting from (1). + # The result is: + # `context` -> `childcontext(context)` -> ... -> `model.context` + # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` return quote - context_new = setleafcontext(context, model.context) + context_new = setleafcontext(context, setleafcontext(model.context, leafcontext(context))) model.f(model, varinfo, context_new, $(unwrap_args...)) end end From e2f6fc5f8a2ddc9e6efb04b506c9fba50df28aca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:46:36 +0100 Subject: [PATCH 47/61] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 3 ++- src/contexts.jl | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6e10c4285..921062079 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -34,7 +34,8 @@ function isassumption(expr::Union{Symbol,Expr}) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || $(DynamicPPL.inmissings)($vn, __model__) + if !($(DynamicPPL.inargnames)($vn, __model__)) || + $(DynamicPPL.inmissings)($vn, __model__) true else $expr === missing diff --git a/src/contexts.jl b/src/contexts.jl index de7a44207..b045a17e2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -326,7 +326,11 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ function condition(; values...) - return isempty(values) ? decondition(ConditionContext()) : condition(DefaultContext(), (; values...)) + return if isempty(values) + decondition(ConditionContext()) + else + condition(DefaultContext(), (; values...)) + end end condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context From a361a5ea5caa9d50a3fbc989ed3b35bd3e1d592f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:47:09 +0100 Subject: [PATCH 48/61] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/model.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 600ba162b..d55f9d5b2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -188,7 +188,9 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf # `context` -> `childcontext(context)` -> ... -> `model.context` # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` return quote - context_new = setleafcontext(context, setleafcontext(model.context, leafcontext(context))) + context_new = setleafcontext( + context, setleafcontext(model.context, leafcontext(context)) + ) model.f(model, varinfo, context_new, $(unwrap_args...)) end end From ce160d682c420127126ae7c53d9f2f6ca47e1373 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:53:18 +0100 Subject: [PATCH 49/61] fixed condition without arguments --- src/contexts.jl | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index b045a17e2..ca3c6b4c2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -251,9 +251,6 @@ struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext end end -function ConditionContext(context::ConditionContext, child_context::AbstractContext) - return ConditionContext(context.values, child_context) -end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -325,13 +322,7 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ -function condition(; values...) - return if isempty(values) - decondition(ConditionContext()) - else - condition(DefaultContext(), (; values...)) - end -end +condition(; values...) = condition(DefaultContext(), (; values...)) condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) From b492041baba4abcc04d5b29c6ca1fb1cdd3d5ae0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 07:55:56 +0100 Subject: [PATCH 50/61] had forgotten a return --- src/compiler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 921062079..ae0c2fed0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -67,6 +67,8 @@ function contextual_isassumption(context::ConditionContext, vn) # TODO: Do we even need the `>: Missing` to help the compiler? if eltype(val) >: Missing && val === missing return true + else + return false end end From 7ae9e3e871034ad963800237f2d406b7d21f79ce Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 08:23:09 +0100 Subject: [PATCH 51/61] added methods for extracting the conditioned/observed variables and values --- src/DynamicPPL.jl | 3 +++ src/contexts.jl | 55 +++++++++++++++++++++++++++++++++++++++++++++++ src/model.jl | 4 ++++ 3 files changed, 62 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 2f0ef2617..f054517bb 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -104,6 +104,9 @@ export AbstractVarInfo, pointwise_loglikelihoods, condition, decondition, + contextualize, + observations, + conditioned, # Convenience macros @addlogprob!, @submodel diff --git a/src/contexts.jl b/src/contexts.jl index ca3c6b4c2..409f666d1 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -416,3 +416,58 @@ function decondition(context::ConditionContext, sym, syms...) syms..., ) end + +""" + conditioned(model::Model) + conditioned(context::AbstractContext) + +Return `NamedTuple` of values that are conditioned on under `model`/`context`. + +# Examples +```jldoctest +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + end +demo (generic function with 1 methods) + +julia> m = demo(); + +julia> # Returns all the variables we have conditioned on + their values. + conditioned(condition(m, x=100.0, m=1.0)) +(x = 100.0, m = 1.0) + +julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). + cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, m = 1.0) + +julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, + # `a.m` is treated as a random variable. + keys(VarInfo(cm)) +1-element Vector{VarName{Symbol("a.m"), Tuple{}}}: + a.m + +julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. + cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, a.m = 1.0) + +julia> keys(VarInfo(cm)) # <= no variables are sampled +Any[] +``` +""" +function conditioned(context::AbstractContext) + conditioned(NodeTrait(conditioned, context), context) +end +conditioned(::IsLeaf, context) = () +conditioned(::IsParent, context) = conditioned(childcontext(context)) +function conditioned(context::ConditionContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `conditioned` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return merge(context.values, conditioned(childcontext(context))) +end diff --git a/src/model.jl b/src/model.jl index d55f9d5b2..105878a61 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,6 +94,10 @@ function decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) end +observations(model::Model) = conditioned(model) +conditioned(model::Model) = conditioned(model.context) + + """ (model::Model)([rng, varinfo, sampler, context]) From 3b31d35ff072cf093973fbc8cd906f154db4c9f6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 08:23:28 +0100 Subject: [PATCH 52/61] fixed a couple of tilde statements --- src/context_implementations.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9464098d7..83f57fe70 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -171,6 +171,8 @@ function tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts +# TODO: Should we maybe not do `args...` here but instead be explicit? +# Could help avoid stealthy bugs. function tilde_observe(context::AbstractContext, args...) return tilde_observe(NodeTrait(tilde_observe, context), context, args...) end @@ -186,13 +188,13 @@ tilde_observe(::PriorContext, sampler, right, left, vi) = 0 function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(context::MiniBatchContext, right, left, vname, vi) - return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) +function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, sampler, right, left, vname, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ From cefb443c1d3a9c5c389c198024862e8905b8c1e3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 08:25:07 +0100 Subject: [PATCH 53/61] added docstring to observations --- src/model.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/model.jl b/src/model.jl index 105878a61..3495b6a05 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,6 +94,11 @@ function decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) end +""" + observations(model::Model) + +Alias for [`conditioned`](@ref). +""" observations(model::Model) = conditioned(model) conditioned(model::Model) = conditioned(model.context) From 6086734bba15452cd8cc17211565e438308d9997 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 2 Aug 2021 18:38:39 +0100 Subject: [PATCH 54/61] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 3 ++- src/contexts.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 83f57fe70..54a198acc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -189,7 +189,8 @@ function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.context, sampler, right, left, vname, vi) + return context.loglike_scalar * + tilde_observe(context.context, sampler, right, left, vname, vi) end # `PrefixContext` diff --git a/src/contexts.jl b/src/contexts.jl index 409f666d1..ca072453b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -460,7 +460,7 @@ Any[] ``` """ function conditioned(context::AbstractContext) - conditioned(NodeTrait(conditioned, context), context) + return conditioned(NodeTrait(conditioned, context), context) end conditioned(::IsLeaf, context) = () conditioned(::IsParent, context) = conditioned(childcontext(context)) From e57902001bf139fab6378dec70fe9c9fe6907868 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Aug 2021 20:38:55 +0100 Subject: [PATCH 55/61] Apply suggestions from code review Co-authored-by: Hong Ge --- src/DynamicPPL.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f054517bb..7d7a14584 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -68,7 +68,6 @@ export AbstractVarInfo, vectorize, # Model Model, - ContextualModel, getmissings, getargnames, generated_quantities, From 1d3fa2b2a269af89fb3df59f6ac7be8b758468a5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:27:52 +0100 Subject: [PATCH 56/61] dont export contextualize, observations and conditioned --- src/DynamicPPL.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 7d7a14584..ac2734b47 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -103,9 +103,6 @@ export AbstractVarInfo, pointwise_loglikelihoods, condition, decondition, - contextualize, - observations, - conditioned, # Convenience macros @addlogprob!, @submodel From da34dff45a553ee7b8cfcc56589bef67a8034732 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:01:08 +0100 Subject: [PATCH 57/61] added getvalue_nested and hasvalue_nested to be more explicit --- src/compiler.jl | 4 +-- src/contexts.jl | 76 +++++++++++++++++++++++++++++++++---------------- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7a1556419..e55caac42 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -400,7 +400,7 @@ function generate_tilde(left, right) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getvalue)(__context__, $vn) + $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) end $(DynamicPPL.tilde_observe!)( @@ -449,7 +449,7 @@ function generate_dot_tilde(left, right) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left .= $(DynamicPPL.getvalue)(__context__, $vn) + $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) end $(DynamicPPL.dot_tilde_observe!)( diff --git a/src/contexts.jl b/src/contexts.jl index d0b1230b0..884529b36 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -288,43 +288,69 @@ childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) """ - getvalue(context, vn) + hasvalue(context, vn) -Return the value of the parameter corresponding to `vn` from `context`. -If `context` does not contain the value for `vn`, then `nothing` is returned, -e.g. [`DefaultContext`](@ref) will always return `nothing`. +Return `true` if `vn` is found in `context`. """ -getvalue(::IsLeaf, context, vn) = nothing -getvalue(::IsParent, context, vn) = getvalue(childcontext(context), vn) -getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), context, vn) -getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn)) - -function getvalue(context::ConditionContext, vn) - return if hasvalue(context, vn) - _getvalue(context.values, vn) - else - getvalue(childcontext(context), vn) - end -end +hasvalue(context, vn) = false -# General implementations of `haskey`. -hasvalue(::IsLeaf, context, vn) = false -hasvalue(::IsParent, context, vn) = hasvalue(childcontext(context), vn) -hasvalue(context::AbstractContext, vn) = hasvalue(NodeTrait(context), context, vn) - -# Specific to `ConditionContext`. function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} - # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end - function hasvalue( context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} ) where {vars,sym} - # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end +""" + getvalue(context, vn) + +Return value of `vn` in `context`. +""" +function getvalue(context::AbstractContext, vn) + return error("context $(context) does not contain value for $vn") +end +getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) + +""" + hasvalue_nested(context, vn) + +Return `true` if `vn` is found in `context` or any of its descendants. +""" +function hasvalue_nested(context::AbstractContext, vn) + return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) +end +hasvalue_nested(::IsLeaf, context, vn) = hasvalue(context, vn) +function hasvalue_nested(::IsParent, context, vn) + return hasvalue(context, vn) || hasvalue_nested(childcontext(context), vn) +end +function hasvalue_nested(context::PrefixContext, vn) + return hasvalue_nested(childcontext(context), prefix(context, vn)) +end + +""" + getvalue_nested(context, vn) + +Return the value of the parameter corresponding to `vn` from `context` or its descendants. +""" +function getvalue_nested(context::AbstractContext, vn) + return getvalue_nested(NodeTrait(getvalue_nested, context), context, vn) +end +function getvalue_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getvalue_nested(context::PrefixContext, vn) + return getvalue_nested(childcontext(context), prefix(context, vn)) +end +function getvalue_nested(::IsParent, context, vn) + return if hasvalue(context, vn) + getvalue(context, vn) + else + getvalue_nested(childcontext(context), vn) + end +end + """ context([context::AbstractContext,] values::NamedTuple) context([context::AbstractContext]; values...) From 5e15b25a1cb1c53cf64a53978f37ee7c94121b9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:01:49 +0100 Subject: [PATCH 58/61] added a substantial amount of testing for ConditionContext --- src/model.jl | 1 - test/contexts.jl | 125 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 3495b6a05..9f163910d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -102,7 +102,6 @@ Alias for [`conditioned`](@ref). observations(model::Model) = conditioned(model) conditioned(model::Model) = conditioned(model.context) - """ (model::Model)([rng, varinfo, sampler, context]) diff --git a/test/contexts.jl b/test/contexts.jl index 7793dfc74..5f9f511b9 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,8 +8,15 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext - + PointwiseLikelihoodContext, + contextual_isassumption, + ConditionContext, + hasvalue, + getvalue, + hasvalue_nested, + getvalue_nested + +# Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext context::C end @@ -19,6 +26,50 @@ DynamicPPL.childcontext(context::ParentContext) = context.context DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") +# TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractContext) + if NodeTrait(context) isa IsLeaf + return nothing + end + + return context, context +end +function Base.iterate(_::AbstractContext, context::AbstractContext) + return _iterate(NodeTrait(context), context) +end +_iterate(::IsLeaf, context) = nothing +function _iterate(::IsParent, context) + child = childcontext(context) + return child, child +end + +Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() + +""" + remove_prefix(vn::VarName) + +Return `vn` but now with the prefix removed. +""" +remove_prefix(vn::VarName) = VarName{Symbol(split(string(vn), ".")[end])}(vn.indexing) + +""" + varnames(vn::VarName, val) + +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varnames(vn::VarName, val::Real) = [vn] +function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return (VarName(vn, (vn.indexing..., Tuple(I))) for I in CartesianIndices(val)) +end +function varnames(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varnames(VarName(vn, (vn.indexing..., Tuple(I))), val[I]) for + I in CartesianIndices(val) + ) +end + @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] @@ -28,6 +79,10 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), PointwiseLikelihoodContext(), + ConditionContext((x=1.0,)), + ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), + ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), + ConditionContext((x=[1.0, missing],)), ] contexts = vcat(child_contexts, parent_contexts) @@ -104,6 +159,72 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c end end + @testset "contextual_isassumption" begin + @testset "$context" for context in contexts + # Any `context` should return `true` by default. + @test contextual_isassumption(context, VarName{gensym(:x)}()) + + if any(Base.Fix2(isa, ConditionContext), context) + # We have a `ConditionContext` among us. + # Let's first extract the conditioned variables. + conditioned_values = DynamicPPL.conditioned(context) + + for (sym, val) in pairs(conditioned_values) + vn = VarName{sym}() + + # We need to drop the prefix of `var` since in `contextual_isassumption` + # it will be threaded through the `PrefixContext` before it reaches + # `ConditionContext` with the conditioned variable. + vn_without_prefix = remove_prefix(vn) + + # Let's check elementwise. + for vn_child in varnames(vn_without_prefix, val) + if DynamicPPL._getindex(val, vn_child.indexing) === missing + @test contextual_isassumption(context, vn_child) + else + @test !contextual_isassumption(context, vn_child) + end + end + end + end + end + end + + @testset "getvalue_nested & hasvalue_nested" begin + @testset "$context" for context in contexts + fake_vn = VarName{gensym(:x)}() + @test !hasvalue_nested(context, fake_vn) + @test_throws ErrorException getvalue_nested(context, fake_vn) + + if any(Base.Fix2(isa, ConditionContext), context) + # `ConditionContext` specific. + + # Let's first extract the conditioned variables. + conditioned_values = DynamicPPL.conditioned(context) + + for (sym, val) in pairs(conditioned_values) + vn = VarName{sym}() + + # We need to drop the prefix of `var` since in `contextual_isassumption` + # it will be threaded through the `PrefixContext` before it reaches + # `ConditionContext` with the conditioned variable. + vn_without_prefix = remove_prefix(vn) + + for vn_child in varnames(vn_without_prefix, val) + # `vn_child` should be in `context`. + @test hasvalue_nested(context, vn_child) + if !hasvalue_nested(context, vn_child) + @info "" context vn_child + end + # Value should be the same as extracted above. + @test getvalue_nested(context, vn_child) === + DynamicPPL._getindex(val, vn_child.indexing) + end + end + end + end + end + @testset "PrefixContext" begin ctx = @inferred PrefixContext{:f}( PrefixContext{:e}( From 386e9852e4d5165082dd92d1801fb35df4e58470 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:02:34 +0100 Subject: [PATCH 59/61] removed some debugging code --- test/contexts.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 5f9f511b9..fa96a0ec4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -213,9 +213,6 @@ end for vn_child in varnames(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasvalue_nested(context, vn_child) - if !hasvalue_nested(context, vn_child) - @info "" context vn_child - end # Value should be the same as extracted above. @test getvalue_nested(context, vn_child) === DynamicPPL._getindex(val, vn_child.indexing) From 52c9f044fe9b94cfaee52865d6e6c3a2c29dda48 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:18:18 +0100 Subject: [PATCH 60/61] use maybe_view in isassumption check --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index e55caac42..7880b9f1c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -38,7 +38,7 @@ function isassumption(expr::Union{Symbol,Expr}) $(DynamicPPL.inmissings)($vn, __model__) true else - $expr === missing + $(maybe_view(expr)) === missing end else false From 65e7f7128bb9b0ce66888e714d85725448dde3ee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:20:29 +0100 Subject: [PATCH 61/61] added more extensive docstring for getvalue_nested and hasvalue_nested --- src/contexts.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 884529b36..55b070c6d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -317,6 +317,9 @@ getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) hasvalue_nested(context, vn) Return `true` if `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`, +not recursively checking if `vn` is in any of its descendants. """ function hasvalue_nested(context::AbstractContext, vn) return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) @@ -333,6 +336,9 @@ end getvalue_nested(context, vn) Return the value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getvalue`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. """ function getvalue_nested(context::AbstractContext, vn) return getvalue_nested(NodeTrait(getvalue_nested, context), context, vn)