From cd1c46de8bef1242f22ea9303ff1faca0bb972c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:08:01 +0100 Subject: [PATCH 01/86] added ConditionContext and ContextualModel --- Project.toml | 1 + src/DynamicPPL.jl | 6 ++++ src/context_implementations.jl | 51 ++++++++++++++++++++++++++++++++++ src/contexts.jl | 42 ++++++++++++++++++++++++++++ src/model.jl | 14 ++++++---- src/varinfo.jl | 6 ++-- 6 files changed, 111 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 921dc054d..7cc3db31d 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.12.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index a46c941a1..bfcc49e6b 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,6 +6,7 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC +using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -67,6 +68,7 @@ export AbstractVarInfo, vectorize, # Model Model, + ContextualModel, getmissings, getargnames, generated_quantities, @@ -81,6 +83,7 @@ export AbstractVarInfo, PriorContext, MiniBatchContext, PrefixContext, + ConditionContext, assume, dot_assume, observe, @@ -99,6 +102,8 @@ export AbstractVarInfo, logprior, logjoint, pointwise_loglikelihoods, + condition, + decondition, # Convenience macros @addlogprob!, @submodel @@ -129,5 +134,6 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +include("contextual_model.jl") end # module diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..68165a64f 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -133,6 +133,28 @@ function tilde_assume!(context, right, vn, inds, vi) return value end +function tilde_assume!(context::ConditionContext, right, vn, inds, vi) + value = if haskey(context, vn) + # Extract value. + if inds isa Tuple{} + getfield(context.values, getsym(vn)) + else + _getindex(getfield(context.values, getsym(vn)), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + tilde_observe!(context.context, right, value, vn, inds, vi) + else + tilde_assume!(context.context, right, vn, inds, vi) + end + + return value +end + # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -217,6 +239,10 @@ function tilde_observe!(context, right, left, vi) return left end +function tilde_observe!(context::ConditionContext, right, left, vi) + return tilde_observe!(context.context, right, left, vi) +end + function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -419,6 +445,28 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end +function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) + value = if vn in context + # Extract value. + if inds isa Tuple{} + getfield(context.values, sym) + else + _getindex(getfield(context.values, sym), inds) + end + + # Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + dot_tilde_observe!(context.context, right, left, vn, inds, vi) + else + dot_tilde_assume!(context.context, right, left, vn, inds, vi) + end + + return value +end + # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi @@ -637,6 +685,9 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end +function dot_tilde_observe!(context::ConditionContext, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) +end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..946c22092 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -106,3 +106,45 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + + +struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext + values::Values + context::Ctx +end + +function ConditionContext(values::NamedTuple{Vars}) where {Vars} + return ConditionContext(values, DefaultContext()) +end + +function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} + return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +end + +# Try to avoid nested `ConditionContext`. +function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} + # Note that this potentially overrides values from `context`, thus giving + # precedence to the outmost `ConditionContext`. + return ConditionContext(merge(context.values, values), context.context) +end + +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + +# TODO: Can we maybe do this in a better way? +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, sym, syms...) + return decondition( + ConditionContext(BangBang.delete!!(context.values, sym), context.context), + syms... + ) +end +decondition(context::ConditionContext) = context +function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} + return decondition( + ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), + syms... + ) +end diff --git a/src/model.jl b/src/model.jl index 9ec047a44..3ce707e2e 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,5 @@ +abstract type AbstractModel <: AbstractProbabilisticProgram end + """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -33,7 +35,7 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x ``` """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractProbabilisticProgram + AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -82,7 +84,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( +function (model::AbstractModel)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -91,7 +93,7 @@ function (model::Model)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::Model)(context::AbstractContext) = model(VarInfo(), context) +(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) @@ -100,17 +102,17 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) +function (model::AbstractModel)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end diff --git a/src/varinfo.jl b/src/varinfo.jl index fe3262dd5..be6551939 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::Model, + model::AbstractModel, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end From b1106ee65b60a8eebffada99b5c770a869730169 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:10:12 +0100 Subject: [PATCH 02/86] removed redundant definition --- src/contexts.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 946c22092..5abae4168 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -141,7 +141,6 @@ function decondition(context::ConditionContext, sym, syms...) syms... ) end -decondition(context::ConditionContext) = context function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} return decondition( ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), From 0f00771e8f391311cd6956ac33e3e31c770a2409 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:36:28 +0100 Subject: [PATCH 03/86] return condition model by default --- src/compiler.jl | 16 ++++++++++------ src/contexts.jl | 14 +++++++------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7466bc2c0..00be43551 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,7 +23,8 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) + $(DynamicPPL.inmissings)($vn, __model__) || + (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) true else # Evaluate the LHS @@ -427,11 +428,14 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, + return $(DynamicPPL.condition)( + $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, + $allargs_namedtuple, + $defaults_namedtuple, + ), + $allargs_namedtuple ) end diff --git a/src/contexts.jl b/src/contexts.jl index 5abae4168..0a0e977b9 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -134,16 +134,16 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v end # TODO: Can we maybe do this in a better way? -decondition(context::ConditionContext) = context +# When no second argument is given, we remove _all_ conditioned variables. +# TODO: Should we remove this and just return `context.context`? +# That will work better if `Model` becomes like `ContextualModel`. +decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext, sym) + return ConditionContext(BangBang.delete!!(context.values, sym), context.context) +end function decondition(context::ConditionContext, sym, syms...) return decondition( ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end -function decondition(context::ConditionContext, ::Val{sym}, syms...) where {sym} - return decondition( - ConditionContext(BangBang.delete!!(context.values, Val{sym}()), context.context), - syms... - ) -end From c754291411f3af1bef2468838493a750eaefa427 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 20:47:50 +0100 Subject: [PATCH 04/86] formatting --- src/compiler.jl | 9 ++++++--- src/contexts.jl | 6 ++---- src/model.jl | 3 +-- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 00be43551..191152f4e 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -23,8 +23,11 @@ function isassumption(expr::Union{Symbol,Expr}) # This branch should compile nicely in all cases except for partial missing data # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - (__context__ isa $(DynamicPPL.ConditionContext) && !$(Base.haskey)(__context__, $vn)) + $(DynamicPPL.inmissings)($vn, __model__) || + ( + __context__ isa $(DynamicPPL.ConditionContext) && + !$(Base.haskey)(__context__, $vn) + ) true else # Evaluate the LHS @@ -435,7 +438,7 @@ function build_output(modelinfo, linenumbernode) $allargs_namedtuple, $defaults_namedtuple, ), - $allargs_namedtuple + $allargs_namedtuple, ) end diff --git a/src/contexts.jl b/src/contexts.jl index 0a0e977b9..9831dde4a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,6 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end - struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -128,7 +127,7 @@ function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) w return ConditionContext(merge(context.values, values), context.context) end -function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars, sym} +function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end @@ -143,7 +142,6 @@ function decondition(context::ConditionContext, sym) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), context.context), - syms... + ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end diff --git a/src/model.jl b/src/model.jl index 3ce707e2e..39a9f4fd3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -34,8 +34,7 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: - AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} From 2d3f94cb6c137f6ce22683392cf4fcf0c579a47d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:15:23 +0100 Subject: [PATCH 05/86] forgot to include contextual model in previous commit --- src/contextual_model.jl | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 src/contextual_model.jl diff --git a/src/contextual_model.jl b/src/contextual_model.jl new file mode 100644 index 000000000..382c45034 --- /dev/null +++ b/src/contextual_model.jl @@ -0,0 +1,29 @@ +struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel + context::Ctx + model::M +end + +function contextualize(model::AbstractModel, context::AbstractContext) + return ContextualModel(context, model) +end + +# TODO: What do we do for other contexts? Could handle this in general if we had a +# notion of wrapper-, primitive-context, etc. +function (cmodel::ContextualModel{<:ConditionContext})( + varinfo::AbstractVarInfo, context::AbstractContext +) + # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as + # `ConditionContext` child. + return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) +end + +condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) +condition(model::AbstractModel; values...) = condition(model, (; values...)) +function condition(cmodel::ContextualModel{<:ConditionContext}, values) + return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) +end + +decondition(model::AbstractModel, args...) = model +function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) + return contextualize(cmodel.model, decondition(cmodel.context, syms...)) +end From 3e5a79f7e1b14ba610f88377ef9b9f177777becb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 22:45:45 +0100 Subject: [PATCH 06/86] fixed typos --- src/context_implementations.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 68165a64f..052b4ab9c 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -134,9 +134,9 @@ function tilde_assume!(context, right, vn, inds, vi) end function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - value = if haskey(context, vn) + if haskey(context, vn) # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, getsym(vn)) else _getindex(getfield(context.values, getsym(vn)), inds) @@ -149,7 +149,7 @@ function tilde_assume!(context::ConditionContext, right, vn, inds, vi) tilde_observe!(context.context, right, value, vn, inds, vi) else - tilde_assume!(context.context, right, vn, inds, vi) + value = tilde_assume!(context.context, right, vn, inds, vi) end return value @@ -446,9 +446,9 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - value = if vn in context + if vn in context # Extract value. - if inds isa Tuple{} + value = if inds isa Tuple{} getfield(context.values, sym) else _getindex(getfield(context.values, sym), inds) @@ -459,9 +459,9 @@ function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) vi[vn] = value end - dot_tilde_observe!(context.context, right, left, vn, inds, vi) + dot_tilde_observe!(context.context, right, value, vn, inds, vi) else - dot_tilde_assume!(context.context, right, left, vn, inds, vi) + value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) end return value From d4e42387888a4d7c9b2f9e62541a8de70fbb6f0f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 15 Jul 2021 23:06:59 +0100 Subject: [PATCH 07/86] added some niceties to ConditionContext --- src/contexts.jl | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 9831dde4a..e06d4c6eb 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -110,14 +110,38 @@ end struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx + + function ConditionContext{Values}( + values::Values, context::AbstractContext + ) where {names,Values<:NamedTuple{names}} + return new{names,typeof(values),typeof(context)}(values, context) + end +end + +@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} + names_expr = Expr(:tuple) + values_expr = Expr(:tuple) + + for (n, v) in zip(names, values.parameters) + if !(v <: Missing) + push!(names_expr.args, QuoteNode(n)) + push!(values_expr.args, :(nt.$n)) + end + end + + return :(NamedTuple{$names_expr}($values_expr)) end -function ConditionContext(values::NamedTuple{Vars}) where {Vars} +function ConditionContext(context::ConditionContext, child_context::AbstractContext) + return ConditionContext(context.values, child_context) +end +function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end -function ConditionContext(values::NamedTuple{Vars}, context) where {Vars} - return ConditionContext{Vars,typeof(values),typeof(context)}(values, context) +function ConditionContext(values::NamedTuple, context::AbstractContext) + values_wo_missing = drop_missings(values) + return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) end # Try to avoid nested `ConditionContext`. From 85a47eb098498df027ebaf2efe124517507933cb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 01:49:52 +0100 Subject: [PATCH 08/86] added support for vectors of VarName --- src/context_implementations.jl | 2 +- src/contexts.jl | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 052b4ab9c..813d40875 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -446,7 +446,7 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) end function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - if vn in context + if haskey(context, vn) # Extract value. value = if inds isa Tuple{} getfield(context.values, sym) diff --git a/src/contexts.jl b/src/contexts.jl index e06d4c6eb..500d3d60b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -156,6 +156,11 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v return sym in vars end +function Base.haskey(context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}) where {vars,sym} + # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. + return sym in vars +end + # TODO: Can we maybe do this in a better way? # When no second argument is given, we remove _all_ conditioned variables. # TODO: Should we remove this and just return `context.context`? From c7dae8daa1a66712328161be7a8b4fe94575deae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 16 Jul 2021 01:52:17 +0100 Subject: [PATCH 09/86] Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 500d3d60b..d0cd81287 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -156,7 +156,9 @@ function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {v return sym in vars end -function Base.haskey(context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}}) where {vars,sym} +function Base.haskey( + context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} +) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end From 75680b3d5715eb29239ca39ff7401d3ce966f194 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 05:12:00 +0100 Subject: [PATCH 10/86] upper-bound Distributions.jl in tests --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 146e92c55..37c509610 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -22,7 +22,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "2.1, 3.0" AbstractPPL = "0.1.3" Bijectors = "0.9.5" -Distributions = "0.24, 0.25" +Distributions = "< 0.25.11" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" From f0ae744e625adbe0776021a0991d0aeb14fd8790 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 05:33:42 +0100 Subject: [PATCH 11/86] make the isassumption check using context extensible and nicer --- src/compiler.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d0485dba9..9eca8abae 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -24,10 +24,7 @@ function isassumption(expr::Union{Symbol,Expr}) # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` if !$(DynamicPPL.inargnames)($vn, __model__) || $(DynamicPPL.inmissings)($vn, __model__) || - ( - __context__ isa $(DynamicPPL.ConditionContext) && - !$(Base.haskey)(__context__, $vn) - ) + $(DynamicPPL.contextual_isassumption)(__context__, $vn) true else # Evaluate the LHS @@ -37,6 +34,9 @@ function isassumption(expr::Union{Symbol,Expr}) end end +contextual_isassumption(context::AbstractContext, vn) = false +contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) + # failsafe: a literal is never an assumption isassumption(expr) = :(false) From b990bb05a3baec3f67c714fa11d529d974a97839 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 06:14:33 +0100 Subject: [PATCH 12/86] renamed type-parameter for ConditionContext --- src/contexts.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index d0cd81287..8bfc25a9b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -107,7 +107,7 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} end end -struct ConditionContext{Vars,Values,Ctx<:AbstractContext} <: AbstractContext +struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext values::Values context::Ctx @@ -145,7 +145,7 @@ function ConditionContext(values::NamedTuple, context::AbstractContext) end # Try to avoid nested `ConditionContext`. -function ConditionContext(values::NamedTuple{Vars}, context::ConditionContext) where {Vars} +function ConditionContext(values::NamedTuple{Names}, context::ConditionContext) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), context.context) From 8994cd78834cfa9408927dca7be95532a5a554c7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jul 2021 06:27:08 +0100 Subject: [PATCH 13/86] Update src/contexts.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 8bfc25a9b..96bd43cd7 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -145,7 +145,9 @@ function ConditionContext(values::NamedTuple, context::AbstractContext) end # Try to avoid nested `ConditionContext`. -function ConditionContext(values::NamedTuple{Names}, context::ConditionContext) where {Names} +function ConditionContext( + values::NamedTuple{Names}, context::ConditionContext +) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. return ConditionContext(merge(context.values, values), context.context) From 94da453a489102f2a326e777993e2c45001d966a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:17:50 +0100 Subject: [PATCH 14/86] introduced convenient _getvalue method --- src/context_implementations.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 813d40875..9f9ca3004 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,8 +14,17 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) +_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) _getindex(x, inds::Tuple{}) = x +_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) +function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} + val = getproperty(x, sym) + + # This should work with both cartesian and linear indexing. + return map(vns) do vn + _getindex(val, vn) + end +end # assume """ From 4e74cf8a6ab99de553b23ca898ce451a8d190aac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:18:04 +0100 Subject: [PATCH 15/86] overload tilde_assume rather than tilde_assume! and others for ConditionContext --- src/context_implementations.jl | 117 +++++++++++++++++++-------------- 1 file changed, 69 insertions(+), 48 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9f9ca3004..54f6d5965 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -127,6 +127,36 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end +# `ConditionContext` +function tilde_assume(context::ConditionContext, right, vn, inds, vi) + if !haskey(context, vn) + # Not conditioned on => defer to child context. + return tilde_assume(context.context, right, vn, inds, vi) + end + + # Extract value. + value = _getvalue(context.values, vn) + + # Update the value in `vi`. + # TODO: Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + logp = tilde_observe(context.context, right, value, vn, inds, vi) + return value, logp +end + +function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) + if haskey(context, vn) + # Defer to child context. + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) + end + + # If we're conditioning, then we just fall back to non-rng impl. + return tilde_assume(context, right, vn, inds, vi) +end + """ tilde_assume!(context, right, vn, inds, vi) @@ -142,28 +172,6 @@ function tilde_assume!(context, right, vn, inds, vi) return value end -function tilde_assume!(context::ConditionContext, right, vn, inds, vi) - if haskey(context, vn) - # Extract value. - value = if inds isa Tuple{} - getfield(context.values, getsym(vn)) - else - _getindex(getfield(context.values, getsym(vn)), inds) - end - - # Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - tilde_observe!(context.context, right, value, vn, inds, vi) - else - value = tilde_assume!(context.context, right, vn, inds, vi) - end - - return value -end - # observe """ tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) @@ -220,6 +228,14 @@ function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function tilde_observe(context::ConditionContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, vname, vi) +end +function tilde_observe(context::ConditionContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) +end + """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -248,10 +264,6 @@ function tilde_observe!(context, right, left, vi) return left end -function tilde_observe!(context::ConditionContext, right, left, vi) - return tilde_observe!(context.context, right, left, vi) -end - function assume(rng, spl::Sampler, dist) return error("DynamicPPL.assume: unmanaged inference algorithm: $(typeof(spl))") end @@ -440,6 +452,37 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, ) end +# `ConditionContext` +function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) + if !haskey(context, vn) + # Not conditioned on => defer to child context. + return dot_tilde_assume(context.context, right, left, vn, inds, vi) + end + + # Extract value. + # FIXME: Handle the case where `vn` is actually `AbstractArray{<:VarName}`. + value = _getvalue(context.values, vn) + + # Update the value in `vi`. + # TODO: Should we even do this? + if haskey(vi, vn) + vi[vn] = value + end + + logp = dot_tilde_observe(context.context, right, left, value, vn, inds, vi) + return value, logp +end + +function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) + if !haskey(context, vn) + # Defer to child context. + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) + end + + # If we're conditioning, then we just fall back to non-rng impl. + return dot_tilde_assume(context, right, left, vn, inds, vi) +end + """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -454,28 +497,6 @@ function dot_tilde_assume!(context, right, left, vn, inds, vi) return value end -function dot_tilde_assume!(context::ConditionContext, right, left, vn, inds, vi) - if haskey(context, vn) - # Extract value. - value = if inds isa Tuple{} - getfield(context.values, sym) - else - _getindex(getfield(context.values, sym), inds) - end - - # Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - dot_tilde_observe!(context.context, right, value, vn, inds, vi) - else - value = dot_tilde_assume!(context.context, right, left, vn, inds, vi) - end - - return value -end - # `dot_assume` function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi From f9cdfa9910226a2ad790654f4ddeb390e531e95e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:21:42 +0100 Subject: [PATCH 16/86] added contextual_isassumption for PrefixContext --- src/compiler.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 9eca8abae..6bcd68c7a 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -36,6 +36,9 @@ end contextual_isassumption(context::AbstractContext, vn) = false contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) +function contextual_isassumption(context::PrefixContext, vn::VarName) + return contextual_isassumption(context.context, prefix(context, vn)) +end # failsafe: a literal is never an assumption isassumption(expr) = :(false) From 835a41ead81a2387f1113fb148d42a7851f2cd1d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 16:24:28 +0100 Subject: [PATCH 17/86] implemented contextual_isassumption for all contexts --- src/compiler.jl | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6bcd68c7a..3ddbb6340 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -35,10 +35,19 @@ function isassumption(expr::Union{Symbol,Expr}) end contextual_isassumption(context::AbstractContext, vn) = false -contextual_isassumption(context::ConditionContext, vn::VarName) = !(haskey(context, vn)) -function contextual_isassumption(context::PrefixContext, vn::VarName) +function contextual_isassumption(context::ConditionContext, vn) + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. + return !(haskey(context, vn)) || contextual_isassumption(context, vn) +end +function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(context.context, prefix(context, vn)) end +function contextual_isassumption(context::MiniBatchContext, vn) + return contextual_isassumption(context.context, vn) +end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end # failsafe: a literal is never an assumption isassumption(expr) = :(false) From 5d110d5cb4bb885d44ed8b8d33574f41aae22a4b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:32:54 +0100 Subject: [PATCH 18/86] improved the way ConditionContext works signficantly --- src/compiler.jl | 72 ++++++++++++++++++++++++---------- src/context_implementations.jl | 51 ++---------------------- src/contexts.jl | 6 +++ src/loglikelihoods.jl | 8 ++++ 4 files changed, 70 insertions(+), 67 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 3ddbb6340..d640ef1d7 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -20,24 +20,52 @@ function isassumption(expr::Union{Symbol,Expr}) return quote let $vn = $(varname(expr)) - # This branch should compile nicely in all cases except for partial missing data - # For example, when `expr` is `:(x[i])` and `x isa Vector{Union{Missing, Float64}}` - if !$(DynamicPPL.inargnames)($vn, __model__) || - $(DynamicPPL.inmissings)($vn, __model__) || - $(DynamicPPL.contextual_isassumption)(__context__, $vn) - true + if $(DynamicPPL.contextual_isassumption)(__context__, $vn) + # Considered an assumption by `__context__` which means either: + # 1. We hit the default implementation, e.g. using `DefaultContext`, + # which in turn means that we haven't considered if it's one of + # the model arguments, hence we need to check this. + # 2. We are working with a `ConditionContext` _and_ it's NOT in the model arguments, + # i.e. we're trying to condition one of the latent variables. + # In this case, the below will return `true` since the first branch + # will be hit. + # 3. We are working with a `ConditionContext` _and_ it's in the model arguments, + # i.e. we're trying to override the value. This is currently NOT supported. + # TODO: Support by adding context to model, and use `model.args` + # as the default conditioning. Then we no longer need to check `inargnames` + # since it will all be handled by `contextual_isassumption`. + if !($(DynamicPPL.inargnames)($vn, __model__)) + true + else + $expr === missing + end else - # Evaluate the LHS - $(maybe_view(expr)) === missing + false end end end end -contextual_isassumption(context::AbstractContext, vn) = false +""" + contextual_isassumption(context, vn) + +Return `true` if `vn` is considered an assumption by `context`. + +The default implementation for `AbstractContext` always returns `true`. +""" +contextual_isassumption(context::AbstractContext, vn) = true function contextual_isassumption(context::ConditionContext, vn) # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. - return !(haskey(context, vn)) || contextual_isassumption(context, vn) + + # We have either of the following cases: + # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, + # which means we have a value to replace with and we don't need to recurse. + # 2. One of the decendant contexts consider it as an observation, i.e. + # `contextual_isassumption` evaluates to `false`. + # The below then evaluates to `!(false || true) === false`. + # 3. Neither `context` nor any of it's decendants considers it an observation, + # in which case the below evaluates to `!(false || false) === true`. + return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) end function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(context.context, prefix(context, vn)) @@ -45,9 +73,6 @@ end function contextual_isassumption(context::MiniBatchContext, vn) return contextual_isassumption(context.context, vn) end -function contextual_isassumption(context::PointwiseLikelihoodContext, vn) - return contextual_isassumption(context.context, vn) -end # failsafe: a literal is never an assumption isassumption(expr) = :(false) @@ -348,6 +373,11 @@ function generate_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left = $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -392,6 +422,11 @@ function generate_dot_tilde(left, right) __varinfo__, ) else + # If `vn` is not in `argnames`, we need to make sure that the variable is defined. + if !$(DynamicPPL.inargnames)($vn, __model__) + $left .= $(DynamicPPL.getvalue)(__context__, $vn) + end + $(DynamicPPL.dot_tilde_observe!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), @@ -453,14 +488,11 @@ function build_output(modelinfo, linenumbernode) modeldef[:body] = MacroTools.@q begin $(linenumbernode) $evaluator = $(MacroTools.combinedef(evaluatordef)) - return $(DynamicPPL.condition)( - $(DynamicPPL.Model)( - $(QuoteNode(modeldef[:name])), - $evaluator, - $allargs_namedtuple, - $defaults_namedtuple, - ), + return $(DynamicPPL.Model)( + $(QuoteNode(modeldef[:name])), + $evaluator, $allargs_namedtuple, + $defaults_namedtuple, ) end diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 54f6d5965..e632a4c95 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -129,32 +129,11 @@ end # `ConditionContext` function tilde_assume(context::ConditionContext, right, vn, inds, vi) - if !haskey(context, vn) - # Not conditioned on => defer to child context. - return tilde_assume(context.context, right, vn, inds, vi) - end - - # Extract value. - value = _getvalue(context.values, vn) - - # Update the value in `vi`. - # TODO: Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - logp = tilde_observe(context.context, right, value, vn, inds, vi) - return value, logp + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) - if haskey(context, vn) - # Defer to child context. - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) - end - - # If we're conditioning, then we just fall back to non-rng impl. - return tilde_assume(context, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) end """ @@ -454,33 +433,11 @@ end # `ConditionContext` function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) - if !haskey(context, vn) - # Not conditioned on => defer to child context. - return dot_tilde_assume(context.context, right, left, vn, inds, vi) - end - - # Extract value. - # FIXME: Handle the case where `vn` is actually `AbstractArray{<:VarName}`. - value = _getvalue(context.values, vn) - - # Update the value in `vi`. - # TODO: Should we even do this? - if haskey(vi, vn) - vi[vn] = value - end - - logp = dot_tilde_observe(context.context, right, left, value, vn, inds, vi) - return value, logp + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) - if !haskey(context, vn) - # Defer to child context. - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) - end - - # If we're conditioning, then we just fall back to non-rng impl. - return dot_tilde_assume(context, right, left, vn, inds, vi) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) end """ diff --git a/src/contexts.jl b/src/contexts.jl index 96bd43cd7..f58242b5c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -153,6 +153,12 @@ function ConditionContext( return ConditionContext(merge(context.values, values), context.context) end +function getvalue(context::ConditionContext, vn) + haskey(context, vn) ? _getvalue(context.values, vn) : getvalue(context.context, vn) +end +getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) +getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) + function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6c66e4ec4..2203eda13 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -13,6 +13,14 @@ function PointwiseLikelihoodContext( ) end +function contextual_isassumption(context::PointwiseLikelihoodContext, vn) + return contextual_isassumption(context.context, vn) +end + +function contextual_isobservation(context::PointwiseLikelihoodContext, vn) + return contextual_isobservation(context.context, vn) +end + function Base.push!( context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, From 560ca83ecaf57eaa2b2cdb3d68b8b2f39e5854bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:35:34 +0100 Subject: [PATCH 19/86] forgot impl of dot_tilde_observe for ConditionContext --- src/context_implementations.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e632a4c95..e464b69de 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -646,6 +646,11 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end +# `ConditionContext` +function dot_tilde_observe(context::ConditionContext, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) +end + """ dot_tilde_observe!(context, right, left, vname, vinds, vi) @@ -672,9 +677,6 @@ function dot_tilde_observe!(context, right, left, vi) acclogp!(vi, logp) return left end -function dot_tilde_observe!(context::ConditionContext, right, left, vi) - return dot_tilde_observe!(context.context, right, left, vi) -end # Falls back to non-sampler definition. function dot_observe(::AbstractSampler, dist, value, vi) From 5635c3b530a2b3085b4787c6e491950e2ade722e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 21 Jul 2021 22:39:24 +0100 Subject: [PATCH 20/86] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 4 +++- src/contexts.jl | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e464b69de..a0b3c8e77 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -436,7 +436,9 @@ function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, left, vn, inds, vi) end -function dot_tilde_assume(rng, context::ConditionContext, sampler, right, left, vn, inds, vi) +function dot_tilde_assume( + rng, context::ConditionContext, sampler, right, left, vn, inds, vi +) return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index f58242b5c..f48a5187c 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -154,7 +154,11 @@ function ConditionContext( end function getvalue(context::ConditionContext, vn) - haskey(context, vn) ? _getvalue(context.values, vn) : getvalue(context.context, vn) + return if haskey(context, vn) + _getvalue(context.values, vn) + else + getvalue(context.context, vn) + end end getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) From e78dc6546c08d960d5ae3e6a95ca3d0e5f4f056f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 22 Jul 2021 02:38:40 +0100 Subject: [PATCH 21/86] overload _evaluate rather than the model-call directly --- src/contextual_model.jl | 6 ++---- src/model.jl | 2 +- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 382c45034..4b843ed0e 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -9,12 +9,10 @@ end # TODO: What do we do for other contexts? Could handle this in general if we had a # notion of wrapper-, primitive-context, etc. -function (cmodel::ContextualModel{<:ConditionContext})( - varinfo::AbstractVarInfo, context::AbstractContext -) +function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return cmodel.model(varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) diff --git a/src/model.jl b/src/model.jl index e5b9beb24..7b5d8918a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -93,7 +93,7 @@ function (model::AbstractModel)( end (model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) +function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else From f9e753ac2ede7d511b46c4bfb833b9dbb19a1481 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 23 Jul 2021 17:59:38 +0100 Subject: [PATCH 22/86] initial work --- src/context_implementations.jl | 104 +++++++++++++++------------------ src/contexts.jl | 13 +++++ 2 files changed, 59 insertions(+), 58 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 3d492f5b1..94f5f3014 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -35,12 +35,25 @@ function tilde_assume(context::SamplingContext, right, vn, inds, vi) end # Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) -function tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi -) +function tilde_assume(context::AbstractContext, args...) + return tilde_assume(NodeTrait(tilde_assume, context), context, args...) +end +function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vinds, vi) + return assume(right, vn, vi) +end +function tilde_assume(::IsParent, context::AbstractContext, args...) + return tilde_assume(childcontext(context), args...) +end + +function tilde_assume(rng, context::AbstractContext, args...) + return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) +end +function tilde_assume(::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi) return assume(rng, sampler, right, vn, vi) end +function tilde_assume(::IsParent, rng, context::AbstractContext, args...) + return tilde_assume(rng, childcontext(context), args...) +end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) @@ -64,12 +77,6 @@ function tilde_assume( end return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end -function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, vi) -end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) @@ -102,18 +109,9 @@ function tilde_assume( return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) -end - function tilde_assume(context::PrefixContext, right, vn, inds, vi) return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end - function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end @@ -162,16 +160,16 @@ function tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) -function tilde_observe(::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::AbstractContext, args...) + return tilde_observe(NodeTrait(tilde_observe, context), context, args...) +end +tilde_observe(::IsLeaf, context::AbstractContext, args...) = observe(args...) +function tilde_observe(::IsParent, context::AbstractContext, args...) + return tilde_observe(childcontext(context), args...) end + tilde_observe(::PriorContext, right, left, vi) = 0 tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) -function tilde_observe(::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) -end # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, right, left, vi) @@ -185,9 +183,6 @@ end function tilde_observe(context::PrefixContext, right, left, vname, vi) return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -291,9 +286,26 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) end # `DefaultContext` -function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) +function dot_tilde_assume(context::AbstractContext, args...) + return dot_tilde_assume(NodeTrait(dot_tilde_assume, context), context, args...) +end +function dot_tilde_assume(rng, context::AbstractContext, args...) + return dot_tilde_assume(rng, NodeTrait(dot_tilde_assume, context), context, args...) +end + +function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi) return dot_assume(right, left, vns, vi) end +function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi) + return dot_assume(rng, sampler, right, vns, left, vi) +end + +function dot_tilde_assume(::IsParent, context::AbstractContext, args...) + return dot_tilde_assume(childcontext(context), args...) +end +function dot_tilde_assume(rng, ::IsParent, context::AbstractContext, args...) + return dot_tilde_assume(rng, childcontext(context), args...) +end function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) @@ -374,25 +386,6 @@ function dot_tilde_assume( dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end end -function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - return dot_assume(right, left, vn, vi) -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi -) - return dot_assume(rng, sampler, right, vn, left, vi) -end - -# `MiniBatchContext` -function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - -function dot_tilde_assume( - rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi -) - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) -end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) @@ -588,18 +581,13 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) -function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) +function dot_tilde_observe(::IsParent, context::AbstractContext, args...) + return dot_tilde_observe(childcontext(context), args...) end + dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(context::LikelihoodContext, right, left, vi) - return dot_observe(right, left, vi) -end -function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) -end # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, right, left, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 05ad8df0d..b27258932 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,11 @@ +# Fallback traits +# TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`? +struct IsLeaf end +struct IsParent end + +NodeTrait(::Any, context) = NodeTrait(context) + +# Contexts """ SamplingContext(rng, sampler, context) @@ -11,6 +19,7 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end +NodeTrait(context::SamplingContext) = IsParent() """ struct DefaultContext <: AbstractContext end @@ -19,6 +28,7 @@ The `DefaultContext` is used by default to compute log the joint probability of and parameters when running the model. """ struct DefaultContext <: AbstractContext end +NodeTrait(context::DefaultContext) = IsLeaf() """ struct PriorContext{Tvars} <: AbstractContext @@ -32,6 +42,7 @@ struct PriorContext{Tvars} <: AbstractContext vars::Tvars end PriorContext() = PriorContext(nothing) +NodeTrait(context::PriorContext) = IsLeaf() """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -46,6 +57,7 @@ struct LikelihoodContext{Tvars} <: AbstractContext vars::Tvars end LikelihoodContext() = LikelihoodContext(nothing) +NodeTrait(context::LikelihoodContext) = IsLeaf() """ struct MiniBatchContext{Tctx, T} <: AbstractContext @@ -66,6 +78,7 @@ end function MiniBatchContext(context=DefaultContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end +NodeTrait(context::MiniBatchContext) = IsParent() """ PrefixContext{Prefix}(context) From f090ff50ec04d9b8eaa0b84efe7eb4479e9ae142 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 05:15:18 +0100 Subject: [PATCH 23/86] added some missing implementations --- src/contexts.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index b27258932..78759de61 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -3,7 +3,13 @@ struct IsLeaf end struct IsParent end -NodeTrait(::Any, context) = NodeTrait(context) +""" + NodeTrait(context) + NodeTrait(f, context) + +Specifies the role of `context` in the context-tree. +""" +NodeTrait(_, context) = NodeTrait(context) # Contexts """ @@ -20,6 +26,7 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte context::C end NodeTrait(context::SamplingContext) = IsParent() +childcontext(context::SamplingContext) = context.context """ struct DefaultContext <: AbstractContext end @@ -79,6 +86,7 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end NodeTrait(context::MiniBatchContext) = IsParent() +childcontext(context::MiniBatchContext) = context.context """ PrefixContext{Prefix}(context) @@ -98,6 +106,9 @@ function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} return PrefixContext{Prefix,typeof(context)}(context) end +NodeTrait(context::PrefixContext) = IsParent() +childcontext(context::PrefixContext) = context.context + const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( From 4e36c55fb7ae028177763dca51cf8dff68d98601 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 05:23:22 +0100 Subject: [PATCH 24/86] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 94f5f3014..0d35aba45 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -48,7 +48,9 @@ end function tilde_assume(rng, context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) end -function tilde_assume(::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi) +function tilde_assume( + ::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi +) return assume(rng, sampler, right, vn, vi) end function tilde_assume(::IsParent, rng, context::AbstractContext, args...) @@ -296,7 +298,9 @@ end function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi) return dot_assume(right, left, vns, vi) end -function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume( + ::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi +) return dot_assume(rng, sampler, right, vns, left, vi) end From 1d3b11e9caa2d2e2184940b42eadf6a176b7ccb1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 06:51:20 +0100 Subject: [PATCH 25/86] drop now unneceesary impls for tilds for ConditionContext --- src/context_implementations.jl | 33 --------------------------------- src/contexts.jl | 3 +++ 2 files changed, 3 insertions(+), 33 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8e2b02851..ce3f89739 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -127,15 +127,6 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) end -# `ConditionContext` -function tilde_assume(context::ConditionContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function tilde_assume(rng, context::ConditionContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) -end - """ tilde_assume!(context, right, vn, inds, vi) @@ -204,14 +195,6 @@ function tilde_observe(context::PrefixContext, right, left, vname, vi) return tilde_observe(context.context, right, left, prefix(context, vname), vi) end -# `ConditionContext` -function tilde_observe(context::ConditionContext, right, left, vname, vi) - return tilde_observe(context.context, right, left, vname, vi) -end -function tilde_observe(context::ConditionContext, right, left, vi) - return tilde_observe(context.context, right, left, vi) -end - """ tilde_observe!(context, right, left, vname, vinds, vi) @@ -428,17 +411,6 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, ) end -# `ConditionContext` -function dot_tilde_assume(context::ConditionContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - -function dot_tilde_assume( - rng, context::ConditionContext, sampler, right, left, vn, inds, vi -) - return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) -end - """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -640,11 +612,6 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end -# `ConditionContext` -function dot_tilde_observe(context::ConditionContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vi) -end - """ dot_tilde_observe!(context, right, left, vname, vinds, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 98790a785..947f31538 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -212,3 +212,6 @@ function decondition(context::ConditionContext, sym, syms...) ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... ) end + +NodeTrait(context::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context From e1a7d38d33b91cfa0acb1b36e31d6fda0e50579f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 06:52:02 +0100 Subject: [PATCH 26/86] formatting --- src/contextual_model.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 4b843ed0e..793491720 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,9 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate( + cmodel.model, varinfo, ConditionContext(cmodel.context.values, context) + ) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From b42c34f346c69c4b5ab615ef0306214e74e1a29b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:03:35 +0100 Subject: [PATCH 27/86] address issues using traits --- src/compiler.jl | 15 +++++++++------ src/contexts.jl | 34 ++++++++++++++++++++++++++-------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index d640ef1d7..eeced2701 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -53,7 +53,13 @@ Return `true` if `vn` is considered an assumption by `context`. The default implementation for `AbstractContext` always returns `true`. """ -contextual_isassumption(context::AbstractContext, vn) = true +contextual_isassumption(::IsLeaf, context, vn) = true +function contextual_isassumption(::IsParent, context, vn) + return contextual_isassumption(childcontext(context), vn) +end +function contextual_isassumption(context::AbstractContext, vn) + return contextual_isassumption(NodeTrait(context), context, vn) +end function contextual_isassumption(context::ConditionContext, vn) # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. @@ -65,13 +71,10 @@ function contextual_isassumption(context::ConditionContext, vn) # The below then evaluates to `!(false || true) === false`. # 3. Neither `context` nor any of it's decendants considers it an observation, # in which case the below evaluates to `!(false || false) === true`. - return !(haskey(context, vn) || !contextual_isassumption(context.context, vn)) + return !(haskey(context, vn) || !contextual_isassumption(childcontext(context), vn)) end function contextual_isassumption(context::PrefixContext, vn) - return contextual_isassumption(context.context, prefix(context, vn)) -end -function contextual_isassumption(context::MiniBatchContext, vn) - return contextual_isassumption(context.context, vn) + return contextual_isassumption(childcontext(context), prefix(context, vn)) end # failsafe: a literal is never an assumption diff --git a/src/contexts.jl b/src/contexts.jl index 947f31538..38cbe213e 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -174,19 +174,34 @@ function ConditionContext( ) where {Names} # Note that this potentially overrides values from `context`, thus giving # precedence to the outmost `ConditionContext`. - return ConditionContext(merge(context.values, values), context.context) + return ConditionContext(merge(context.values, values), childcontext(context)) end +""" + getvalue(context, vn) + +Return the value of the parameter corresponding to `vn` from `context`. +If `context` does not contain the value for `vn`, then `nothing` is returned, +e.g. [`DefaultContext`](@ref) will always return `nothing`. +""" +getvalue(::IsLeaf, context, vn) = nothing +getvalue(::IsParent, context, vn) = getvalue(childcontext(context), vn) +getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), context, vn) +getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn)) function getvalue(context::ConditionContext, vn) return if haskey(context, vn) _getvalue(context.values, vn) else - getvalue(context.context, vn) + getvalue(childcontext(context), vn) end end -getvalue(context::AbstractContext, vn) = getvalue(context.context, vn) -getvalue(context::PrefixContext, vn) = getvalue(context.context, prefix(context, vn)) +# General implementations of `haskey`. +Base.haskey(::IsLeaf, context, vn) = false +Base.haskey(::IsParent, context, vn) = Base.haskey(childcontext(context), vn) +Base.haskey(context::AbstractContext, vn) = Base.haskey(NodeTrait(context), context, vn) + +# Specific to `ConditionContext`. function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars @@ -201,15 +216,18 @@ end # TODO: Can we maybe do this in a better way? # When no second argument is given, we remove _all_ conditioned variables. -# TODO: Should we remove this and just return `context.context`? +# TODO: Should we remove this and just return `childcontext(context)`? # That will work better if `Model` becomes like `ContextualModel`. -decondition(context::ConditionContext) = ConditionContext(NamedTuple(), context.context) +function decondition(context::ConditionContext) + return ConditionContext(NamedTuple(), childcontext(context)) +end function decondition(context::ConditionContext, sym) - return ConditionContext(BangBang.delete!!(context.values, sym), context.context) + return ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), context.context), syms... + ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)), + syms..., ) end From c7c60e6a435ef89fe7bebd39bfe76df08f0d8ea9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:23:59 +0100 Subject: [PATCH 28/86] added rewrap for contexts --- src/contexts.jl | 62 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 38cbe213e..2f74847e6 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -8,9 +8,44 @@ struct IsParent end NodeTrait(f, context) Specifies the role of `context` in the context-tree. + +The officially supported traits are: +- `IsLeaf`: `context` does not have any decendants. +- `IsParent`: `context` has a child context to which we often defer. + Expects the following methods to be implemented: + - [`childcontext`](@ref) + - [`rewrap`](@ref) """ NodeTrait(_, context) = NodeTrait(context) +""" + childcontext(context) + +Return the descendant context of `context`. +""" +childcontext + +""" + rewrap(parent::AbstractContext, child::AbstractContext) + +Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), +effectively updating the child context. + +# Examples +```jldoctest +julia> ctx = SamplingContext(); + +julia> DynamicPPL.childcontext(ctx) +DefaultContext() + +julia> ctx_prior = DynamicPPL.rewrap(ctx, PriorContext()); # only compute the logprior + +julia> DynamicPPL.childcontext(ctx_prior) +PriorContext{Nothing}(nothing) +``` +""" +rewrap + # Contexts """ SamplingContext(rng, sampler, context) @@ -25,8 +60,14 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end +SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context) +SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context) +SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext()) +SamplingContext() = SamplingContext(SampleFromPrior()) + NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context +rewrap(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) """ struct DefaultContext <: AbstractContext end @@ -87,6 +128,7 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context +rewrap(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) """ PrefixContext{Prefix}(context) @@ -108,6 +150,7 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context +rewrap(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) const PREFIX_SEPARATOR = Symbol(".") @@ -156,9 +199,7 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -function ConditionContext(context::ConditionContext, child_context::AbstractContext) - return ConditionContext(context.values, child_context) -end +ConditionContext(; values...) = ConditionContext((; values...)) function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -177,6 +218,10 @@ function ConditionContext( return ConditionContext(merge(context.values, values), childcontext(context)) end +NodeTrait(context::ConditionContext) = IsParent() +childcontext(context::ConditionContext) = context.context +rewrap(parent::ConditionContext, child) = ConditionContext(parent.values, child) + """ getvalue(context, vn) @@ -214,10 +259,10 @@ function Base.haskey( return sym in vars end -# TODO: Can we maybe do this in a better way? -# When no second argument is given, we remove _all_ conditioned variables. -# TODO: Should we remove this and just return `childcontext(context)`? -# That will work better if `Model` becomes like `ContextualModel`. +# Recursively `decondition` the context. +decondition(::IsLeaf, context) = context +decondition(::IsParent, context) = rewrap(context, decondition(childcontext(context))) +decondition(context) = decondition(NodeTrait(context), context) function decondition(context::ConditionContext) return ConditionContext(NamedTuple(), childcontext(context)) end @@ -230,6 +275,3 @@ function decondition(context::ConditionContext, sym, syms...) syms..., ) end - -NodeTrait(context::ConditionContext) = IsParent() -childcontext(context::ConditionContext) = context.context From be67807948b72b0ef47a64d071835ed538384b0d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:31:47 +0100 Subject: [PATCH 29/86] do decondition properly --- src/contexts.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 2f74847e6..648039694 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -260,18 +260,25 @@ function Base.haskey( end # Recursively `decondition` the context. -decondition(::IsLeaf, context) = context -decondition(::IsParent, context) = rewrap(context, decondition(childcontext(context))) -decondition(context) = decondition(NodeTrait(context), context) +decondition(::IsLeaf, context, args...) = context +function decondition(::IsParent, context, args...) + return rewrap(context, decondition(childcontext(context), args...)) +end +decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) - return ConditionContext(NamedTuple(), childcontext(context)) + return ConditionContext(NamedTuple(), decondition(childcontext(context))) end function decondition(context::ConditionContext, sym) - return ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)) + return ConditionContext( + BangBang.delete!!(context.values, sym), childcontext(context, sym) + ) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext(BangBang.delete!!(context.values, sym), childcontext(context)), + ConditionContext( + BangBang.delete!!(context.values, sym), + decondition(childcontext(context), syms...), + ), syms..., ) end From 0468297876eaf773ef823adf665a510fec226a74 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:46:53 +0100 Subject: [PATCH 30/86] added some examples and decondition now removes ConditionContext --- src/contexts.jl | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 648039694..7857c1f92 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -199,7 +199,7 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -ConditionContext(; values...) = ConditionContext((; values...)) +ConditionContext(context=DefaultContext(); values...) = ConditionContext((; values...), context) function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -259,14 +259,46 @@ function Base.haskey( return sym in vars end -# Recursively `decondition` the context. +""" + context([context::AbstractContext,] values::NamedTuple) + context([context::AbstractContext]; values...) + +Return `ConditionContext` with `values` and wrapping `context`. +""" +condition(context=DefaultContext(); values...) = ConditionContext(context; values...) + +""" + decondition(context::AbstractContext, syms...) + +Return `context` but with `syms` no longer conditioned on. + +Note that this recursively traverses contexts, deconditioning all along the way. + +# Examples +```jldoctest +julia> ctx = DefaultContext(); + +julia> decondition(ctx) === ctx # this is a no-op +true + +julia> ctx = ConditionContext(x = 1.0); + +julia> decondition(ctx) +DefaultContext() + +julia> ctx_nested = ConditionContext(SamplingContext(ConditionContext(y=2.0)), x=1.0); + +julia> decondition(ctx_nested) +SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLOBAL_RNG(), SampleFromPrior(), DefaultContext()) +``` +""" decondition(::IsLeaf, context, args...) = context function decondition(::IsParent, context, args...) return rewrap(context, decondition(childcontext(context), args...)) end decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) - return ConditionContext(NamedTuple(), decondition(childcontext(context))) + return decondition(childcontext(context)) end function decondition(context::ConditionContext, sym) return ConditionContext( From 80e3d5fd48b3c46f4c11749ed380e51d08ed5fa0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:56:37 +0100 Subject: [PATCH 31/86] improved condition and decondition a bit further --- src/contexts.jl | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 7857c1f92..7bb437867 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -199,7 +199,9 @@ end return :(NamedTuple{$names_expr}($values_expr)) end -ConditionContext(context=DefaultContext(); values...) = ConditionContext((; values...), context) +function ConditionContext(context=DefaultContext(); values...) + return ConditionContext((; values...), context) +end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -263,9 +265,18 @@ end context([context::AbstractContext,] values::NamedTuple) context([context::AbstractContext]; values...) -Return `ConditionContext` with `values` and wrapping `context`. +Return `ConditionContext` with `values` and `context` if `values` is non-empty, +otherwise return `context` which is [`DefaultContext`](@ref) by default. + +See also: [`decondition`](@ref) """ -condition(context=DefaultContext(); values...) = ConditionContext(context; values...) +condition() = decondition(ConditionContext()) +condition(values::NamedTuple) = condition(DefaultContext(), values) +condition(context::AbstractContext, values::NamedTuple{()}) = context +condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) +function condition(context::AbstractContext=DefaultContext(); values...) + return ConditionContext(context; values...) +end """ decondition(context::AbstractContext, syms...) @@ -274,6 +285,8 @@ Return `context` but with `syms` no longer conditioned on. Note that this recursively traverses contexts, deconditioning all along the way. +See also: [`condition`](@ref) + # Examples ```jldoctest julia> ctx = DefaultContext(); @@ -301,15 +314,15 @@ function decondition(context::ConditionContext) return decondition(childcontext(context)) end function decondition(context::ConditionContext, sym) - return ConditionContext( - BangBang.delete!!(context.values, sym), childcontext(context, sym) + return condition( + decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym) ) end function decondition(context::ConditionContext, sym, syms...) return decondition( - ConditionContext( - BangBang.delete!!(context.values, sym), + condition( decondition(childcontext(context), syms...), + BangBang.delete!!(context.values, sym), ), syms..., ) From 9419e76bbbe2ffb7f708c67ac030d9769273cef6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:56:53 +0100 Subject: [PATCH 32/86] use rewrap in _evaluate for ContextualModel --- src/contextual_model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 4b843ed0e..883fc5379 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,7 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, ConditionContext(cmodel.context.values, context)) + return _evaluate(cmodel.model, varinfo, rewrap(cmodel.context, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From 4e566f78d4deff430b2ebe065ff8909b2f7a3b64 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 07:57:54 +0100 Subject: [PATCH 33/86] remove the drop_missing as it is no longer needed --- src/contexts.jl | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index f48a5187c..1c54cf86a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -118,30 +118,14 @@ struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext end end -@generated function drop_missings(nt::NamedTuple{names,values}) where {names,values} - names_expr = Expr(:tuple) - values_expr = Expr(:tuple) - - for (n, v) in zip(names, values.parameters) - if !(v <: Missing) - push!(names_expr.args, QuoteNode(n)) - push!(values_expr.args, :(nt.$n)) - end - end - - return :(NamedTuple{$names_expr}($values_expr)) -end - function ConditionContext(context::ConditionContext, child_context::AbstractContext) return ConditionContext(context.values, child_context) end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end - function ConditionContext(values::NamedTuple, context::AbstractContext) - values_wo_missing = drop_missings(values) - return ConditionContext{typeof(values_wo_missing)}(values_wo_missing, context) + return ConditionContext{typeof(values)}(values, context) end # Try to avoid nested `ConditionContext`. From b27228aec8a15310f9076cc9623ec0de59f43221 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:08:14 +0100 Subject: [PATCH 34/86] rename rewrap to setchildcontet --- src/contexts.jl | 18 +++++++++--------- src/contextual_model.jl | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 00d6598af..33da613a7 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -14,7 +14,7 @@ The officially supported traits are: - `IsParent`: `context` has a child context to which we often defer. Expects the following methods to be implemented: - [`childcontext`](@ref) - - [`rewrap`](@ref) + - [`setchildcontext`](@ref) """ NodeTrait(_, context) = NodeTrait(context) @@ -26,7 +26,7 @@ Return the descendant context of `context`. childcontext """ - rewrap(parent::AbstractContext, child::AbstractContext) + setchildcontext(parent::AbstractContext, child::AbstractContext) Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), effectively updating the child context. @@ -38,13 +38,13 @@ julia> ctx = SamplingContext(); julia> DynamicPPL.childcontext(ctx) DefaultContext() -julia> ctx_prior = DynamicPPL.rewrap(ctx, PriorContext()); # only compute the logprior +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior julia> DynamicPPL.childcontext(ctx_prior) PriorContext{Nothing}(nothing) ``` """ -rewrap +setchildcontext # Contexts """ @@ -67,7 +67,7 @@ SamplingContext() = SamplingContext(SampleFromPrior()) NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context -rewrap(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) +setchildcontext(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) """ struct DefaultContext <: AbstractContext end @@ -128,7 +128,7 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context -rewrap(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) +setchildcontext(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) """ PrefixContext{Prefix}(context) @@ -150,7 +150,7 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -rewrap(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) +setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) const PREFIX_SEPARATOR = Symbol(".") @@ -206,7 +206,7 @@ end NodeTrait(context::ConditionContext) = IsParent() childcontext(context::ConditionContext) = context.context -rewrap(parent::ConditionContext, child) = ConditionContext(parent.values, child) +setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) """ getvalue(context, vn) @@ -291,7 +291,7 @@ SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLO """ decondition(::IsLeaf, context, args...) = context function decondition(::IsParent, context, args...) - return rewrap(context, decondition(childcontext(context), args...)) + return setchildcontext(context, decondition(childcontext(context), args...)) end decondition(context, args...) = decondition(NodeTrait(context), context, args...) function decondition(context::ConditionContext) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 883fc5379..3ab931b65 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -12,7 +12,7 @@ end function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, rewrap(cmodel.context, context)) + return _evaluate(cmodel.model, varinfo, setchildcontext(cmodel.context, context)) end condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) From 4935d5c610be95bcb073f1776e3ff462160ef371 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:08:33 +0100 Subject: [PATCH 35/86] formatting --- src/contexts.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 33da613a7..36c6485ba 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -67,7 +67,9 @@ SamplingContext() = SamplingContext(SampleFromPrior()) NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context -setchildcontext(parent::SamplingContext, child) = SamplingContext(parent.rng, parent.sampler, child) +function setchildcontext(parent::SamplingContext, child) + return SamplingContext(parent.rng, parent.sampler, child) +end """ struct DefaultContext <: AbstractContext end @@ -128,7 +130,9 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context -setchildcontext(parent::MiniBatchContext, child) = MiniBatchContext(child, parent.loglike_scalar) +function setchildcontext(parent::MiniBatchContext, child) + return MiniBatchContext(child, parent.loglike_scalar) +end """ PrefixContext{Prefix}(context) @@ -150,7 +154,9 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context -setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} = PrefixContext{Prefix}(child) +function setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} + return PrefixContext{Prefix}(child) +end const PREFIX_SEPARATOR = Symbol(".") From 61960832b75ff490862eaca5da6e186602fc3c93 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:11:51 +0100 Subject: [PATCH 36/86] made show a bit nicer for ConditionContext --- src/contexts.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 1c54cf86a..4a6834b32 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -137,6 +137,10 @@ function ConditionContext( return ConditionContext(merge(context.values, values), context.context) end +function Base.show(io::IO, context::ConditionContext) + println(io, "ConditionContext($(context.values), $(context.context))") +end + function getvalue(context::ConditionContext, vn) return if haskey(context, vn) _getvalue(context.values, vn) From 65048fc4c57e1ec884cddbcdcd880269cfdc33a3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:12:32 +0100 Subject: [PATCH 37/86] use print instead of println in show --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 4a6834b32..fa0bf585a 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -138,7 +138,7 @@ function ConditionContext( end function Base.show(io::IO, context::ConditionContext) - println(io, "ConditionContext($(context.values), $(context.context))") + print(io, "ConditionContext($(context.values), $(context.context))") end function getvalue(context::ConditionContext, vn) From 649af299a7f68c0def25026bbf06850aedbbfdbc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 24 Jul 2021 08:14:15 +0100 Subject: [PATCH 38/86] formatting --- src/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index fa0bf585a..926dd0fe0 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -138,7 +138,7 @@ function ConditionContext( end function Base.show(io::IO, context::ConditionContext) - print(io, "ConditionContext($(context.values), $(context.context))") + return print(io, "ConditionContext($(context.values), $(context.context))") end function getvalue(context::ConditionContext, vn) From 4010ab8be8a55e0b4bc60a1194800e6d0bb38850 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 26 Jul 2021 00:02:26 +0100 Subject: [PATCH 39/86] make Model a contextual model --- src/contextual_model.jl | 4 +++- src/model.jl | 28 ++++++++++++++++++++-------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/contextual_model.jl b/src/contextual_model.jl index 3ab931b65..61d5fa6b6 100644 --- a/src/contextual_model.jl +++ b/src/contextual_model.jl @@ -4,7 +4,7 @@ struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel end function contextualize(model::AbstractModel, context::AbstractContext) - return ContextualModel(context, model) + return Model(model.name, model.f, model.args, model.defaults, context) end # TODO: What do we do for other contexts? Could handle this in general if we had a @@ -15,6 +15,8 @@ function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context return _evaluate(cmodel.model, varinfo, setchildcontext(cmodel.context, context)) end + +Base.:|(model::AbstractModel, values) = condition(model, values) condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) condition(model::AbstractModel; values...) = condition(model, (; values...)) function condition(cmodel::ContextualModel{<:ConditionContext}, values) diff --git a/src/model.jl b/src/model.jl index 7b5d8918a..ef3f89807 100644 --- a/src/model.jl +++ b/src/model.jl @@ -34,11 +34,13 @@ julia> Model{(:y,)}(f, (x = 1.0, y = 2.0), (x = 42,)) # with special definition Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x = 1.0, y = 2.0), (x = 42,)) ``` """ -struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel +struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: + AbstractModel name::Symbol f::F args::NamedTuple{argnames,Targs} defaults::NamedTuple{defaultnames,Tdefaults} + context::Ctx @doc """ Model{missings}(name::Symbol, f, args::NamedTuple, defaults::NamedTuple) @@ -51,9 +53,10 @@ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} <: AbstractModel f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple{defaultnames,Tdefaults}, - ) where {missings,F,argnames,Targs,defaultnames,Tdefaults} - return new{F,argnames,defaultnames,missings,Targs,Tdefaults}( - name, f, args, defaults + context::Ctx=DefaultContext(), + ) where {missings,F,argnames,Targs,defaultnames,Tdefaults,Ctx} + return new{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx}( + name, f, args, defaults, context ) end end @@ -68,10 +71,14 @@ Default arguments `defaults` are used internally when constructing instances of model with different arguments. """ @generated function Model( - name::Symbol, f::F, args::NamedTuple{argnames,Targs}, defaults::NamedTuple=NamedTuple() + name::Symbol, + f::F, + args::NamedTuple{argnames,Targs}, + defaults::NamedTuple=NamedTuple(), + context::AbstractContext=DefaultContext(), ) where {F,argnames,Targs} missings = Tuple(name for (name, typ) in zip(argnames, Targs.types) if typ <: Missing) - return :(Model{$missings}(name, f, args, defaults)) + return :(Model{$missings}(name, f, args, defaults, context)) end """ @@ -157,8 +164,13 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf @generated function _evaluate( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + unwrap_args = [ + :($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames + ] + return quote + context_new = insertcontext(context, model.context) + model.f(model, varinfo, context_new, $(unwrap_args...)) + end end """ From ee035cb388d43e38a7130182302ef43f0775dcb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 18:59:48 +0100 Subject: [PATCH 40/86] added some more functionality for context traits --- src/contexts.jl | 112 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 78759de61..3653c68d3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -8,9 +8,107 @@ struct IsParent end NodeTrait(f, context) Specifies the role of `context` in the context-tree. + +The officially supported traits are: +- `IsLeaf`: `context` does not have any decendants. +- `IsParent`: `context` has a child context to which we often defer. + Expects the following methods to be implemented: + - [`childcontext`](@ref) + - [`setchildcontext`](@ref) """ NodeTrait(_, context) = NodeTrait(context) +""" + childcontext(context) + +Return the descendant context of `context`. +""" +childcontext + +""" + setchildcontext(parent::AbstractContext, child::AbstractContext) + +Reconstruct `parent` but now using `child` is its [`childcontext`](@ref), +effectively updating the child context. + +# Examples +```jldoctest +julia> ctx = SamplingContext(); + +julia> DynamicPPL.childcontext(ctx) +DefaultContext() + +julia> ctx_prior = DynamicPPL.setchildcontext(ctx, PriorContext()); # only compute the logprior + +julia> DynamicPPL.childcontext(ctx_prior) +PriorContext{Nothing}(nothing) +``` +""" +setchildcontext + +""" + leafcontext(context) + +Return the leaf of `context`, i.e. the first descendant context that `IsLeaf`. +""" +leafcontext(context) = leafcontext(NodeTrait(leafcontext, context), context) +leafcontext(::IsLeaf, context) = context +leafcontext(::IsParent, context) = leafcontext(childcontext(context)) + +""" + setleafcontext(left, right) + +Return `left` but now with its leaf context replaced by `right`. + +Note that this also works even if `right` is not a leaf context, +in which case effectively append `right` to `left`, dropping the +original leaf context of `left`. + +# Examples +```jldoctest +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext + +julia> struct ParentContext{C} + context::C + end + +julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() + +julia> DynamicPPL.childcontext(context::ParentContext) = context.context + +julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) + +julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + +julia> ctx = ParentContext(ParentContext(DefaultContext())) +ParentContext(ParentContext(DefaultContext())) + +julia> # Replace the leaf context with another leaf. + leafcontext(setleafcontext(ctx, PriorContext())) +PriorContext{Nothing}(nothing) + +julia> # Append another parent context. + setleafcontext(ctx, ParentContext(DefaultContext())) +ParentContext(ParentContext(ParentContext(DefaultContext()))) +``` +""" +function setleafcontext(left, right) + return setleafcontext( + NodeTrait(setleafcontext, left), + NodeTrait(setleafcontext, right), + left, + right + ) +end +function setleafcontext(::IsParent, ::IsParent, left, right) + return setchildcontext(left, setleafcontext(childcontext(left), right)) +end +function setleafcontext(::IsParent, ::IsLeaf, left, right) + return setchildcontext(left, setleafcontext(childcontext(left), right)) +end +setleafcontext(::IsLeaf, ::IsParent, left, right) = right +setleafcontext(::IsLeaf, ::IsLeaf, left, right) = right + # Contexts """ SamplingContext(rng, sampler, context) @@ -25,8 +123,16 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte sampler::S context::C end +SamplingContext(sampler, context) = SamplingContext(Random.GLOBAL_RNG, sampler, context) +SamplingContext(context::AbstractContext) = SamplingContext(SampleFromPrior(), context) +SamplingContext(sampler::AbstractSampler) = SamplingContext(sampler, DefaultContext()) +SamplingContext() = SamplingContext(SampleFromPrior()) + NodeTrait(context::SamplingContext) = IsParent() childcontext(context::SamplingContext) = context.context +function setchildcontext(parent::SamplingContext, child) + return SamplingContext(parent.rng, parent.sampler, child) +end """ struct DefaultContext <: AbstractContext end @@ -87,6 +193,9 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) end NodeTrait(context::MiniBatchContext) = IsParent() childcontext(context::MiniBatchContext) = context.context +function setchildcontext(parent::MiniBatchContext, child) + return MiniBatchContext(child, parent.loglike_scalar) +end """ PrefixContext{Prefix}(context) @@ -108,6 +217,9 @@ end NodeTrait(context::PrefixContext) = IsParent() childcontext(context::PrefixContext) = context.context +function setchildcontext(parent::PrefixContext{Prefix}, child) where {Prefix} + return PrefixContext{Prefix}(child) +end const PREFIX_SEPARATOR = Symbol(".") From 21c08e5852f05d3ee9ef1bdaa832efb66fb96cc9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:03:31 +0100 Subject: [PATCH 41/86] dont overload haskey and improved contextual_isassumption check --- src/compiler.jl | 22 +++++++++++----------- src/contexts.jl | 16 ++++++++-------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index fb77d0323..c69a83803 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -61,17 +61,17 @@ function contextual_isassumption(context::AbstractContext, vn) return contextual_isassumption(NodeTrait(context), context, vn) end function contextual_isassumption(context::ConditionContext, vn) - # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}`. - - # We have either of the following cases: - # 1. `context` considers `vn` as an observation, i.e. it has `vn` as a key, - # which means we have a value to replace with and we don't need to recurse. - # 2. One of the decendant contexts consider it as an observation, i.e. - # `contextual_isassumption` evaluates to `false`. - # The below then evaluates to `!(false || true) === false`. - # 3. Neither `context` nor any of it's decendants considers it an observation, - # in which case the below evaluates to `!(false || false) === true`. - return !(haskey(context, vn) || !contextual_isassumption(childcontext(context), vn)) + if hasvalue(context, vn) + val = getvalue(context, vn) + # TODO: Do we even need the `>: Missing` to help the compiler? + if eltype(val) >: Missing && val === missing + return true + end + end + + # We might have nested contexts, e.g. `ContextionContext{.., <:PrefixContext{..., <:ConditionContext}}` + # so we defer to `childcontext` if we haven't concluded that anything yet. + return contextual_isassumption(childcontext(context), vn) end function contextual_isassumption(context::PrefixContext, vn) return contextual_isassumption(childcontext(context), prefix(context, vn)) diff --git a/src/contexts.jl b/src/contexts.jl index fb26b2757..1c63df0c5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -294,7 +294,7 @@ getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn)) function getvalue(context::ConditionContext, vn) - return if haskey(context, vn) + return if hasvalue(context, vn) _getvalue(context.values, vn) else getvalue(childcontext(context), vn) @@ -302,17 +302,17 @@ function getvalue(context::ConditionContext, vn) end # General implementations of `haskey`. -Base.haskey(::IsLeaf, context, vn) = false -Base.haskey(::IsParent, context, vn) = Base.haskey(childcontext(context), vn) -Base.haskey(context::AbstractContext, vn) = Base.haskey(NodeTrait(context), context, vn) +hasvalue(::IsLeaf, context, vn) = false +hasvalue(::IsParent, context, vn) = hasvalue(childcontext(context), vn) +hasvalue(context::AbstractContext, vn) = hasvalue(NodeTrait(context), context, vn) # Specific to `ConditionContext`. -function Base.haskey(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} +function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end -function Base.haskey( +function hasvalue( context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} ) where {vars,sym} # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. @@ -332,8 +332,8 @@ condition() = decondition(ConditionContext()) condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) -function condition(context::AbstractContext=DefaultContext(); values...) - return ConditionContext(context; values...) +function condition(context::AbstractContext; values...) + return condition(context, (; values...)) end """ From 35fad3691e4653ecc812dc8e6b9b865526e5aaa7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:07:40 +0100 Subject: [PATCH 42/86] remove the now unnecessary ContextualModel and AbstractModel --- src/contextual_model.jl | 29 ----------------------------- src/model.jl | 33 ++++++++++++++++++++++++--------- 2 files changed, 24 insertions(+), 38 deletions(-) delete mode 100644 src/contextual_model.jl diff --git a/src/contextual_model.jl b/src/contextual_model.jl deleted file mode 100644 index 61d5fa6b6..000000000 --- a/src/contextual_model.jl +++ /dev/null @@ -1,29 +0,0 @@ -struct ContextualModel{Ctx<:AbstractContext,M<:Model} <: AbstractModel - context::Ctx - model::M -end - -function contextualize(model::AbstractModel, context::AbstractContext) - return Model(model.name, model.f, model.args, model.defaults, context) -end - -# TODO: What do we do for other contexts? Could handle this in general if we had a -# notion of wrapper-, primitive-context, etc. -function _evaluate(cmodel::ContextualModel{<:ConditionContext}, varinfo, context) - # Wrap `context` in the model-associated `ConditionContext`, but now using `context` as - # `ConditionContext` child. - return _evaluate(cmodel.model, varinfo, setchildcontext(cmodel.context, context)) -end - - -Base.:|(model::AbstractModel, values) = condition(model, values) -condition(model::AbstractModel, values) = contextualize(model, ConditionContext(values)) -condition(model::AbstractModel; values...) = condition(model, (; values...)) -function condition(cmodel::ContextualModel{<:ConditionContext}, values) - return contextualize(cmodel.model, ConditionContext(values, cmodel.context)) -end - -decondition(model::AbstractModel, args...) = model -function decondition(cmodel::ContextualModel{<:ConditionContext}, syms...) - return contextualize(cmodel.model, decondition(cmodel.context, syms...)) -end diff --git a/src/model.jl b/src/model.jl index ef3f89807..a725bff8b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,5 +1,3 @@ -abstract type AbstractModel <: AbstractProbabilisticProgram end - """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults} name::Symbol @@ -35,7 +33,7 @@ Model{typeof(f),(:x, :y),(:x,),(:y,),Tuple{Float64,Float64},Tuple{Int64}}(f, (x ``` """ struct Model{F,argnames,defaultnames,missings,Targs,Tdefaults,Ctx<:AbstractContext} <: - AbstractModel + AbstractProbabilisticProgram name::Symbol f::F args::NamedTuple{argnames,Targs} @@ -81,6 +79,23 @@ model with different arguments. return :(Model{$missings}(name, f, args, defaults, context)) end +function contextualize(model::Model, context::AbstractContext) + return Model(model.name, model.f, model.args, model.defaults, context) +end + +Base.:|(model::Model, values) = condition(model, values) + +condition(model::Model, values) = contextualize(model, ConditionContext(values)) +condition(model::Model; values...) = condition(model, (; values...)) +function condition(model::Model, values) + return contextualize(model, condition(model.context, values)) +end + +function decondition(model::Model, syms...) + return contextualize(model, decondition(model.context, syms...)) +end + + """ (model::Model)([rng, varinfo, sampler, context]) @@ -90,7 +105,7 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::AbstractModel)( +function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), @@ -99,8 +114,8 @@ function (model::AbstractModel)( return model(varinfo, SamplingContext(rng, sampler, context)) end -(model::AbstractModel)(context::AbstractContext) = model(VarInfo(), context) -function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractContext) +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else @@ -108,17 +123,17 @@ function (model::AbstractModel)(varinfo::AbstractVarInfo, context::AbstractConte end end -function (model::AbstractModel)(args...) +function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::AbstractModel)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) +function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) return model(rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::AbstractModel)(rng::Random.AbstractRNG, context::AbstractContext) +function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) return model(rng, VarInfo(), SampleFromPrior(), context) end From 9297e612592580a891994dd418ee394ab8d2e2bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:14:14 +0100 Subject: [PATCH 43/86] removed some leftovers --- Project.toml | 1 - src/DynamicPPL.jl | 1 - src/varinfo.jl | 6 +++--- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index ecf47b50f..5678c050a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" -BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bfcc49e6b..aa8b8ca58 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,7 +6,6 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC -using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules diff --git a/src/varinfo.jl b/src/varinfo.jl index ba6e157e9..64c122dc2 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -124,7 +124,7 @@ end function VarInfo( rng::Random.AbstractRNG, - model::AbstractModel, + model::Model, sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) @@ -132,10 +132,10 @@ function VarInfo( model(rng, varinfo, sampler, context) return TypedVarInfo(varinfo) end -VarInfo(model::AbstractModel, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) +VarInfo(model::Model, args...) = VarInfo(Random.GLOBAL_RNG, model, args...) # without AbstractSampler -function VarInfo(rng::Random.AbstractRNG, model::AbstractModel, context::AbstractContext) +function VarInfo(rng::Random.AbstractRNG, model::Model, context::AbstractContext) return VarInfo(rng, model, SampleFromPrior(), context) end From 65f80944f20b08f6b821af5422c33386a9863f1f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:14:53 +0100 Subject: [PATCH 44/86] removed now gone include --- src/DynamicPPL.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index aa8b8ca58..e7732d8c5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -133,6 +133,5 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -include("contextual_model.jl") end # module From b201ef3bb76ad20fb76dcc38c37b91ed486cb5c9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 28 Jul 2021 19:15:51 +0100 Subject: [PATCH 45/86] removed leftovers --- Project.toml | 1 - src/DynamicPPL.jl | 2 -- 2 files changed, 3 deletions(-) diff --git a/Project.toml b/Project.toml index ecf47b50f..5678c050a 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,6 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" -BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index bfcc49e6b..e7732d8c5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -6,7 +6,6 @@ using Distributions using Bijectors using AbstractMCMC: AbstractMCMC -using BangBang: BangBang using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules @@ -134,6 +133,5 @@ include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") -include("contextual_model.jl") end # module From 9ebdd0e9f2b5313a25565ae06d353a206363b3a8 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 02:09:37 +0100 Subject: [PATCH 46/86] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/model.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index a725bff8b..92b773ab5 100644 --- a/src/model.jl +++ b/src/model.jl @@ -95,7 +95,6 @@ function decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) end - """ (model::Model)([rng, varinfo, sampler, context]) From 274ad23758d14a2a20ed3f5c2836892f9a03ccdb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 02:41:45 +0100 Subject: [PATCH 47/86] fixed typo --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index a725bff8b..8f9057459 100644 --- a/src/model.jl +++ b/src/model.jl @@ -183,7 +183,7 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf :($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames ] return quote - context_new = insertcontext(context, model.context) + context_new = setleafcontext(context, model.context) model.f(model, varinfo, context_new, $(unwrap_args...)) end end From b7998bc84ddab1f99d84b7fc9e79806082976574 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:33:15 +0100 Subject: [PATCH 48/86] fixed PointwiseLikelihood --- src/context_implementations.jl | 3 +++ src/loglikelihoods.jl | 14 ++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 08ec7144d..d15966fb6 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -583,6 +583,9 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts +function dot_tilde_observe(context::AbstractContext, args...) + return dot_tilde_observe(NodeTrait(tilde_observe, context), context, args...) +end dot_tilde_observe(::IsLeaf, ::AbstractContext, args...) = dot_observe(args...) function dot_tilde_observe(::IsParent, context::AbstractContext, args...) return dot_tilde_observe(childcontext(context), args...) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 2901432d1..0cac29219 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -13,6 +13,12 @@ function PointwiseLikelihoodContext( ) end +NodeTrait(::PointwiseLikelihoodContext) = IsParent() +childcontext(context::PointwiseLikelihoodContext) = context.context +function setchildcontext(context::PointwiseLikelihoodContext, child) + return PointwiseLikelihoodContext(context.loglikelihoods, child) +end + function Base.push!( context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, @@ -61,14 +67,6 @@ function Base.push!( return context.loglikelihoods[vn] = logp end -function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) - return tilde_assume(context.context, right, vn, inds, vi) -end - -function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, left, vn, inds, vi) -end - function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!(context.context, right, left, vi) From 8cc3193d6909be89ad772ac7eb19cee121d128d5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:37:01 +0100 Subject: [PATCH 49/86] improved condition and fixed isassumption --- src/DynamicPPL.jl | 1 + src/compiler.jl | 2 +- src/contexts.jl | 63 ++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 53 insertions(+), 13 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index e7732d8c5..2f0ef2617 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -9,6 +9,7 @@ using AbstractMCMC: AbstractMCMC using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules +using BangBang: BangBang using Random: Random diff --git a/src/compiler.jl b/src/compiler.jl index c69a83803..6e10c4285 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -34,7 +34,7 @@ function isassumption(expr::Union{Symbol,Expr}) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) + if !($(DynamicPPL.inargnames)($vn, __model__)) || $(DynamicPPL.inmissings)($vn, __model__) true else $expr === missing diff --git a/src/contexts.jl b/src/contexts.jl index 1c63df0c5..d1b2145c2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -66,9 +66,9 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext -julia> struct ParentContext{C} +julia> struct ParentContext{C} <: AbstractContext context::C end @@ -328,14 +328,13 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ -condition() = decondition(ConditionContext()) +function condition(; values...) + return isempty(values) ? decondition(ConditionContext()) : condition(DefaultContext(), (; values...)) +end condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) -function condition(context::AbstractContext; values...) - return condition(context, (; values...)) -end - +condition(context::AbstractContext; values...) = condition(context, (; values...)) """ decondition(context::AbstractContext, syms...) @@ -347,20 +346,60 @@ See also: [`condition`](@ref) # Examples ```jldoctest -julia> ctx = DefaultContext(); +julia> using DynamicPPL: AbstractContext, leafcontext, setleafcontext, childcontext, setchildcontext + +julia> struct ParentContext{C} <: AbstractContext + context::C + end + +julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() + +julia> DynamicPPL.childcontext(context::ParentContext) = context.context + +julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) + +julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + +julia> ctx = DefaultContext() +DefaultContext() julia> decondition(ctx) === ctx # this is a no-op true -julia> ctx = ConditionContext(x = 1.0); +julia> ctx = condition(x = 1.0) # default "constructor" for `ConditionContext` +ConditionContext((x = 1.0,), DefaultContext()) -julia> decondition(ctx) +julia> decondition(ctx) # `decondition` without arguments drops all conditioning DefaultContext() -julia> ctx_nested = ConditionContext(SamplingContext(ConditionContext(y=2.0)), x=1.0); +julia> # Nested conditioning is supported. + ctx_nested = condition(ParentContext(condition(y=2.0)), x=1.0) +ConditionContext((x = 1.0,), ParentContext(ConditionContext((y = 2.0,), DefaultContext()))) + +julia> # We can also specify which variables to drop. + decondition(ctx_nested, :x) +ParentContext(ConditionContext((y = 2.0,), DefaultContext())) + +julia> # No matter the nested level. + decondition(ctx_nested, :y) +ConditionContext((x = 1.0,), ParentContext(DefaultContext())) + +julia> # Or specify multiple at in one call. + decondition(ctx_nested, :x, :y) +ParentContext(DefaultContext()) julia> decondition(ctx_nested) -SamplingContext{SampleFromPrior, DefaultContext, Random._GLOBAL_RNG}(Random._GLOBAL_RNG(), SampleFromPrior(), DefaultContext()) +ParentContext(DefaultContext()) + +julia> # `Val` is also supported. + decondition(ctx_nested, Val(:x)) +ParentContext(ConditionContext((y = 2.0,), DefaultContext())) + +julia> decondition(ctx_nested, Val(:y)) +ConditionContext((x = 1.0,), ParentContext(DefaultContext())) + +julia> decondition(ctx_nested, Val(:x), Val(:y)) +ParentContext(DefaultContext()) ``` """ decondition(::IsLeaf, context, args...) = context From 26edb2c80cc8562fdff1dfefde403ae02395d010 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:40:08 +0100 Subject: [PATCH 50/86] added BangBang as dep --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index 5678c050a..fd514f836 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.13.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" +BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" @@ -15,6 +16,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" AbstractPPL = "0.2" +BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" From be61ef1224cb31c43998a7720c12c219504b0624 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:41:28 +0100 Subject: [PATCH 51/86] fixed the _evaluate --- src/model.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index 45865ef69..600ba162b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -85,7 +85,6 @@ end Base.:|(model::Model, values) = condition(model, values) -condition(model::Model, values) = contextualize(model, ConditionContext(values)) condition(model::Model; values...) = condition(model, (; values...)) function condition(model::Model, values) return contextualize(model, condition(model.context, values)) @@ -181,8 +180,15 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf unwrap_args = [ :($matchingvalue(context_new, varinfo, model.args.$var)) for var in argnames ] + # We want to give `context` precedence over `model.context` while also + # preserving the leaf context of `context`. We can do this by + # 1. Set the leaf context of `model.context` to `leafcontext(context)`. + # 2. Set leaf context of `context` to the context resulting from (1). + # The result is: + # `context` -> `childcontext(context)` -> ... -> `model.context` + # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` return quote - context_new = setleafcontext(context, model.context) + context_new = setleafcontext(context, setleafcontext(model.context, leafcontext(context))) model.f(model, varinfo, context_new, $(unwrap_args...)) end end From 4e3e08fae8c9a73db5ed777535234a6d02e021e4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:44:30 +0100 Subject: [PATCH 52/86] fixed a doctest --- src/contexts.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 3653c68d3..1345e1c93 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -66,9 +66,9 @@ original leaf context of `left`. # Examples ```jldoctest -julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext +julia> using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext -julia> struct ParentContext{C} +julia> struct ParentContext{C} <: AbstractContext context::C end From 9e23d4d3ecf405646d06076031611244813ee756 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:45:52 +0100 Subject: [PATCH 53/86] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/contexts.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 1345e1c93..4bda07009 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -94,10 +94,7 @@ ParentContext(ParentContext(ParentContext(DefaultContext()))) """ function setleafcontext(left, right) return setleafcontext( - NodeTrait(setleafcontext, left), - NodeTrait(setleafcontext, right), - left, - right + NodeTrait(setleafcontext, left), NodeTrait(setleafcontext, right), left, right ) end function setleafcontext(::IsParent, ::IsParent, left, right) From e2f6fc5f8a2ddc9e6efb04b506c9fba50df28aca Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:46:36 +0100 Subject: [PATCH 54/86] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/compiler.jl | 3 ++- src/contexts.jl | 6 +++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 6e10c4285..921062079 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -34,7 +34,8 @@ function isassumption(expr::Union{Symbol,Expr}) # TODO: Support by adding context to model, and use `model.args` # as the default conditioning. Then we no longer need to check `inargnames` # since it will all be handled by `contextual_isassumption`. - if !($(DynamicPPL.inargnames)($vn, __model__)) || $(DynamicPPL.inmissings)($vn, __model__) + if !($(DynamicPPL.inargnames)($vn, __model__)) || + $(DynamicPPL.inmissings)($vn, __model__) true else $expr === missing diff --git a/src/contexts.jl b/src/contexts.jl index de7a44207..b045a17e2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -326,7 +326,11 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ function condition(; values...) - return isempty(values) ? decondition(ConditionContext()) : condition(DefaultContext(), (; values...)) + return if isempty(values) + decondition(ConditionContext()) + else + condition(DefaultContext(), (; values...)) + end end condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context From a361a5ea5caa9d50a3fbc989ed3b35bd3e1d592f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:47:09 +0100 Subject: [PATCH 55/86] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/model.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 600ba162b..d55f9d5b2 100644 --- a/src/model.jl +++ b/src/model.jl @@ -188,7 +188,9 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf # `context` -> `childcontext(context)` -> ... -> `model.context` # -> `childcontext(model.context)` -> ... -> `leafcontext(context)` return quote - context_new = setleafcontext(context, setleafcontext(model.context, leafcontext(context))) + context_new = setleafcontext( + context, setleafcontext(model.context, leafcontext(context)) + ) model.f(model, varinfo, context_new, $(unwrap_args...)) end end From ce160d682c420127126ae7c53d9f2f6ca47e1373 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 29 Jul 2021 04:53:18 +0100 Subject: [PATCH 56/86] fixed condition without arguments --- src/contexts.jl | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index b045a17e2..ca3c6b4c2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -251,9 +251,6 @@ struct ConditionContext{Names,Values,Ctx<:AbstractContext} <: AbstractContext end end -function ConditionContext(context::ConditionContext, child_context::AbstractContext) - return ConditionContext(context.values, child_context) -end function ConditionContext(values::NamedTuple) return ConditionContext(values, DefaultContext()) end @@ -325,13 +322,7 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ -function condition(; values...) - return if isempty(values) - decondition(ConditionContext()) - else - condition(DefaultContext(), (; values...)) - end -end +condition(; values...) = condition(DefaultContext(), (; values...)) condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) From b492041baba4abcc04d5b29c6ca1fb1cdd3d5ae0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 07:55:56 +0100 Subject: [PATCH 57/86] had forgotten a return --- src/compiler.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 921062079..ae0c2fed0 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -67,6 +67,8 @@ function contextual_isassumption(context::ConditionContext, vn) # TODO: Do we even need the `>: Missing` to help the compiler? if eltype(val) >: Missing && val === missing return true + else + return false end end From 7ae9e3e871034ad963800237f2d406b7d21f79ce Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 08:23:09 +0100 Subject: [PATCH 58/86] added methods for extracting the conditioned/observed variables and values --- src/DynamicPPL.jl | 3 +++ src/contexts.jl | 55 +++++++++++++++++++++++++++++++++++++++++++++++ src/model.jl | 4 ++++ 3 files changed, 62 insertions(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 2f0ef2617..f054517bb 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -104,6 +104,9 @@ export AbstractVarInfo, pointwise_loglikelihoods, condition, decondition, + contextualize, + observations, + conditioned, # Convenience macros @addlogprob!, @submodel diff --git a/src/contexts.jl b/src/contexts.jl index ca3c6b4c2..409f666d1 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -416,3 +416,58 @@ function decondition(context::ConditionContext, sym, syms...) syms..., ) end + +""" + conditioned(model::Model) + conditioned(context::AbstractContext) + +Return `NamedTuple` of values that are conditioned on under `model`/`context`. + +# Examples +```jldoctest +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + end +demo (generic function with 1 methods) + +julia> m = demo(); + +julia> # Returns all the variables we have conditioned on + their values. + conditioned(condition(m, x=100.0, m=1.0)) +(x = 100.0, m = 1.0) + +julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). + cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, m = 1.0) + +julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, + # `a.m` is treated as a random variable. + keys(VarInfo(cm)) +1-element Vector{VarName{Symbol("a.m"), Tuple{}}}: + a.m + +julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. + cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, a.m = 1.0) + +julia> keys(VarInfo(cm)) # <= no variables are sampled +Any[] +``` +""" +function conditioned(context::AbstractContext) + conditioned(NodeTrait(conditioned, context), context) +end +conditioned(::IsLeaf, context) = () +conditioned(::IsParent, context) = conditioned(childcontext(context)) +function conditioned(context::ConditionContext) + # Note the order of arguments to `merge`. The behavior of the rest of DPPL + # is that the outermost `context` takes precendence, hence when resolving + # the `conditioned` variables we need to ensure that `context.values` takes + # precedence over decendants of `context`. + return merge(context.values, conditioned(childcontext(context))) +end diff --git a/src/model.jl b/src/model.jl index d55f9d5b2..105878a61 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,6 +94,10 @@ function decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) end +observations(model::Model) = conditioned(model) +conditioned(model::Model) = conditioned(model.context) + + """ (model::Model)([rng, varinfo, sampler, context]) From 3b31d35ff072cf093973fbc8cd906f154db4c9f6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 08:23:28 +0100 Subject: [PATCH 59/86] fixed a couple of tilde statements --- src/context_implementations.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 9464098d7..83f57fe70 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -171,6 +171,8 @@ function tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts +# TODO: Should we maybe not do `args...` here but instead be explicit? +# Could help avoid stealthy bugs. function tilde_observe(context::AbstractContext, args...) return tilde_observe(NodeTrait(tilde_observe, context), context, args...) end @@ -186,13 +188,13 @@ tilde_observe(::PriorContext, sampler, right, left, vi) = 0 function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(context::MiniBatchContext, right, left, vname, vi) - return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) +function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, sampler, right, left, vname, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ From cefb443c1d3a9c5c389c198024862e8905b8c1e3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 1 Aug 2021 08:25:07 +0100 Subject: [PATCH 60/86] added docstring to observations --- src/model.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/model.jl b/src/model.jl index 105878a61..3495b6a05 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,6 +94,11 @@ function decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) end +""" + observations(model::Model) + +Alias for [`conditioned`](@ref). +""" observations(model::Model) = conditioned(model) conditioned(model::Model) = conditioned(model.context) From 6086734bba15452cd8cc17211565e438308d9997 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 2 Aug 2021 18:38:39 +0100 Subject: [PATCH 61/86] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 3 ++- src/contexts.jl | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 83f57fe70..54a198acc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -189,7 +189,8 @@ function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.context, sampler, right, left, vname, vi) + return context.loglike_scalar * + tilde_observe(context.context, sampler, right, left, vname, vi) end # `PrefixContext` diff --git a/src/contexts.jl b/src/contexts.jl index 409f666d1..ca072453b 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -460,7 +460,7 @@ Any[] ``` """ function conditioned(context::AbstractContext) - conditioned(NodeTrait(conditioned, context), context) + return conditioned(NodeTrait(conditioned, context), context) end conditioned(::IsLeaf, context) = () conditioned(::IsParent, context) = conditioned(childcontext(context)) From e57902001bf139fab6378dec70fe9c9fe6907868 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 2 Aug 2021 20:38:55 +0100 Subject: [PATCH 62/86] Apply suggestions from code review Co-authored-by: Hong Ge --- src/DynamicPPL.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index f054517bb..7d7a14584 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -68,7 +68,6 @@ export AbstractVarInfo, vectorize, # Model Model, - ContextualModel, getmissings, getargnames, generated_quantities, From 7c9dc5e9b7dd4ff16bfdf42b08d4cd5c35345dae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:20:15 +0100 Subject: [PATCH 63/86] make NodeTrait an abstract type --- src/contexts.jl | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 4bda07009..9b8f7ac07 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,7 +1,5 @@ # Fallback traits # TODO: Should this instead be `NoChildren()`, `HasChild()`, etc. so we allow plural too, e.g. `HasChildren()`? -struct IsLeaf end -struct IsParent end """ NodeTrait(context) @@ -16,8 +14,22 @@ The officially supported traits are: - [`childcontext`](@ref) - [`setchildcontext`](@ref) """ +abstract type NodeTrait end NodeTrait(_, context) = NodeTrait(context) +""" + IsLeaf + +Specifies that the context is a leaf in the context-tree. +""" +struct IsLeaf <: NodeTrait end +""" + IsParent + +Specifies that the context is a parent in the context-tree. +""" +struct IsParent <: NodeTrait end + """ childcontext(context) From e67089f4ca262ca34d948947020eaf2b50a872e3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:21:28 +0100 Subject: [PATCH 64/86] make matchingvalue work nicely with contexts --- src/compiler.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/compiler.jl b/src/compiler.jl index 6de2f0945..d344717aa 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -501,8 +501,14 @@ function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) end function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(NodeTrait(matchingvalue, context), context, vi, value) +end +function matchingvalue(::IsLeaf, context::AbstractContext, vi, value) return matchingvalue(SampleFromPrior(), vi, value) end +function matchingvalue(::IsParent, context::AbstractContext, vi, value) + return matchingvalue(childcontext(context), vi, value) +end function matchingvalue(context::SamplingContext, vi, value) return matchingvalue(context.sampler, vi, value) end From 9479ce474afe5e52ef3e1909bf8a974c58db4faf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:21:40 +0100 Subject: [PATCH 65/86] added a bunch of tests for the new trait system for contexts --- test/contexts.jl | 94 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/test/contexts.jl b/test/contexts.jl index d9bcd2ef9..ae1ef100b 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,98 @@ +using Test, DynamicPPL +using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, NodeTrait, IsLeaf, IsParent, PointwiseLikelihoodContext + +struct ParentContext{C<:AbstractContext} <: AbstractContext + context::C +end +ParentContext() = ParentContext(DefaultContext()) +DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::ParentContext) = context.context +DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) +Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") + @testset "contexts.jl" begin + child_contexts = [ + DefaultContext(), + PriorContext(), + LikelihoodContext(), + ] + + parent_contexts = [ + ParentContext(DefaultContext()), + SamplingContext(), + MiniBatchContext(DefaultContext(), 0.0), + PrefixContext{:x}(DefaultContext()), + PointwiseLikelihoodContext() + ] + + contexts = vcat(child_contexts, parent_contexts) + + @testset "NodeTrait" begin + @testset "$context" for context in contexts + # Every `context` should have a `NodeTrait`. + @test NodeTrait(context) isa NodeTrait + end + end + + @testset "leafcontext" begin + @testset "$context" for context in child_contexts + @test leafcontext(context) === context + end + + @testset "$context" for context in parent_contexts + @test NodeTrait(leafcontext(context)) isa IsLeaf + end + end + + @testset "setleafcontext" begin + @testset "$context" for context in child_contexts + # Setting to itself should return itself. + @test setleafcontext(context, context) === context + + # Setting to a different context should return that context. + new_leaf = context isa DefaultContext ? PriorContext() : DefaultContext() + @test setleafcontext(context, new_leaf) === new_leaf + + # Also works for parent contexts. + new_leaf = ParentContext(context) + @test setleafcontext(context, new_leaf) === new_leaf + end + + @testset "$context" for context in parent_contexts + # Leaf contexts. + new_leaf = leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + @test leafcontext(setleafcontext(context, new_leaf)) === new_leaf + + # Setting parent contexts as "leaf" means that the new leaf should be + # the leaf of the parent context we just set as the leaf. + new_leaf = ParentContext((leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext())) + @test leafcontext(setleafcontext(context, new_leaf)) === leafcontext(new_leaf) + end + end + + # `IsParent` interface. + @testset "childcontext" begin + @testset "$context" for context in parent_contexts + @test childcontext(context) isa AbstractContext + end + end + + @testset "setchildcontext" begin + @testset "nested contexts" begin + # Both of the following should result in the same context. + context1 = ParentContext(ParentContext(ParentContext())) + context2 = setchildcontext(ParentContext(), setchildcontext(ParentContext(), ParentContext())) + @test context1 === context2 + end + + @testset "$context" for context in parent_contexts + # Setting the child context to a leaf should now change the `leafcontext` accordingly. + new_leaf = leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + new_context = setchildcontext(context, new_leaf) + @test childcontext(new_context) === leafcontext(new_context) === new_leaf + end + end + @testset "PrefixContext" begin ctx = @inferred PrefixContext{:f}( PrefixContext{:e}( From 1315d886d56328f52c852410803d5da859d2e8ac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:22:39 +0100 Subject: [PATCH 66/86] formatting --- test/contexts.jl | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index ae1ef100b..7793dfc74 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,5 +1,14 @@ using Test, DynamicPPL -using DynamicPPL: leafcontext, setleafcontext, childcontext, setchildcontext, AbstractContext, NodeTrait, IsLeaf, IsParent, PointwiseLikelihoodContext +using DynamicPPL: + leafcontext, + setleafcontext, + childcontext, + setchildcontext, + AbstractContext, + NodeTrait, + IsLeaf, + IsParent, + PointwiseLikelihoodContext struct ParentContext{C<:AbstractContext} <: AbstractContext context::C @@ -11,18 +20,14 @@ DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") @testset "contexts.jl" begin - child_contexts = [ - DefaultContext(), - PriorContext(), - LikelihoodContext(), - ] + child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] parent_contexts = [ ParentContext(DefaultContext()), SamplingContext(), MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), - PointwiseLikelihoodContext() + PointwiseLikelihoodContext(), ] contexts = vcat(child_contexts, parent_contexts) @@ -60,18 +65,21 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c @testset "$context" for context in parent_contexts # Leaf contexts. - new_leaf = leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + new_leaf = + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() @test leafcontext(setleafcontext(context, new_leaf)) === new_leaf # Setting parent contexts as "leaf" means that the new leaf should be # the leaf of the parent context we just set as the leaf. - new_leaf = ParentContext((leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext())) + new_leaf = ParentContext(( + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + )) @test leafcontext(setleafcontext(context, new_leaf)) === leafcontext(new_leaf) end end # `IsParent` interface. - @testset "childcontext" begin + @testset "childcontext" begin @testset "$context" for context in parent_contexts @test childcontext(context) isa AbstractContext end @@ -81,13 +89,16 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c @testset "nested contexts" begin # Both of the following should result in the same context. context1 = ParentContext(ParentContext(ParentContext())) - context2 = setchildcontext(ParentContext(), setchildcontext(ParentContext(), ParentContext())) + context2 = setchildcontext( + ParentContext(), setchildcontext(ParentContext(), ParentContext()) + ) @test context1 === context2 end @testset "$context" for context in parent_contexts # Setting the child context to a leaf should now change the `leafcontext` accordingly. - new_leaf = leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() + new_leaf = + leafcontext(context) isa DefaultContext ? PriorContext() : DefaultContext() new_context = setchildcontext(context, new_leaf) @test childcontext(new_context) === leafcontext(new_context) === new_leaf end From 1d3fa2b2a269af89fb3df59f6ac7be8b758468a5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 02:27:52 +0100 Subject: [PATCH 67/86] dont export contextualize, observations and conditioned --- src/DynamicPPL.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 7d7a14584..ac2734b47 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -103,9 +103,6 @@ export AbstractVarInfo, pointwise_loglikelihoods, condition, decondition, - contextualize, - observations, - conditioned, # Convenience macros @addlogprob!, @submodel From da34dff45a553ee7b8cfcc56589bef67a8034732 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:01:08 +0100 Subject: [PATCH 68/86] added getvalue_nested and hasvalue_nested to be more explicit --- src/compiler.jl | 4 +-- src/contexts.jl | 76 +++++++++++++++++++++++++++++++++---------------- 2 files changed, 53 insertions(+), 27 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7a1556419..e55caac42 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -400,7 +400,7 @@ function generate_tilde(left, right) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left = $(DynamicPPL.getvalue)(__context__, $vn) + $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) end $(DynamicPPL.tilde_observe!)( @@ -449,7 +449,7 @@ function generate_dot_tilde(left, right) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) - $left .= $(DynamicPPL.getvalue)(__context__, $vn) + $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) end $(DynamicPPL.dot_tilde_observe!)( diff --git a/src/contexts.jl b/src/contexts.jl index d0b1230b0..884529b36 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -288,43 +288,69 @@ childcontext(context::ConditionContext) = context.context setchildcontext(parent::ConditionContext, child) = ConditionContext(parent.values, child) """ - getvalue(context, vn) + hasvalue(context, vn) -Return the value of the parameter corresponding to `vn` from `context`. -If `context` does not contain the value for `vn`, then `nothing` is returned, -e.g. [`DefaultContext`](@ref) will always return `nothing`. +Return `true` if `vn` is found in `context`. """ -getvalue(::IsLeaf, context, vn) = nothing -getvalue(::IsParent, context, vn) = getvalue(childcontext(context), vn) -getvalue(context::AbstractContext, vn) = getvalue(NodeTrait(getvalue, context), context, vn) -getvalue(context::PrefixContext, vn) = getvalue(childcontext(context), prefix(context, vn)) - -function getvalue(context::ConditionContext, vn) - return if hasvalue(context, vn) - _getvalue(context.values, vn) - else - getvalue(childcontext(context), vn) - end -end +hasvalue(context, vn) = false -# General implementations of `haskey`. -hasvalue(::IsLeaf, context, vn) = false -hasvalue(::IsParent, context, vn) = hasvalue(childcontext(context), vn) -hasvalue(context::AbstractContext, vn) = hasvalue(NodeTrait(context), context, vn) - -# Specific to `ConditionContext`. function hasvalue(context::ConditionContext{vars}, vn::VarName{sym}) where {vars,sym} - # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end - function hasvalue( context::ConditionContext{vars}, vn::AbstractArray{<:VarName{sym}} ) where {vars,sym} - # TODO: Add possibility of indexed variables, e.g. `x[1]`, etc. return sym in vars end +""" + getvalue(context, vn) + +Return value of `vn` in `context`. +""" +function getvalue(context::AbstractContext, vn) + return error("context $(context) does not contain value for $vn") +end +getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) + +""" + hasvalue_nested(context, vn) + +Return `true` if `vn` is found in `context` or any of its descendants. +""" +function hasvalue_nested(context::AbstractContext, vn) + return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) +end +hasvalue_nested(::IsLeaf, context, vn) = hasvalue(context, vn) +function hasvalue_nested(::IsParent, context, vn) + return hasvalue(context, vn) || hasvalue_nested(childcontext(context), vn) +end +function hasvalue_nested(context::PrefixContext, vn) + return hasvalue_nested(childcontext(context), prefix(context, vn)) +end + +""" + getvalue_nested(context, vn) + +Return the value of the parameter corresponding to `vn` from `context` or its descendants. +""" +function getvalue_nested(context::AbstractContext, vn) + return getvalue_nested(NodeTrait(getvalue_nested, context), context, vn) +end +function getvalue_nested(::IsLeaf, context, vn) + return error("context $(context) does not contain value for $vn") +end +function getvalue_nested(context::PrefixContext, vn) + return getvalue_nested(childcontext(context), prefix(context, vn)) +end +function getvalue_nested(::IsParent, context, vn) + return if hasvalue(context, vn) + getvalue(context, vn) + else + getvalue_nested(childcontext(context), vn) + end +end + """ context([context::AbstractContext,] values::NamedTuple) context([context::AbstractContext]; values...) From 5e15b25a1cb1c53cf64a53978f37ee7c94121b9e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:01:49 +0100 Subject: [PATCH 69/86] added a substantial amount of testing for ConditionContext --- src/model.jl | 1 - test/contexts.jl | 125 ++++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 123 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 3495b6a05..9f163910d 100644 --- a/src/model.jl +++ b/src/model.jl @@ -102,7 +102,6 @@ Alias for [`conditioned`](@ref). observations(model::Model) = conditioned(model) conditioned(model::Model) = conditioned(model.context) - """ (model::Model)([rng, varinfo, sampler, context]) diff --git a/test/contexts.jl b/test/contexts.jl index 7793dfc74..5f9f511b9 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -8,8 +8,15 @@ using DynamicPPL: NodeTrait, IsLeaf, IsParent, - PointwiseLikelihoodContext - + PointwiseLikelihoodContext, + contextual_isassumption, + ConditionContext, + hasvalue, + getvalue, + hasvalue_nested, + getvalue_nested + +# Dummy context to test nested behaviors. struct ParentContext{C<:AbstractContext} <: AbstractContext context::C end @@ -19,6 +26,50 @@ DynamicPPL.childcontext(context::ParentContext) = context.context DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") +# TODO: Should we maybe put this in DPPL itself? +function Base.iterate(context::AbstractContext) + if NodeTrait(context) isa IsLeaf + return nothing + end + + return context, context +end +function Base.iterate(_::AbstractContext, context::AbstractContext) + return _iterate(NodeTrait(context), context) +end +_iterate(::IsLeaf, context) = nothing +function _iterate(::IsParent, context) + child = childcontext(context) + return child, child +end + +Base.IteratorSize(::Type{<:AbstractContext}) = Base.SizeUnknown() +Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() + +""" + remove_prefix(vn::VarName) + +Return `vn` but now with the prefix removed. +""" +remove_prefix(vn::VarName) = VarName{Symbol(split(string(vn), ".")[end])}(vn.indexing) + +""" + varnames(vn::VarName, val) + +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varnames(vn::VarName, val::Real) = [vn] +function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return (VarName(vn, (vn.indexing..., Tuple(I))) for I in CartesianIndices(val)) +end +function varnames(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varnames(VarName(vn, (vn.indexing..., Tuple(I))), val[I]) for + I in CartesianIndices(val) + ) +end + @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] @@ -28,6 +79,10 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c MiniBatchContext(DefaultContext(), 0.0), PrefixContext{:x}(DefaultContext()), PointwiseLikelihoodContext(), + ConditionContext((x=1.0,)), + ConditionContext((x=1.0,), ParentContext(ConditionContext((y=2.0,)))), + ConditionContext((x=1.0,), PrefixContext{:a}(ConditionContext((var"a.y"=2.0,)))), + ConditionContext((x=[1.0, missing],)), ] contexts = vcat(child_contexts, parent_contexts) @@ -104,6 +159,72 @@ Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c end end + @testset "contextual_isassumption" begin + @testset "$context" for context in contexts + # Any `context` should return `true` by default. + @test contextual_isassumption(context, VarName{gensym(:x)}()) + + if any(Base.Fix2(isa, ConditionContext), context) + # We have a `ConditionContext` among us. + # Let's first extract the conditioned variables. + conditioned_values = DynamicPPL.conditioned(context) + + for (sym, val) in pairs(conditioned_values) + vn = VarName{sym}() + + # We need to drop the prefix of `var` since in `contextual_isassumption` + # it will be threaded through the `PrefixContext` before it reaches + # `ConditionContext` with the conditioned variable. + vn_without_prefix = remove_prefix(vn) + + # Let's check elementwise. + for vn_child in varnames(vn_without_prefix, val) + if DynamicPPL._getindex(val, vn_child.indexing) === missing + @test contextual_isassumption(context, vn_child) + else + @test !contextual_isassumption(context, vn_child) + end + end + end + end + end + end + + @testset "getvalue_nested & hasvalue_nested" begin + @testset "$context" for context in contexts + fake_vn = VarName{gensym(:x)}() + @test !hasvalue_nested(context, fake_vn) + @test_throws ErrorException getvalue_nested(context, fake_vn) + + if any(Base.Fix2(isa, ConditionContext), context) + # `ConditionContext` specific. + + # Let's first extract the conditioned variables. + conditioned_values = DynamicPPL.conditioned(context) + + for (sym, val) in pairs(conditioned_values) + vn = VarName{sym}() + + # We need to drop the prefix of `var` since in `contextual_isassumption` + # it will be threaded through the `PrefixContext` before it reaches + # `ConditionContext` with the conditioned variable. + vn_without_prefix = remove_prefix(vn) + + for vn_child in varnames(vn_without_prefix, val) + # `vn_child` should be in `context`. + @test hasvalue_nested(context, vn_child) + if !hasvalue_nested(context, vn_child) + @info "" context vn_child + end + # Value should be the same as extracted above. + @test getvalue_nested(context, vn_child) === + DynamicPPL._getindex(val, vn_child.indexing) + end + end + end + end + end + @testset "PrefixContext" begin ctx = @inferred PrefixContext{:f}( PrefixContext{:e}( From 386e9852e4d5165082dd92d1801fb35df4e58470 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:02:34 +0100 Subject: [PATCH 70/86] removed some debugging code --- test/contexts.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index 5f9f511b9..fa96a0ec4 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -213,9 +213,6 @@ end for vn_child in varnames(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasvalue_nested(context, vn_child) - if !hasvalue_nested(context, vn_child) - @info "" context vn_child - end # Value should be the same as extracted above. @test getvalue_nested(context, vn_child) === DynamicPPL._getindex(val, vn_child.indexing) From 52c9f044fe9b94cfaee52865d6e6c3a2c29dda48 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:18:18 +0100 Subject: [PATCH 71/86] use maybe_view in isassumption check --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index e55caac42..7880b9f1c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -38,7 +38,7 @@ function isassumption(expr::Union{Symbol,Expr}) $(DynamicPPL.inmissings)($vn, __model__) true else - $expr === missing + $(maybe_view(expr)) === missing end else false From 65e7f7128bb9b0ce66888e714d85725448dde3ee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 4 Aug 2021 05:20:29 +0100 Subject: [PATCH 72/86] added more extensive docstring for getvalue_nested and hasvalue_nested --- src/contexts.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 884529b36..55b070c6d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -317,6 +317,9 @@ getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) hasvalue_nested(context, vn) Return `true` if `vn` is found in `context` or any of its descendants. + +This is contrast to [`hasvalue`](@ref) which only checks for `vn` in `context`, +not recursively checking if `vn` is in any of its descendants. """ function hasvalue_nested(context::AbstractContext, vn) return hasvalue_nested(NodeTrait(hasvalue_nested, context), context, vn) @@ -333,6 +336,9 @@ end getvalue_nested(context, vn) Return the value of the parameter corresponding to `vn` from `context` or its descendants. + +This is contrast to [`getvalue`](@ref) which only returns the value `vn` in `context`, +not recursively looking into its descendants. """ function getvalue_nested(context::AbstractContext, vn) return getvalue_nested(NodeTrait(getvalue_nested, context), context, vn) From 0054ffb65e60b6fd67b87d3e17921de5e3515ba6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 04:19:10 +0100 Subject: [PATCH 73/86] some minor style changes --- src/contexts.jl | 4 ++-- src/model.jl | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 55b070c6d..ea31592a5 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -366,11 +366,11 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default. See also: [`decondition`](@ref) """ -condition(; values...) = condition(DefaultContext(), (; values...)) +condition(; values...) = condition(DefaultContext(), NamedTuple(values)) condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) -condition(context::AbstractContext; values...) = condition(context, (; values...)) +condition(context::AbstractContext; values...) = condition(context, NamedTuple(values)) """ decondition(context::AbstractContext, syms...) diff --git a/src/model.jl b/src/model.jl index 9f163910d..5c3e555f4 100644 --- a/src/model.jl +++ b/src/model.jl @@ -85,7 +85,7 @@ end Base.:|(model::Model, values) = condition(model, values) -condition(model::Model; values...) = condition(model, (; values...)) +condition(model::Model; values...) = condition(model, NamedTuple(values)) function condition(model::Model, values) return contextualize(model, condition(model.context, values)) end From 14a94f0f47007830c0266ca12bcbbc7ed112137f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 05:20:54 +0100 Subject: [PATCH 74/86] added docs and doctests for condition and decondition for model --- src/contexts.jl | 103 ++----------------- src/model.jl | 262 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 268 insertions(+), 97 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index ea31592a5..98eb4b85d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -358,8 +358,8 @@ function getvalue_nested(::IsParent, context, vn) end """ - context([context::AbstractContext,] values::NamedTuple) - context([context::AbstractContext]; values...) + condition([context::AbstractContext,] values::NamedTuple) + condition([context::AbstractContext]; values...) Return `ConditionContext` with `values` and `context` if `values` is non-empty, otherwise return `context` which is [`DefaultContext`](@ref) by default. @@ -371,6 +371,7 @@ condition(values::NamedTuple) = condition(DefaultContext(), values) condition(context::AbstractContext, values::NamedTuple{()}) = context condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context) condition(context::AbstractContext; values...) = condition(context, NamedTuple(values)) + """ decondition(context::AbstractContext, syms...) @@ -379,64 +380,6 @@ Return `context` but with `syms` no longer conditioned on. Note that this recursively traverses contexts, deconditioning all along the way. See also: [`condition`](@ref) - -# Examples -```jldoctest -julia> using DynamicPPL: AbstractContext, leafcontext, setleafcontext, childcontext, setchildcontext - -julia> struct ParentContext{C} <: AbstractContext - context::C - end - -julia> DynamicPPL.NodeTrait(::ParentContext) = DynamicPPL.IsParent() - -julia> DynamicPPL.childcontext(context::ParentContext) = context.context - -julia> DynamicPPL.setchildcontext(::ParentContext, child) = ParentContext(child) - -julia> Base.show(io::IO, c::ParentContext) = print(io, "ParentContext(", childcontext(c), ")") - -julia> ctx = DefaultContext() -DefaultContext() - -julia> decondition(ctx) === ctx # this is a no-op -true - -julia> ctx = condition(x = 1.0) # default "constructor" for `ConditionContext` -ConditionContext((x = 1.0,), DefaultContext()) - -julia> decondition(ctx) # `decondition` without arguments drops all conditioning -DefaultContext() - -julia> # Nested conditioning is supported. - ctx_nested = condition(ParentContext(condition(y=2.0)), x=1.0) -ConditionContext((x = 1.0,), ParentContext(ConditionContext((y = 2.0,), DefaultContext()))) - -julia> # We can also specify which variables to drop. - decondition(ctx_nested, :x) -ParentContext(ConditionContext((y = 2.0,), DefaultContext())) - -julia> # No matter the nested level. - decondition(ctx_nested, :y) -ConditionContext((x = 1.0,), ParentContext(DefaultContext())) - -julia> # Or specify multiple at in one call. - decondition(ctx_nested, :x, :y) -ParentContext(DefaultContext()) - -julia> decondition(ctx_nested) -ParentContext(DefaultContext()) - -julia> # `Val` is also supported. - decondition(ctx_nested, Val(:x)) -ParentContext(ConditionContext((y = 2.0,), DefaultContext())) - -julia> decondition(ctx_nested, Val(:y)) -ConditionContext((x = 1.0,), ParentContext(DefaultContext())) - -julia> decondition(ctx_nested, Val(:x), Val(:y)) -ParentContext(DefaultContext()) -``` """ decondition(::IsLeaf, context, args...) = context function decondition(::IsParent, context, args...) @@ -462,46 +405,12 @@ function decondition(context::ConditionContext, sym, syms...) end """ - conditioned(model::Model) conditioned(context::AbstractContext) -Return `NamedTuple` of values that are conditioned on under `model`/`context`. +Return `NamedTuple` of values that are conditioned on under context`. -# Examples -```jldoctest -julia> @model function demo() - m ~ Normal() - x ~ Normal(m, 1) - end -demo (generic function with 1 methods) - -julia> m = demo(); - -julia> # Returns all the variables we have conditioned on + their values. - conditioned(condition(m, x=100.0, m=1.0)) -(x = 100.0, m = 1.0) - -julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). - cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0); - -julia> conditioned(cm) -(x = 100.0, m = 1.0) - -julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, - # `a.m` is treated as a random variable. - keys(VarInfo(cm)) -1-element Vector{VarName{Symbol("a.m"), Tuple{}}}: - a.m - -julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. - cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0); - -julia> conditioned(cm) -(x = 100.0, a.m = 1.0) - -julia> keys(VarInfo(cm)) # <= no variables are sampled -Any[] -``` +Note that this will recursively traverse the context stack and return +a merged version of the condition values. """ function conditioned(context::AbstractContext) return conditioned(NodeTrait(conditioned, context), context) diff --git a/src/model.jl b/src/model.jl index 5c3e555f4..b7f3c6804 100644 --- a/src/model.jl +++ b/src/model.jl @@ -83,13 +83,233 @@ function contextualize(model::Model, context::AbstractContext) return Model(model.name, model.f, model.args, model.defaults, context) end +""" + model | (x = 1.0, ...) + +Return a `Model` which now treats variables on the right-hand side as observations. + +See [`condition`](@ref) for more information and examples. +""" Base.:|(model::Model, values) = condition(model, values) +""" + condition(model::Model; values...) + condition(model::Model, values::NamedTuple) + +Return a `Model` which now treats the variables in `values` as observations. + +See also: [`decondition`](@ref), [`conditioned`](@ref) + +# Limitations + +This does currently _not_ work with variables that are +provided to the model as arguments, e.g. `@model function demo(x) ... end` +means that `condition` will not affect the variable `x`. + +Therefore if one wants to make use of `condition` and [`decondition`](@ref) +one should not be specifying any random variables as arguments. + +This is done for the sake of backwards compatibility. + +# Examples +## Simple univariate model +```jldoctest condition +julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reproducibility. + +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + + return (; m, x) + end +demo (generic function with 1 method) + +julia> model = demo(); + +julia> model(rng) +(m = -0.6702516921145671, x = -0.22312984965118443) + +julia> # Create a new instance which treats `x` as observed + # with value `100.0`, and similarly for `m=1.0`. + conditioned_model = condition(model, x=100.0, m=1.0); + +julia> conditioned_model(rng) +(m = 1.0, x = 100.0) + +julia> # Let's only condition on `x = 100.0`. + conditioned_model = condition(model, x = 100.0); + +julia> conditioned_model(rng) +(m = 1.3736306979834252, x = 100.0) + +julia> # We can also use the nicer `|` syntax. + conditioned_model = model | (x = 100.0, ); + +julia> conditioned_model(rng) +(m = 1.3095394956381083, x = 100.0) +``` + +## Condition only a part of a multivariate variable + +Not only can be condition on multivariate random variables, but +we can also use the standard mechanism of setting something to `missing` +in the call to `condition` to only condition on a part of the variable. + +```jldoctest condition +julia> @model function demo_mv(::Type{TV}=Float64) where {TV} + m = Vector{TV}(undef, 2) + m[1] ~ Normal() + m[2] ~ Normal() + + return m + end +demo_mv (generic function with 2 methods) + +julia> model = demo_mv(); + +julia> conditioned_model = condition(model, m = [missing, 1.0]); + +julia> conditioned_model(rng) # (✓) `m[1]` sampled, `m[2]` is fixed +2-element Vector{Float64}: + 0.12607002180931043 + 1.0 +``` + +Intuitively one might also expect to be able to write `model | (x[1] = 1.0, )`. +Unfortunately this is not supported due to performance. + +```jldoctest condition +julia> condition(model, var"x[2]" = 1.0)(rng) # (×) `x[2]` is not set to 1.0. +2-element Vector{Float64}: + 0.683947930996541 + -1.019202452456547 +``` + +We will likely provide some syntactic sugar for this in the future. + +## Nested models + +`condition` of course also supports the use of nested models through +the use of [`@submodel`](@ref). + +```jldoctest condition +julia> @model demo_inner() = m ~ Normal() +demo_inner (generic function with 1 method) + +julia> @model function demo_outer() + m = @submodel demo_inner() + return m + end +demo_outer (generic function with 1 method) + +julia> model = demo_outer(); + +julia> model(rng) +0.683947930996541 + +julia> conditioned_model = model | (m = 1.0, ); + +julia> conditioned_model(rng) +1.0 +``` + +But one needs to be careful when prefixing variables in the nested models: + +```jldoctest condition +julia> @model function demo_outer_prefix() + m = @submodel inner demo_inner() + return m + end +demo_outer_prefix (generic function with 1 method) + +julia> # This doesn't work now! + conditioned_model = demo_outer_prefix() | (m = 1.0, ); + +julia> conditioned_model(rng) +-1.019202452456547 + +julia> # `m` in `demo_inner` is referred to as `inner.m` internally, so we do: + conditioned_model = demo_outer_prefix() | (var"inner.m" = 1.0, ); + +julia> conditioned_model(rng) +1.0 + +julia> # Note that the above `var"..."` is just standard Julia syntax: + typeof((var"inner.m" = 1.0, )) +NamedTuple{(Symbol("inner.m"),), Tuple{Float64}} +``` + +The difference is maybe more obvious once we look at how these different +in their trace/`VarInfo`: + +```jldoctest condition +julia> keys(VarInfo(demo_outer())) +1-element Vector{VarName{:m, Tuple{}}}: + m + +julia> keys(VarInfo(demo_outer_prefix())) +1-element Vector{VarName{Symbol("inner.m"), Tuple{}}}: + inner.m +``` + +From this we can tell what the correct way to condition `m` within `demo_inner` +is in the two different models. + +""" condition(model::Model; values...) = condition(model, NamedTuple(values)) function condition(model::Model, values) return contextualize(model, condition(model.context, values)) end +""" + decondition(model::Model) + decondition(model::Model, syms...) + +Return a `Model` for which `syms...` are _not_ considered observations. +If no `syms` are provided, then all variables currently considered observations +will no longer be. + +This is essentially the inverse of [`condition`](@ref). This also means that +it suffers from the same limitiations. + +# Examples +```jldoctest +julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reproducibility. + +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + + return (; m, x) + end +demo (generic function with 1 method) + +julia> conditioned_model = condition(demo(), m = 1.0, x = 10.0); + +julia> conditioned_model(rng) +(m = 1.0, x = 10.0) + +julia> model = decondition(conditioned_model, :m); + +julia> model(rng) +(m = -0.6702516921145671, x = 10.0) + +julia> # `decondition` multiple at once: + decondition(model, :m, :x)(rng) +(m = 0.4471218424633827, x = 1.820752540446808) + +julia> # `decondition` without any symbols will `decondition` all variables. + decondition(model)(rng) +(m = 1.3095394956381083, x = 1.4356095174474188) + +julia> # Usage of `Val` to perform `decondition` at compile-time if possible + # is also supported. + model = decondition(conditioned_model, Val{:m}()); + +julia> model(rng) +(m = 0.683947930996541, x = 10.0) +``` +""" function decondition(model::Model, syms...) return contextualize(model, decondition(model.context, syms...)) end @@ -100,6 +320,48 @@ end Alias for [`conditioned`](@ref). """ observations(model::Model) = conditioned(model) + +""" + conditioned(model::Model) + +Return `NamedTuple` of values that are conditioned on under `model`. + +# Examples +```jldoctest +julia> @model function demo() + m ~ Normal() + x ~ Normal(m, 1) + end +demo (generic function with 1 methods) + +julia> m = demo(); + +julia> # Returns all the variables we have conditioned on + their values. + conditioned(condition(m, x=100.0, m=1.0)) +(x = 100.0, m = 1.0) + +julia> # Nested ones also work (note that `PrefixContext` does nothing to the result). + cm = condition(contextualize(m, PrefixContext{:a}(condition(m=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, m = 1.0) + +julia> # Since we conditioned on `m`, not `a.m` as it will appear after prefixed, + # `a.m` is treated as a random variable. + keys(VarInfo(cm)) +1-element Vector{VarName{Symbol("a.m"), Tuple{}}}: + a.m + +julia> # If we instead condition on `a.m`, `m` in the model will be considered an observation. + cm = condition(contextualize(m, PrefixContext{:a}(condition(var"a.m"=1.0))), x=100.0); + +julia> conditioned(cm) +(x = 100.0, a.m = 1.0) + +julia> keys(VarInfo(cm)) # <= no variables are sampled +Any[] +``` +""" conditioned(model::Model) = conditioned(model.context) """ From 6e563cfdc9fb2bee7375a7b20ceb8cb0c6cbab88 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 05:21:58 +0100 Subject: [PATCH 75/86] rephrased a comment --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 7880b9f1c..2b0b8fe30 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -64,7 +64,7 @@ end function contextual_isassumption(context::ConditionContext, vn) if hasvalue(context, vn) val = getvalue(context, vn) - # TODO: Do we even need the `>: Missing` to help the compiler? + # TODO: Do we even need the `>: Missing`, i.e. does it even help the compiler? if eltype(val) >: Missing && val === missing return true else From 801138182e46fb747ff88901f593583477f666c2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 05:35:57 +0100 Subject: [PATCH 76/86] improvement to remove_prefix in tests --- test/contexts.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/contexts.jl b/test/contexts.jl index fa96a0ec4..da5abf53f 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -51,7 +51,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() Return `vn` but now with the prefix removed. """ -remove_prefix(vn::VarName) = VarName{Symbol(split(string(vn), ".")[end])}(vn.indexing) +remove_prefix(vn::VarName) = VarName{Symbol(split(string(vn), DynamicPPL.PREFIX_SEPARATOR)[end])}(vn.indexing) """ varnames(vn::VarName, val) From ac5c291d3ad7158f0b2dc2d7953bafd5ec2525b9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 05:38:50 +0100 Subject: [PATCH 77/86] formatting --- test/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/contexts.jl b/test/contexts.jl index da5abf53f..5851d75ca 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -51,7 +51,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() Return `vn` but now with the prefix removed. """ -remove_prefix(vn::VarName) = VarName{Symbol(split(string(vn), DynamicPPL.PREFIX_SEPARATOR)[end])}(vn.indexing) +function remove_prefix(vn::VarName) + return VarName{Symbol(split(string(vn), DynamicPPL.PREFIX_SEPARATOR)[end])}(vn.indexing) +end """ varnames(vn::VarName, val) From 297bf929f34805cb2a51cc1af45424903e6b0219 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 05:43:30 +0100 Subject: [PATCH 78/86] fixing some tests --- src/model.jl | 2 ++ test/contexts.jl | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index b7f3c6804..8f2fae998 100644 --- a/src/model.jl +++ b/src/model.jl @@ -328,6 +328,8 @@ Return `NamedTuple` of values that are conditioned on under `model`. # Examples ```jldoctest +julia> using DynamicPPL: conditioned + julia> @model function demo() m ~ Normal() x ~ Normal(m, 1) diff --git a/test/contexts.jl b/test/contexts.jl index 5851d75ca..ca08e4e97 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -52,7 +52,7 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() Return `vn` but now with the prefix removed. """ function remove_prefix(vn::VarName) - return VarName{Symbol(split(string(vn), DynamicPPL.PREFIX_SEPARATOR)[end])}(vn.indexing) + return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}(vn.indexing) end """ From bab9e19c99394dccdace9bb31da3c1902c591d03 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 06:14:22 +0100 Subject: [PATCH 79/86] fixed bug in tilde_observe for prefix context --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index c278a8824..834680d39 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -192,8 +192,8 @@ function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ From 2d1985bf0500799c872c0154c76e9a000104c7c2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 12 Aug 2021 06:14:40 +0100 Subject: [PATCH 80/86] attempted fix for doctests --- src/model.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 8f2fae998..53771bf56 100644 --- a/src/model.jl +++ b/src/model.jl @@ -119,7 +119,6 @@ julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reprodu julia> @model function demo() m ~ Normal() x ~ Normal(m, 1) - return (; m, x) end demo (generic function with 1 method) @@ -328,13 +327,15 @@ Return `NamedTuple` of values that are conditioned on under `model`. # Examples ```jldoctest -julia> using DynamicPPL: conditioned +julia> using Distributions + +julia> using DynamicPPL: conditioned, contextualize julia> @model function demo() m ~ Normal() x ~ Normal(m, 1) end -demo (generic function with 1 methods) +demo (generic function with 1 method) julia> m = demo(); From 67dd45acf68b796c5e6a4efe7863467e0e7ab01a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Aug 2021 21:12:48 +0100 Subject: [PATCH 81/86] fixed doctests --- src/model.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/model.jl b/src/model.jl index 53771bf56..9ac59e1ca 100644 --- a/src/model.jl +++ b/src/model.jl @@ -159,7 +159,6 @@ julia> @model function demo_mv(::Type{TV}=Float64) where {TV} m = Vector{TV}(undef, 2) m[1] ~ Normal() m[2] ~ Normal() - return m end demo_mv (generic function with 2 methods) @@ -204,7 +203,7 @@ demo_outer (generic function with 1 method) julia> model = demo_outer(); julia> model(rng) -0.683947930996541 +-0.7935128416361353 julia> conditioned_model = model | (m = 1.0, ); @@ -225,7 +224,7 @@ julia> # This doesn't work now! conditioned_model = demo_outer_prefix() | (m = 1.0, ); julia> conditioned_model(rng) --1.019202452456547 +1.7747246334368165 julia> # `m` in `demo_inner` is referred to as `inner.m` internally, so we do: conditioned_model = demo_outer_prefix() | (var"inner.m" = 1.0, ); @@ -278,7 +277,6 @@ julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reprodu julia> @model function demo() m ~ Normal() x ~ Normal(m, 1) - return (; m, x) end demo (generic function with 1 method) From 2735784c8c32c893a034a09154d84ceabb3d4376 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Aug 2021 21:15:14 +0100 Subject: [PATCH 82/86] formatting --- test/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/contexts.jl b/test/contexts.jl index ca08e4e97..80615cbd9 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -52,7 +52,9 @@ Base.IteratorEltype(::Type{<:AbstractContext}) = Base.EltypeUnknown() Return `vn` but now with the prefix removed. """ function remove_prefix(vn::VarName) - return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}(vn.indexing) + return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( + vn.indexing + ) end """ From ed7aa8b9546d71a93127c7ce48f3094944153af2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Aug 2021 23:37:12 +0100 Subject: [PATCH 83/86] fixed doctests --- src/model.jl | 8 ++++---- test/runtests.jl | 10 +++++++++- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/model.jl b/src/model.jl index 9ac59e1ca..a081bf831 100644 --- a/src/model.jl +++ b/src/model.jl @@ -119,7 +119,7 @@ julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reprodu julia> @model function demo() m ~ Normal() x ~ Normal(m, 1) - return (; m, x) + return (; m=m, x=x) end demo (generic function with 1 method) @@ -233,8 +233,8 @@ julia> conditioned_model(rng) 1.0 julia> # Note that the above `var"..."` is just standard Julia syntax: - typeof((var"inner.m" = 1.0, )) -NamedTuple{(Symbol("inner.m"),), Tuple{Float64}} + keys((var"inner.m" = 1.0, )) +(Symbol("inner.m"),) ``` The difference is maybe more obvious once we look at how these different @@ -277,7 +277,7 @@ julia> using Distributions; using StableRNGs; rng = StableRNG(42); # For reprodu julia> @model function demo() m ~ Normal() x ~ Normal(m, 1) - return (; m, x) + return (; m=m, x=x) end demo (generic function with 1 method) diff --git a/test/runtests.jl b/test/runtests.jl index d83be0eea..bb2ae579c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -57,7 +57,15 @@ include("test_util.jl") DocMeta.setdocmeta!( DynamicPPL, :DocTestSetup, :(using DynamicPPL); recursive=true ) - doctest(DynamicPPL; manual=false) + doctestfilters = [ + # Older versions will show "0 element Array" instead of "Type[]". + r"(Any\[\]|0-element Array{.+,[0-9]+})", + # Older versions will show "Array{...,1}" instead of "Vector{...}". + r"(Array{.+,\s?1}|Vector{.+})", + # Older versions will show "Array{...,2}" instead of "Matrix{...}". + r"(Array{.+,\s?2}|Matrix{.+})", + ] + doctest(DynamicPPL; manual=false, doctestfilters=doctestfilters) end end From 333fec7bd85c23266323a3da98df37e0c7d1f97a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 13 Aug 2021 23:45:29 +0100 Subject: [PATCH 84/86] fixed doctests on 1.3 --- src/compiler.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 2b0b8fe30..4594808af 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -140,7 +140,7 @@ variables. # Example ```jldoctest; setup=:(using Distributions) -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(1, 1.0), randn(1, 2), @varname(x)); string(vns[end]) +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal([1.0, 1.0], [1.0 0.0; 0.0 1.0]), randn(2, 2), @varname(x)); string(vns[end]) "x[:,2]" julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end]) From 73736efe303f8176cf14517f9e1547a926bc5806 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 14 Aug 2021 01:01:38 +0100 Subject: [PATCH 85/86] bump of minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c6866cc53..bb8faf6c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.13.2" +version = "0.14.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 264e22b6e1289b374e1b4a0f215cb77266a11f24 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 14 Aug 2021 01:19:06 +0100 Subject: [PATCH 86/86] bump version of integration tests --- test/turing/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/turing/Project.toml b/test/turing/Project.toml index 8edbb9389..95212c061 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.13" +DynamicPPL = "0.14" Turing = "0.17" julia = "1.3"