From cd1c46de8bef1242f22ea9303ff1faca0bb972c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:08:01 +0100 Subject: [PATCH 01/41] added ConditionContext and ContextualModel --- Project.toml | 1 + src/DynamicPPL.jl | 6 ++++ src/context_implementations.jl | 51 ++++++++++++++++++++++++++++++++++ src/contexts.jl | 42 ++++++++++++++++++++++++++++ src/model.jl | 14 ++++++---- src/varinfo.jl | 6 ++-- 6 files changed, 111 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 921dc054d..7cc3db31d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..bfcc49e6b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC +using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -67,6 +68,7 @@ export AbstractVarInfo, vectorize, # Model Model, + ContextualModel, getmissings, getargnames, generated_quantities, @@ -81,6 +83,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +102,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel @@ -129,5 +134,6 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("contextual_model.jl") end # module diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..68165a64f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi) return value end +function tilde_assume!(context::ConditionContext, right, vn, inds, vi) + value = if haskey(context, vn) + # Extract value. + if inds isa Tuple{} + getfield(context.values, getsym(vn)) + else + _getindex(getfield(context.values, getsym(vn)), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + tilde_observe!(context.context, right, value, vn, inds, vi) + else + tilde_assume!(context.context, right, vn, inds, vi) + end + + return value +end + # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi) return left end +function tilde_observe!(context::ConditionContext, right, left, vi) + return tilde_observe!(context.context, right, left, vi) +end + function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end +function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) + value = if vn in context + # Extract value. + if inds isa Tuple{} + getfield(context.values, sym) + else + _getindex(getfield(context.values, sym), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + dot_tilde_observe!(context.context, right, left, vn, inds, vi) + else + dot_tilde_assume!(context.context, right, left, vn, inds, vi) + end + + return value +end + # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi @@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end +function dot_tilde_observe!(context::ConditionContext, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) +end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..946c22092 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -106,3 +106,45 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + + +struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +function ConditionContext(values::NamedTuple{Vars}) where {Vars} + return ConditionContext(values, DefaultContext()) +end + +function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} + return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), context.context) +end + +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + +# TODO: Can we maybe do this in a better way? +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, sym, syms...) + return decondition( + ConditionContext(BangBang.delete!!(context.values, sym), context.context), + syms... + ) +end +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} + return decondition( + ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), + syms... + ) +end diff --git a/src/model.jl b/src/model.jl index 9ec047a44..3ce707e2e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,5 @@ +abstract type AbstractModel <: AbstractProbabilisticProgram end + """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -33,7 +35,7 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x ``` """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram + AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -82,7 +84,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( +function (model::AbstractModel)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -91,7 +93,7 @@ function (model::Model)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::Model)(context::AbstractContext) = model(VarInfo(), context) +(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) @@ -100,17 +102,17 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) +function (model::AbstractModel)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..be6551939 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::Model, + model::AbstractModel, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end From b1106ee65b60a8eebffada99b5c770a869730169 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:10:12 +0100 Subject: [PATCH 02/41] removed redundant definition --- src/contexts.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 946c22092..5abae4168 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -141,7 +141,6 @@ function decondition(context::ConditionContext, sym, syms...) syms... ) end -decondition(context::ConditionContext) = context function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} return decondition( ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), From 0f00771e8f391311cd6956ac33e3e31c770a2409 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:36:28 +0100 Subject: [PATCH 03/41] return condition model by default --- src/compiler.jl | 16 ++++++++++------ src/contexts.jl | 14 +++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..00be43551 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,7 +23,8 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) + $(DynamicPPL.inmissings)($vn, __model__) || + (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) true else # Evaluate the LHS @@ -427,11 +428,14 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, + return $(DynamicPPL.condition)( + $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, + $allargs_namedtuple, + $defaults_namedtuple, + ), + $allargs_namedtuple ) end diff --git a/src/contexts.jl b/src/contexts.jl index 5abae4168..0a0e977b9 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -134,16 +134,16 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v end # TODO: Can we maybe do this in a better way? -decondition(context::ConditionContext) = context +# When no second argument is given, we remove _all_ conditioned variables. +# TODO: Should we remove this and just return `context.context`? +# That will work better if `Model` becomes like `ContextualModel`. +decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext, sym) + return ConditionContext(BangBang.delete!!(context.values, sym), context.context) +end function decondition(context::ConditionContext, sym, syms...) return decondition( ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end -function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} - return decondition( - ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), - syms... - ) -end From c754291411f3af1bef2468838493a750eaefa427 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:47:50 +0100 Subject: [PATCH 04/41] formatting --- src/compiler.jl | 9 ++++++--- src/contexts.jl | 6 ++---- src/model.jl | 3 +-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 00be43551..191152f4e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,8 +23,11 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) + $(DynamicPPL.inmissings)($vn, __model__) || + ( + __context__ isa $(DynamicPPL.ConditionContext) && + !$(Base.haskey)(__context__, $vn) + ) true else # Evaluate the LHS @@ -435,7 +438,7 @@ function build_output(modelinfo, linenumbernode) $allargs_namedtuple, $defaults_namedtuple, ), - $allargs_namedtuple + $allargs_namedtuple, ) end diff --git a/src/contexts.jl b/src/contexts.jl index 0a0e977b9..9831dde4a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,6 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end - struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -128,7 +127,7 @@ function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) w return ConditionContext(merge(context.values, values), context.context) end -function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end @@ -143,7 +142,6 @@ function decondition(context::ConditionContext, sym) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), context.context), - syms... + ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end diff --git a/src/model.jl b/src/model.jl index 3ce707e2e..39a9f4fd3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -34,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} From 2d3f94cb6c137f6ce22683392cf4fcf0c579a47d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:15:23 +0100 Subject: [PATCH 05/41] forgot to include contextual model in previous commit --- src/contextual_model.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 src/contextual_model.jl diff --git a/src/contextual_model.jl b/src/contextual_model.jl new file mode 100644 index 000000000..382c45034 --- /dev/null +++ b/src/contextual_model.jl @@ -0,0 +1,29 @@ +struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel + context::Ctx + model::M +end + +function contextualize(model::AbstractModel, context::AbstractContext) + return ContextualModel(context, model) +end + +# TODO: What do we do for other contexts? Could handle this in general if we had a +# notion of wrapper-, primitive-context, etc. +function (cmodel::ContextualModel{<:ConditionContext})( + varinfo::AbstractVarInfo, context::AbstractContext +) + # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as + # `ConditionContext` child. + return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) +end + +condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) +condition(model::AbstractModel; values...) = condition(model, (; values...)) +function condition(cmodel::ContextualModel{<:ConditionContext}, values) + return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) +end + +decondition(model::AbstractModel, args...) = model +function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) + return contextualize(cmodel.model, decondition(cmodel.context, syms...)) +end From 3e5a79f7e1b14ba610f88377ef9b9f177777becb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:45:45 +0100 Subject: [PATCH 06/41] fixed typos --- src/context_implementations.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 68165a64f..052b4ab9c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -134,9 +134,9 @@ function tilde_assume!(context, right, vn, inds, vi) end function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - value = if haskey(context, vn) + if haskey(context, vn) # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, getsym(vn)) else _getindex(getfield(context.values, getsym(vn)), inds) @@ -149,7 +149,7 @@ function tilde_assume!(context::ConditionContext, right, vn, inds, vi) tilde_observe!(context.context, right, value, vn, inds, vi) else - tilde_assume!(context.context, right, vn, inds, vi) + value = tilde_assume!(context.context, right, vn, inds, vi) end return value @@ -446,9 +446,9 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - value = if vn in context + if vn in context # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, sym) else _getindex(getfield(context.values, sym), inds) @@ -459,9 +459,9 @@ function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) vi[vn] = value end - dot_tilde_observe!(context.context, right, left, vn, inds, vi) + dot_tilde_observe!(context.context, right, value, vn, inds, vi) else - dot_tilde_assume!(context.context, right, left, vn, inds, vi) + value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) end return value From d4e42387888a4d7c9b2f9e62541a8de70fbb6f0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 23:06:59 +0100 Subject: [PATCH 07/41] added some niceties to ConditionContext --- src/contexts.jl | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 9831dde4a..e06d4c6eb 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -110,14 +110,38 @@ end struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx + + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) + end +end + +@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} + names_expr = Expr(:tuple) + values_expr = Expr(:tuple) + + for (n, v) in zip(names, values.parameters) + if !(v <: Missing) + push!(names_expr.args, QuoteNode(n)) + push!(values_expr.args, :(nt.$n)) + end + end + + return :(NamedTuple{$names_expr}($values_expr)) end -function ConditionContext(values::NamedTuple{Vars}) where {Vars} +function ConditionContext(context::ConditionContext, child_context::AbstractContext) + return ConditionContext(context.values, child_context) +end +function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end -function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} - return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +function ConditionContext(values::NamedTuple, context::AbstractContext) + values_wo_missing = drop_missings(values) + return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) end # Try to avoid nested `ConditionContext`. From 85a47eb098498df027ebaf2efe124517507933cb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 01:49:52 +0100 Subject: [PATCH 08/41] added support for vectors of VarName --- src/context_implementations.jl | 2 +- src/contexts.jl | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 052b4ab9c..813d40875 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -446,7 +446,7 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - if vn in context + if haskey(context, vn) # Extract value. value = if inds isa Tuple{} getfield(context.values, sym) diff --git a/src/contexts.jl b/src/contexts.jl index e06d4c6eb..500d3d60b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -156,6 +156,11 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v return sym in vars end +function Base.haskey(context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}) where {vars,sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + # TODO: Can we maybe do this in a better way? # When no second argument is given, we remove _all_ conditioned variables. # TODO: Should we remove this and just return `context.context`? From c7dae8daa1a66712328161be7a8b4fe94575deae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 01:52:17 +0100 Subject: [PATCH 09/41] Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 500d3d60b..d0cd81287 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -156,7 +156,9 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v return sym in vars end -function Base.haskey(context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}) where {vars,sym} +function Base.haskey( + context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} +) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end From 75680b3d5715eb29239ca39ff7401d3ce966f194 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 05:12:00 +0100 Subject: [PATCH 10/41] upper-bound Distributions.jl in tests --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 146e92c55..37c509610 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "2.1, 3.0" AbstractPPL = "0.1.3" Bijectors = "0.9.5" -Distributions = "0.24, 0.25" +Distributions = "< 0.25.11" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" From f0ae744e625adbe0776021a0991d0aeb14fd8790 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 05:33:42 +0100 Subject: [PATCH 11/41] make the isassumption check using context extensible and nicer --- src/compiler.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d0485dba9..9eca8abae 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -24,10 +24,7 @@ function isassumption(expr::Union{Symbol,Expr}) # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || $(DynamicPPL.inmissings)($vn, __model__) || - ( - __context__ isa $(DynamicPPL.ConditionContext) && - !$(Base.haskey)(__context__, $vn) - ) + $(DynamicPPL.contextual_isassumption)(__context__, $vn) true else # Evaluate the LHS @@ -37,6 +34,9 @@ function isassumption(expr::Union{Symbol,Expr}) end end +contextual_isassumption(context::AbstractContext, vn) = false +contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) + # failsafe: a literal is never an assumption isassumption(expr) = :(false) From b990bb05a3baec3f67c714fa11d529d974a97839 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 06:14:33 +0100 Subject: [PATCH 12/41] renamed type-parameter for ConditionContext --- src/contexts.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index d0cd81287..8bfc25a9b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,7 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end -struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext +struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -145,7 +145,7 @@ function ConditionContext(values::NamedTuple, context::AbstractContext) end # Try to avoid nested `ConditionContext`. -function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} +function ConditionContext(values::NamedTuple{Names}, context::ConditionContext) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), context.context) From 8994cd78834cfa9408927dca7be95532a5a554c7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 06:27:08 +0100 Subject: [PATCH 13/41] Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 8bfc25a9b..96bd43cd7 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -145,7 +145,9 @@ function ConditionContext(values::NamedTuple, context::AbstractContext) end # Try to avoid nested `ConditionContext`. -function ConditionContext(values::NamedTuple{Names}, context::ConditionContext) where {Names} +function ConditionContext( + values::NamedTuple{Names}, context::ConditionContext +) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), context.context) From 94da453a489102f2a326e777993e2c45001d966a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:17:50 +0100 Subject: [PATCH 14/41] introduced convenient _getvalue method --- src/context_implementations.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 813d40875..9f9ca3004 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) +_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) +function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} + val = getproperty(x, sym) + + # This should work with both cartesian and linear indexing. + return map(vns) do vn + _getindex(val, vn) + end +end # assume """ From 4e74cf8a6ab99de553b23ca898ce451a8d190aac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:18:04 +0100 Subject: [PATCH 15/41] overload tilde_assume rather than tilde_assume! and others for ConditionContext --- src/context_implementations.jl | 117 +++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 48 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9f9ca3004..54f6d5965 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -127,6 +127,36 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end +# `ConditionContext` +function tilde_assume(context::ConditionContext, right, vn, inds, vi) + if !haskey(context, vn) + # Not conditioned on => defer to child context. + return tilde_assume(context.context, right, vn, inds, vi) + end + + # Extract value. + value = _getvalue(context.values, vn) + + # Update the value in `vi`. + # TODO: Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + logp = tilde_observe(context.context, right, value, vn, inds, vi) + return value, logp +end + +function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) + if haskey(context, vn) + # Defer to child context. + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) + end + + # If we're conditioning, then we just fall back to non-rng impl. + return tilde_assume(context, right, vn, inds, vi) +end + """ tilde_assume!(context, right, vn, inds, vi) @@ -142,28 +172,6 @@ function tilde_assume!(context, right, vn, inds, vi) return value end -function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - if haskey(context, vn) - # Extract value. - value = if inds isa Tuple{} - getfield(context.values, getsym(vn)) - else - _getindex(getfield(context.values, getsym(vn)), inds) - end - - # Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - tilde_observe!(context.context, right, value, vn, inds, vi) - else - value = tilde_assume!(context.context, right, vn, inds, vi) - end - - return value -end - # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -220,6 +228,14 @@ function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function tilde_observe(context::ConditionContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, vname, vi) +end +function tilde_observe(context::ConditionContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) +end + """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -248,10 +264,6 @@ function tilde_observe!(context, right, left, vi) return left end -function tilde_observe!(context::ConditionContext, right, left, vi) - return tilde_observe!(context.context, right, left, vi) -end - function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -440,6 +452,37 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, ) end +# `ConditionContext` +function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) + if !haskey(context, vn) + # Not conditioned on => defer to child context. + return dot_tilde_assume(context.context, right, left, vn, inds, vi) + end + + # Extract value. + # FIXME: Handle the case where `vn` is actually `AbstractArray{<:VarName}`. + value = _getvalue(context.values, vn) + + # Update the value in `vi`. + # TODO: Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + logp = dot_tilde_observe(context.context, right, left, value, vn, inds, vi) + return value, logp +end + +function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) + if !haskey(context, vn) + # Defer to child context. + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) + end + + # If we're conditioning, then we just fall back to non-rng impl. + return dot_tilde_assume(context, right, left, vn, inds, vi) +end + """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -454,28 +497,6 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end -function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - if haskey(context, vn) - # Extract value. - value = if inds isa Tuple{} - getfield(context.values, sym) - else - _getindex(getfield(context.values, sym), inds) - end - - # Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - dot_tilde_observe!(context.context, right, value, vn, inds, vi) - else - value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) - end - - return value -end - # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi From f9cdfa9910226a2ad790654f4ddeb390e531e95e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:21:42 +0100 Subject: [PATCH 16/41] added contextual_isassumption for PrefixContext --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 9eca8abae..6bcd68c7a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -36,6 +36,9 @@ end contextual_isassumption(context::AbstractContext, vn) = false contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) +function contextual_isassumption(context::PrefixContext, vn::VarName) + return contextual_isassumption(context.context, prefix(context, vn)) +end # failsafe: a literal is never an assumption isassumption(expr) = :(false) From 835a41ead81a2387f1113fb148d42a7851f2cd1d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:24:28 +0100 Subject: [PATCH 17/41] implemented contextual_isassumption for all contexts --- src/compiler.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6bcd68c7a..3ddbb6340 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -35,10 +35,19 @@ function isassumption(expr::Union{Symbol,Expr}) end contextual_isassumption(context::AbstractContext, vn) = false -contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) -function contextual_isassumption(context::PrefixContext, vn::VarName) +function contextual_isassumption(context::ConditionContext, vn) + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. + return !(haskey(context, vn)) || contextual_isassumption(context, vn) +end +function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(context.context, prefix(context, vn)) end +function contextual_isassumption(context::MiniBatchContext, vn) + return contextual_isassumption(context.context, vn) +end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end # failsafe: a literal is never an assumption isassumption(expr) = :(false) From 5d110d5cb4bb885d44ed8b8d33574f41aae22a4b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:32:54 +0100 Subject: [PATCH 18/41] improved the way ConditionContext works signficantly --- src/compiler.jl | 72 ++++++++++++++++++++++++---------- src/context_implementations.jl | 51 ++---------------------- src/contexts.jl | 6 +++ src/loglikelihoods.jl | 8 ++++ 4 files changed, 70 insertions(+), 67 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 3ddbb6340..d640ef1d7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -20,24 +20,52 @@ function isassumption(expr::Union{Symbol,Expr}) return quote let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - $(DynamicPPL.contextual_isassumption)(__context__, $vn) - true + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) + true + else + $expr === missing + end else - # Evaluate the LHS - $(maybe_view(expr)) === missing + false end end end end -contextual_isassumption(context::AbstractContext, vn) = false +""" + contextual_isassumption(context, vn) + +Return `true` if `vn` is considered an assumption by `context`. + +The default implementation for `AbstractContext` always returns `true`. +""" +contextual_isassumption(context::AbstractContext, vn) = true function contextual_isassumption(context::ConditionContext, vn) # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. - return !(haskey(context, vn)) || contextual_isassumption(context, vn) + + # We have either of the following cases: + # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, + # which means we have a value to replace with and we don't need to recurse. + # 2. One of the decendant contexts consider it as an observation, i.e. + # `contextual_isassumption` evaluates to `false`. + # The below then evaluates to `!(false || true) === false`. + # 3. Neither `context` nor any of it's decendants considers it an observation, + # in which case the below evaluates to `!(false || false) === true`. + return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) end function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(context.context, prefix(context, vn)) @@ -45,9 +73,6 @@ end function contextual_isassumption(context::MiniBatchContext, vn) return contextual_isassumption(context.context, vn) end -function contextual_isassumption(context::PointwiseLikelihoodContext, vn) - return contextual_isassumption(context.context, vn) -end # failsafe: a literal is never an assumption isassumption(expr) = :(false) @@ -348,6 +373,11 @@ function generate_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left = $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -392,6 +422,11 @@ function generate_dot_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left .= $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -453,14 +488,11 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.condition)( - $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, - ), + return $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, $allargs_namedtuple, + $defaults_namedtuple, ) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 54f6d5965..e632a4c95 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -129,32 +129,11 @@ end # `ConditionContext` function tilde_assume(context::ConditionContext, right, vn, inds, vi) - if !haskey(context, vn) - # Not conditioned on => defer to child context. - return tilde_assume(context.context, right, vn, inds, vi) - end - - # Extract value. - value = _getvalue(context.values, vn) - - # Update the value in `vi`. - # TODO: Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - logp = tilde_observe(context.context, right, value, vn, inds, vi) - return value, logp + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) - if haskey(context, vn) - # Defer to child context. - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) - end - - # If we're conditioning, then we just fall back to non-rng impl. - return tilde_assume(context, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) end """ @@ -454,33 +433,11 @@ end # `ConditionContext` function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) - if !haskey(context, vn) - # Not conditioned on => defer to child context. - return dot_tilde_assume(context.context, right, left, vn, inds, vi) - end - - # Extract value. - # FIXME: Handle the case where `vn` is actually `AbstractArray{<:VarName}`. - value = _getvalue(context.values, vn) - - # Update the value in `vi`. - # TODO: Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - logp = dot_tilde_observe(context.context, right, left, value, vn, inds, vi) - return value, logp + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) - if !haskey(context, vn) - # Defer to child context. - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) - end - - # If we're conditioning, then we just fall back to non-rng impl. - return dot_tilde_assume(context, right, left, vn, inds, vi) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index 96bd43cd7..f58242b5c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -153,6 +153,12 @@ function ConditionContext( return ConditionContext(merge(context.values, values), context.context) end +function getvalue(context::ConditionContext, vn) + haskey(context, vn) ? _getvalue(context.values, vn) : getvalue(context.context, vn) +end +getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) +getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) + function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6c66e4ec4..2203eda13 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -13,6 +13,14 @@ function PointwiseLikelihoodContext( ) end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end + +function contextual_isobservation(context::PointwiseLikelihoodContext, vn) + return contextual_isobservation(context.context, vn) +end + function Base.push!( context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, From 560ca83ecaf57eaa2b2cdb3d68b8b2f39e5854bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:35:34 +0100 Subject: [PATCH 19/41] forgot impl of dot_tilde_observe for ConditionContext --- src/context_implementations.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e632a4c95..e464b69de 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -646,6 +646,11 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function dot_tilde_observe(context::ConditionContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) +end + """ dot_tilde_observe!(context, right, left, vname, vinds, vi) @@ -672,9 +677,6 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end -function dot_tilde_observe!(context::ConditionContext, right, left, vi) - return dot_tilde_observe!(context.context, right, left, vi) -end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) From 5635c3b530a2b3085b4787c6e491950e2ade722e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:39:24 +0100 Subject: [PATCH 20/41] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 4 +++- src/contexts.jl | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e464b69de..a0b3c8e77 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -436,7 +436,9 @@ function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, left, vn, inds, vi) end -function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) +function dot_tilde_assume( + rng, context::ConditionContext, sampler, right, left, vn, inds, vi +) return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index f58242b5c..f48a5187c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -154,7 +154,11 @@ function ConditionContext( end function getvalue(context::ConditionContext, vn) - haskey(context, vn) ? _getvalue(context.values, vn) : getvalue(context.context, vn) + return if haskey(context, vn) + _getvalue(context.values, vn) + else + getvalue(context.context, vn) + end end getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) From e78dc6546c08d960d5ae3e6a95ca3d0e5f4f056f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Jul 2021 02:38:40 +0100 Subject: [PATCH 21/41] overload _evaluate rather than the model-call directly --- src/contextual_model.jl | 6 ++---- src/model.jl | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 382c45034..4b843ed0e 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -9,12 +9,10 @@ end # TODO: What do we do for other contexts? Could handle this in general if we had a # notion of wrapper-, primitive-context, etc. -function (cmodel::ContextualModel{<:ConditionContext})( - varinfo::AbstractVarInfo, context::AbstractContext -) +function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) diff --git a/src/model.jl b/src/model.jl index e5b9beb24..7b5d8918a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -93,7 +93,7 @@ function (model::AbstractModel)( end (model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) +function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else From 1d3b11e9caa2d2e2184940b42eadf6a176b7ccb1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 06:51:20 +0100 Subject: [PATCH 22/41] drop now unneceesary impls for tilds for ConditionContext --- src/context_implementations.jl | 33 --------------------------------- src/contexts.jl | 3 +++ 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8e2b02851..ce3f89739 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -127,15 +127,6 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end -# `ConditionContext` -function tilde_assume(context::ConditionContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) -end - """ tilde_assume!(context, right, vn, inds, vi) @@ -204,14 +195,6 @@ function tilde_observe(context::PrefixContext, right, left, vname, vi) return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -# `ConditionContext` -function tilde_observe(context::ConditionContext, right, left, vname, vi) - return tilde_observe(context.context, right, left, vname, vi) -end -function tilde_observe(context::ConditionContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end - """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -428,17 +411,6 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, ) end -# `ConditionContext` -function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - -function dot_tilde_assume( - rng, context::ConditionContext, sampler, right, left, vn, inds, vi -) - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) -end - """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -640,11 +612,6 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end -# `ConditionContext` -function dot_tilde_observe(context::ConditionContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vi) -end - """ dot_tilde_observe!(context, right, left, vname, vinds, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 98790a785..947f31538 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -212,3 +212,6 @@ function decondition(context::ConditionContext, sym, syms...) ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end + +NodeTrait(context::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context From e1a7d38d33b91cfa0acb1b36e31d6fda0e50579f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 06:52:02 +0100 Subject: [PATCH 23/41] formatting --- src/contextual_model.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 4b843ed0e..793491720 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,9 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate( + cmodel.model, varinfo, ConditionContext(cmodel.context.values, context) + ) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From b42c34f346c69c4b5ab615ef0306214e74e1a29b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:03:35 +0100 Subject: [PATCH 24/41] address issues using traits --- src/compiler.jl | 15 +++++++++------ src/contexts.jl | 34 ++++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d640ef1d7..eeced2701 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -53,7 +53,13 @@ Return `true` if `vn` is considered an assumption by `context`. The default implementation for `AbstractContext` always returns `true`. """ -contextual_isassumption(context::AbstractContext, vn) = true +contextual_isassumption(::IsLeaf, context, vn) = true +function contextual_isassumption(::IsParent, context, vn) + return contextual_isassumption(childcontext(context), vn) +end +function contextual_isassumption(context::AbstractContext, vn) + return contextual_isassumption(NodeTrait(context), context, vn) +end function contextual_isassumption(context::ConditionContext, vn) # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. @@ -65,13 +71,10 @@ function contextual_isassumption(context::ConditionContext, vn) # The below then evaluates to `!(false || true) === false`. # 3. Neither `context` nor any of it's decendants considers it an observation, # in which case the below evaluates to `!(false || false) === true`. - return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) + return !(haskey(context, vn) || !contextual_isassumption(childcontext(context), vn)) end function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(context.context, prefix(context, vn)) -end -function contextual_isassumption(context::MiniBatchContext, vn) - return contextual_isassumption(context.context, vn) + return contextual_isassumption(childcontext(context), prefix(context, vn)) end # failsafe: a literal is never an assumption diff --git a/src/contexts.jl b/src/contexts.jl index 947f31538..38cbe213e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -174,19 +174,34 @@ function ConditionContext( ) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. - return ConditionContext(merge(context.values, values), context.context) + return ConditionContext(merge(context.values, values), childcontext(context)) end +""" + getvalue(context, vn) + +Return the value of the parameter corresponding to `vn` from `context`. +If `context` does not contain the value for `vn`, then `nothing` is returned, +e.g. [`DefaultContext`](@ref) will always return `nothing`. +""" +getvalue(::IsLeaf, context, vn) = nothing +getvalue(::IsParent, context, vn) = getvalue(childcontext(context), vn) +getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), context, vn) +getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn)) function getvalue(context::ConditionContext, vn) return if haskey(context, vn) _getvalue(context.values, vn) else - getvalue(context.context, vn) + getvalue(childcontext(context), vn) end end -getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) -getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) +# General implementations of `haskey`. +Base.haskey(::IsLeaf, context, vn) = false +Base.haskey(::IsParent, context, vn) = Base.haskey(childcontext(context), vn) +Base.haskey(context::AbstractContext, vn) = Base.haskey(NodeTrait(context), context, vn) + +# Specific to `ConditionContext`. function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars @@ -201,15 +216,18 @@ end # TODO: Can we maybe do this in a better way? # When no second argument is given, we remove _all_ conditioned variables. -# TODO: Should we remove this and just return `context.context`? +# TODO: Should we remove this and just return `childcontext(context)`? # That will work better if `Model` becomes like `ContextualModel`. -decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext) + return ConditionContext(NamedTuple(), childcontext(context)) +end function decondition(context::ConditionContext, sym) - return ConditionContext(BangBang.delete!!(context.values, sym), context.context) + return ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... + ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)), + syms..., ) end From c7c60e6a435ef89fe7bebd39bfe76df08f0d8ea9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:23:59 +0100 Subject: [PATCH 25/41] added rewrap for contexts --- src/contexts.jl | 62 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 38cbe213e..2f74847e6 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -8,9 +8,44 @@ struct IsParent end NodeTrait(f, context) Specifies the role of `context` in the context-tree. + +The officially supported traits are: +- `IsLeaf`: `context` does not have any decendants. +- `IsParent`: `context` has a child context to which we often defer. + Expects the following methods to be implemented: + - [`childcontext`](@ref) + - [`rewrap`](@ref) """ NodeTrait(_, context) = NodeTrait(context) +""" + childcontext(context) + +Return the descendant context of `context`. +""" +childcontext + +""" + rewrap(parent::AbstractContext, child::AbstractContext) + +Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), +effectively updating the child context. + +# Examples +```jldoctest +julia> ctx = SamplingContext(); + +julia> DynamicPPL.childcontext(ctx) +DefaultContext() + +julia> ctx_prior = DynamicPPL.rewrap(ctx, PriorContext()); # only compute the logprior + +julia> DynamicPPL.childcontext(ctx_prior) +PriorContext{Nothing}(nothing) +``` +""" +rewrap + # Contexts """ SamplingContext(rng, sampler, context) @@ -25,8 +60,14 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end +SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context) +SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context) +SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext()) +SamplingContext() = SamplingContext(SampleFromPrior()) + NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context +rewrap(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) """ struct DefaultContext <: AbstractContext end @@ -87,6 +128,7 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context +rewrap(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) """ PrefixContext{Prefix}(context) @@ -108,6 +150,7 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context +rewrap(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) const PREFIX_SEPARATOR = Symbol(".") @@ -156,9 +199,7 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -function ConditionContext(context::ConditionContext, child_context::AbstractContext) - return ConditionContext(context.values, child_context) -end +ConditionContext(; values...) = ConditionContext((; values...)) function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -177,6 +218,10 @@ function ConditionContext( return ConditionContext(merge(context.values, values), childcontext(context)) end +NodeTrait(context::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context +rewrap(parent::ConditionContext, child) = ConditionContext(parent.values, child) + """ getvalue(context, vn) @@ -214,10 +259,10 @@ function Base.haskey( return sym in vars end -# TODO: Can we maybe do this in a better way? -# When no second argument is given, we remove _all_ conditioned variables. -# TODO: Should we remove this and just return `childcontext(context)`? -# That will work better if `Model` becomes like `ContextualModel`. +# Recursively `decondition` the context. +decondition(::IsLeaf, context) = context +decondition(::IsParent, context) = rewrap(context, decondition(childcontext(context))) +decondition(context) = decondition(NodeTrait(context), context) function decondition(context::ConditionContext) return ConditionContext(NamedTuple(), childcontext(context)) end @@ -230,6 +275,3 @@ function decondition(context::ConditionContext, sym, syms...) syms..., ) end - -NodeTrait(context::ConditionContext) = IsParent() -childcontext(context::ConditionContext) = context.context From be67807948b72b0ef47a64d071835ed538384b0d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:31:47 +0100 Subject: [PATCH 26/41] do decondition properly --- src/contexts.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 2f74847e6..648039694 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -260,18 +260,25 @@ function Base.haskey( end # Recursively `decondition` the context. -decondition(::IsLeaf, context) = context -decondition(::IsParent, context) = rewrap(context, decondition(childcontext(context))) -decondition(context) = decondition(NodeTrait(context), context) +decondition(::IsLeaf, context, args...) = context +function decondition(::IsParent, context, args...) + return rewrap(context, decondition(childcontext(context), args...)) +end +decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) - return ConditionContext(NamedTuple(), childcontext(context)) + return ConditionContext(NamedTuple(), decondition(childcontext(context))) end function decondition(context::ConditionContext, sym) - return ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)) + return ConditionContext( + BangBang.delete!!(context.values, sym), childcontext(context, sym) + ) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)), + ConditionContext( + BangBang.delete!!(context.values, sym), + decondition(childcontext(context), syms...), + ), syms..., ) end From 0468297876eaf773ef823adf665a510fec226a74 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:46:53 +0100 Subject: [PATCH 27/41] added some examples and decondition now removes ConditionContext --- src/contexts.jl | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 648039694..7857c1f92 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -199,7 +199,7 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -ConditionContext(; values...) = ConditionContext((; values...)) +ConditionContext(context=DefaultContext(); values...) = ConditionContext((; values...), context) function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -259,14 +259,46 @@ function Base.haskey( return sym in vars end -# Recursively `decondition` the context. +""" + context([context::AbstractContext,] values::NamedTuple) + context([context::AbstractContext]; values...) + +Return `ConditionContext` with `values` and wrapping `context`. +""" +condition(context=DefaultContext(); values...) = ConditionContext(context; values...) + +""" + decondition(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer conditioned on. + +Note that this recursively traverses contexts, deconditioning all along the way. + +# Examples +```jldoctest +julia> ctx = DefaultContext(); + +julia> decondition(ctx) === ctx # this is a no-op +true + +julia> ctx = ConditionContext(x = 1.0); + +julia> decondition(ctx) +DefaultContext() + +julia> ctx_nested = ConditionContext(SamplingContext(ConditionContext(y=2.0)), x=1.0); + +julia> decondition(ctx_nested) +SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLOBAL_RNG(), SampleFromPrior(), DefaultContext()) +``` +""" decondition(::IsLeaf, context, args...) = context function decondition(::IsParent, context, args...) return rewrap(context, decondition(childcontext(context), args...)) end decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) - return ConditionContext(NamedTuple(), decondition(childcontext(context))) + return decondition(childcontext(context)) end function decondition(context::ConditionContext, sym) return ConditionContext( From 80e3d5fd48b3c46f4c11749ed380e51d08ed5fa0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:56:37 +0100 Subject: [PATCH 28/41] improved condition and decondition a bit further --- src/contexts.jl | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 7857c1f92..7bb437867 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -199,7 +199,9 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -ConditionContext(context=DefaultContext(); values...) = ConditionContext((; values...), context) +function ConditionContext(context=DefaultContext(); values...) + return ConditionContext((; values...), context) +end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -263,9 +265,18 @@ end context([context::AbstractContext,] values::NamedTuple) context([context::AbstractContext]; values...) -Return `ConditionContext` with `values` and wrapping `context`. +Return `ConditionContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`decondition`](@ref) """ -condition(context=DefaultContext(); values...) = ConditionContext(context; values...) +condition() = decondition(ConditionContext()) +condition(values::NamedTuple) = condition(DefaultContext(), values) +condition(context::AbstractContext, values::NamedTuple{()}) = context +condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) +function condition(context::AbstractContext=DefaultContext(); values...) + return ConditionContext(context; values...) +end """ decondition(context::AbstractContext, syms...) @@ -274,6 +285,8 @@ Return `context` but with `syms` no longer conditioned on. Note that this recursively traverses contexts, deconditioning all along the way. +See also: [`condition`](@ref) + # Examples ```jldoctest julia> ctx = DefaultContext(); @@ -301,15 +314,15 @@ function decondition(context::ConditionContext) return decondition(childcontext(context)) end function decondition(context::ConditionContext, sym) - return ConditionContext( - BangBang.delete!!(context.values, sym), childcontext(context, sym) + return condition( + decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym) ) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext( - BangBang.delete!!(context.values, sym), + condition( decondition(childcontext(context), syms...), + BangBang.delete!!(context.values, sym), ), syms..., ) From 9419e76bbbe2ffb7f708c67ac030d9769273cef6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:56:53 +0100 Subject: [PATCH 29/41] use rewrap in _evaluate for ContextualModel --- src/contextual_model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 4b843ed0e..883fc5379 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,7 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate(cmodel.model, varinfo, rewrap(cmodel.context, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From 4e566f78d4deff430b2ebe065ff8909b2f7a3b64 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:57:54 +0100 Subject: [PATCH 30/41] remove the drop_missing as it is no longer needed --- src/contexts.jl | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index f48a5187c..1c54cf86a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -118,30 +118,14 @@ struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext end end -@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} - names_expr = Expr(:tuple) - values_expr = Expr(:tuple) - - for (n, v) in zip(names, values.parameters) - if !(v <: Missing) - push!(names_expr.args, QuoteNode(n)) - push!(values_expr.args, :(nt.$n)) - end - end - - return :(NamedTuple{$names_expr}($values_expr)) -end - function ConditionContext(context::ConditionContext, child_context::AbstractContext) return ConditionContext(context.values, child_context) end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end - function ConditionContext(values::NamedTuple, context::AbstractContext) - values_wo_missing = drop_missings(values) - return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) + return ConditionContext{typeof(values)}(values, context) end # Try to avoid nested `ConditionContext`. From b27228aec8a15310f9076cc9623ec0de59f43221 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:08:14 +0100 Subject: [PATCH 31/41] rename rewrap to setchildcontet --- src/contexts.jl | 18 +++++++++--------- src/contextual_model.jl | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 00d6598af..33da613a7 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -14,7 +14,7 @@ The officially supported traits are: - `IsParent`: `context` has a child context to which we often defer. Expects the following methods to be implemented: - [`childcontext`](@ref) - - [`rewrap`](@ref) + - [`setchildcontext`](@ref) """ NodeTrait(_, context) = NodeTrait(context) @@ -26,7 +26,7 @@ Return the descendant context of `context`. childcontext """ - rewrap(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. @@ -38,13 +38,13 @@ julia> ctx = SamplingContext(); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.rewrap(ctx, PriorContext()); # only compute the logprior +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior julia> DynamicPPL.childcontext(ctx_prior) PriorContext{Nothing}(nothing) ``` """ -rewrap +setchildcontext # Contexts """ @@ -67,7 +67,7 @@ SamplingContext() = SamplingContext(SampleFromPrior()) NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context -rewrap(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) +setchildcontext(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) """ struct DefaultContext <: AbstractContext end @@ -128,7 +128,7 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context -rewrap(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) +setchildcontext(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) """ PrefixContext{Prefix}(context) @@ -150,7 +150,7 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -rewrap(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) +setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) const PREFIX_SEPARATOR = Symbol(".") @@ -206,7 +206,7 @@ end NodeTrait(context::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -rewrap(parent::ConditionContext, child) = ConditionContext(parent.values, child) +setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) """ getvalue(context, vn) @@ -291,7 +291,7 @@ SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLO """ decondition(::IsLeaf, context, args...) = context function decondition(::IsParent, context, args...) - return rewrap(context, decondition(childcontext(context), args...)) + return setchildcontext(context, decondition(childcontext(context), args...)) end decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 883fc5379..3ab931b65 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,7 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, rewrap(cmodel.context, context)) + return _evaluate(cmodel.model, varinfo, setchildcontext(cmodel.context, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From 4935d5c610be95bcb073f1776e3ff462160ef371 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:08:33 +0100 Subject: [PATCH 32/41] formatting --- src/contexts.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 33da613a7..36c6485ba 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -67,7 +67,9 @@ SamplingContext() = SamplingContext(SampleFromPrior()) NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context -setchildcontext(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) +function setchildcontext(parent::SamplingContext, child) + return SamplingContext(parent.rng, parent.sampler, child) +end """ struct DefaultContext <: AbstractContext end @@ -128,7 +130,9 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context -setchildcontext(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) +function setchildcontext(parent::MiniBatchContext, child) + return MiniBatchContext(child, parent.loglike_scalar) +end """ PrefixContext{Prefix}(context) @@ -150,7 +154,9 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) +function setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} + return PrefixContext{Prefix}(child) +end const PREFIX_SEPARATOR = Symbol(".") From 61960832b75ff490862eaca5da6e186602fc3c93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:11:51 +0100 Subject: [PATCH 33/41] made show a bit nicer for ConditionContext --- src/contexts.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 1c54cf86a..4a6834b32 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -137,6 +137,10 @@ function ConditionContext( return ConditionContext(merge(context.values, values), context.context) end +function Base.show(io::IO, context::ConditionContext) + println(io, "ConditionContext($(context.values), $(context.context))") +end + function getvalue(context::ConditionContext, vn) return if haskey(context, vn) _getvalue(context.values, vn) From 65048fc4c57e1ec884cddbcdcd880269cfdc33a3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:12:32 +0100 Subject: [PATCH 34/41] use print instead of println in show --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 4a6834b32..fa0bf585a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -138,7 +138,7 @@ function ConditionContext( end function Base.show(io::IO, context::ConditionContext) - println(io, "ConditionContext($(context.values), $(context.context))") + print(io, "ConditionContext($(context.values), $(context.context))") end function getvalue(context::ConditionContext, vn) From 649af299a7f68c0def25026bbf06850aedbbfdbc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:14:15 +0100 Subject: [PATCH 35/41] formatting --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index fa0bf585a..926dd0fe0 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -138,7 +138,7 @@ function ConditionContext( end function Base.show(io::IO, context::ConditionContext) - print(io, "ConditionContext($(context.values), $(context.context))") + return print(io, "ConditionContext($(context.values), $(context.context))") end function getvalue(context::ConditionContext, vn) From 21c08e5852f05d3ee9ef1bdaa832efb66fb96cc9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:03:31 +0100 Subject: [PATCH 36/41] dont overload haskey and improved contextual_isassumption check --- src/compiler.jl | 22 +++++++++++----------- src/contexts.jl | 16 ++++++++-------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fb77d0323..c69a83803 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -61,17 +61,17 @@ function contextual_isassumption(context::AbstractContext, vn) return contextual_isassumption(NodeTrait(context), context, vn) end function contextual_isassumption(context::ConditionContext, vn) - # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. - - # We have either of the following cases: - # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, - # which means we have a value to replace with and we don't need to recurse. - # 2. One of the decendant contexts consider it as an observation, i.e. - # `contextual_isassumption` evaluates to `false`. - # The below then evaluates to `!(false || true) === false`. - # 3. Neither `context` nor any of it's decendants considers it an observation, - # in which case the below evaluates to `!(false || false) === true`. - return !(haskey(context, vn) || !contextual_isassumption(childcontext(context), vn)) + if hasvalue(context, vn) + val = getvalue(context, vn) + # TODO: Do we even need the `>: Missing` to help the compiler? + if eltype(val) >: Missing && val === missing + return true + end + end + + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}` + # so we defer to `childcontext` if we haven't concluded that anything yet. + return contextual_isassumption(childcontext(context), vn) end function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(childcontext(context), prefix(context, vn)) diff --git a/src/contexts.jl b/src/contexts.jl index fb26b2757..1c63df0c5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -294,7 +294,7 @@ getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn)) function getvalue(context::ConditionContext, vn) - return if haskey(context, vn) + return if hasvalue(context, vn) _getvalue(context.values, vn) else getvalue(childcontext(context), vn) @@ -302,17 +302,17 @@ function getvalue(context::ConditionContext, vn) end # General implementations of `haskey`. -Base.haskey(::IsLeaf, context, vn) = false -Base.haskey(::IsParent, context, vn) = Base.haskey(childcontext(context), vn) -Base.haskey(context::AbstractContext, vn) = Base.haskey(NodeTrait(context), context, vn) +hasvalue(::IsLeaf, context, vn) = false +hasvalue(::IsParent, context, vn) = hasvalue(childcontext(context), vn) +hasvalue(context::AbstractContext, vn) = hasvalue(NodeTrait(context), context, vn) # Specific to `ConditionContext`. -function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} +function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end -function Base.haskey( +function hasvalue( context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} ) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. @@ -332,8 +332,8 @@ condition() = decondition(ConditionContext()) condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) -function condition(context::AbstractContext=DefaultContext(); values...) - return ConditionContext(context; values...) +function condition(context::AbstractContext; values...) + return condition(context, (; values...)) end """ From b201ef3bb76ad20fb76dcc38c37b91ed486cb5c9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:15:51 +0100 Subject: [PATCH 37/41] removed leftovers --- Project.toml | 1 - src/DynamicPPL.jl | 2 -- 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index ecf47b50f..5678c050a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" -BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bfcc49e6b..e7732d8c5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,7 +6,6 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC -using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -134,6 +133,5 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -include("contextual_model.jl") end # module From 8cc3193d6909be89ad772ac7eb19cee121d128d5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:37:01 +0100 Subject: [PATCH 38/41] improved condition and fixed isassumption --- src/DynamicPPL.jl | 1 + src/compiler.jl | 2 +- src/contexts.jl | 63 ++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e7732d8c5..2f0ef2617 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -9,6 +9,7 @@ using AbstractMCMC: AbstractMCMC using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules +using BangBang: BangBang using Random: Random diff --git a/src/compiler.jl b/src/compiler.jl index c69a83803..6e10c4285 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -34,7 +34,7 @@ function isassumption(expr::Union{Symbol,Expr}) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) + if !($(DynamicPPL.inargnames)($vn, __model__)) || $(DynamicPPL.inmissings)($vn, __model__) true else $expr === missing diff --git a/src/contexts.jl b/src/contexts.jl index 1c63df0c5..d1b2145c2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -66,9 +66,9 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext -julia> struct ParentContext{C} +julia> struct ParentContext{C} <: AbstractContext context::C end @@ -328,14 +328,13 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ -condition() = decondition(ConditionContext()) +function condition(; values...) + return isempty(values) ? decondition(ConditionContext()) : condition(DefaultContext(), (; values...)) +end condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) -function condition(context::AbstractContext; values...) - return condition(context, (; values...)) -end - +condition(context::AbstractContext; values...) = condition(context, (; values...)) """ decondition(context::AbstractContext, syms...) @@ -347,20 +346,60 @@ See also: [`condition`](@ref) # Examples ```jldoctest -julia> ctx = DefaultContext(); +julia> using DynamicPPL: AbstractContext, leafcontext, setleafcontext, childcontext, setchildcontext + +julia> struct ParentContext{C} <: AbstractContext + context::C + end + +julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() + +julia> DynamicPPL.childcontext(context::ParentContext) = context.context + +julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) + +julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + +julia> ctx = DefaultContext() +DefaultContext() julia> decondition(ctx) === ctx # this is a no-op true -julia> ctx = ConditionContext(x = 1.0); +julia> ctx = condition(x = 1.0) # default "constructor" for `ConditionContext` +ConditionContext((x = 1.0,), DefaultContext()) -julia> decondition(ctx) +julia> decondition(ctx) # `decondition` without arguments drops all conditioning DefaultContext() -julia> ctx_nested = ConditionContext(SamplingContext(ConditionContext(y=2.0)), x=1.0); +julia> # Nested conditioning is supported. + ctx_nested = condition(ParentContext(condition(y=2.0)), x=1.0) +ConditionContext((x = 1.0,), ParentContext(ConditionContext((y = 2.0,), DefaultContext()))) + +julia> # We can also specify which variables to drop. + decondition(ctx_nested, :x) +ParentContext(ConditionContext((y = 2.0,), DefaultContext())) + +julia> # No matter the nested level. + decondition(ctx_nested, :y) +ConditionContext((x = 1.0,), ParentContext(DefaultContext())) + +julia> # Or specify multiple at in one call. + decondition(ctx_nested, :x, :y) +ParentContext(DefaultContext()) julia> decondition(ctx_nested) -SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLOBAL_RNG(), SampleFromPrior(), DefaultContext()) +ParentContext(DefaultContext()) + +julia> # `Val` is also supported. + decondition(ctx_nested, Val(:x)) +ParentContext(ConditionContext((y = 2.0,), DefaultContext())) + +julia> decondition(ctx_nested, Val(:y)) +ConditionContext((x = 1.0,), ParentContext(DefaultContext())) + +julia> decondition(ctx_nested, Val(:x), Val(:y)) +ParentContext(DefaultContext()) ``` """ decondition(::IsLeaf, context, args...) = context From 26edb2c80cc8562fdff1dfefde403ae02395d010 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:40:08 +0100 Subject: [PATCH 39/41] added BangBang as dep --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 5678c050a..fd514f836 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" AbstractPPL = "0.2" +BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" From e2f6fc5f8a2ddc9e6efb04b506c9fba50df28aca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:46:36 +0100 Subject: [PATCH 40/41] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 3 ++- src/contexts.jl | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6e10c4285..921062079 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -34,7 +34,8 @@ function isassumption(expr::Union{Symbol,Expr}) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || $(DynamicPPL.inmissings)($vn, __model__) + if !($(DynamicPPL.inargnames)($vn, __model__)) || + $(DynamicPPL.inmissings)($vn, __model__) true else $expr === missing diff --git a/src/contexts.jl b/src/contexts.jl index de7a44207..b045a17e2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -326,7 +326,11 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ function condition(; values...) - return isempty(values) ? decondition(ConditionContext()) : condition(DefaultContext(), (; values...)) + return if isempty(values) + decondition(ConditionContext()) + else + condition(DefaultContext(), (; values...)) + end end condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context From ce160d682c420127126ae7c53d9f2f6ca47e1373 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:53:18 +0100 Subject: [PATCH 41/41] fixed condition without arguments --- src/contexts.jl | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index b045a17e2..ca3c6b4c2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -251,9 +251,6 @@ struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext end end -function ConditionContext(context::ConditionContext, child_context::AbstractContext) - return ConditionContext(context.values, child_context) -end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -325,13 +322,7 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ -function condition(; values...) - return if isempty(values) - decondition(ConditionContext()) - else - condition(DefaultContext(), (; values...)) - end -end +condition(; values...) = condition(DefaultContext(), (; values...)) condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context)