From a2e8f9bd1c7d678bff34c310e3a23e339067df5e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 05:37:54 +0100 Subject: [PATCH 01/15] added WrappedContext and others --- src/contexts.jl | 139 +++++++++++++++++++++++++++++------------------- 1 file changed, 83 insertions(+), 56 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 1ee43f2b2..119b2ce85 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,46 +1,53 @@ +abstract type PrimitiveContext <: AbstractContext end +struct EvaluationContext{S<:AbstractSampler} <: PrimitiveContext + # TODO: do we even need the sampler these days? + sampler::S +end +EvaluationContext() = EvaluationContext(SampleFromPrior()) + +struct SamplingContext{R<:Random.AbstractRNG,S<:AbstractSampler} <: PrimitiveContext + rng::R + sampler::S +end +SamplingContext(sampler=SampleFromPrior()) = SamplingContext(Random.GLOBAL_RNG, sampler) + +######################## +### Wrapped contexts ### +######################## +abstract type WrappedContext{LeafCtx<:PrimitiveContext} <: AbstractContext end + """ - unwrap_childcontext(context::AbstractContext) + childcontext(ctx) -Return a tuple of the child context of a `context`, or `nothing` if the context does -not wrap any other context, and a function `f(c::AbstractContext)` that constructs -an instance of `context` in which the child context is replaced with `c`. +Returns the child-context of `ctx`. -Falls back to `(nothing, _ -> context)`. +Returns `nothing` if `ctx` is not a `WrappedContext`. """ -function unwrap_childcontext(context::AbstractContext) - reconstruct_context(@nospecialize(x)) = context - return nothing, reconstruct_context -end +childcontext(ctx::WrappedContext) = ctx.ctx +childcontext(ctx::AbstractContext) = nothing """ - SamplingContext(rng, sampler, context) + unwrap(ctx::AbstractContext) -Create a context that allows you to sample parameters with the `sampler` when running the model. -The `context` determines how the returned log density is computed when running the model. +Returns the unwrapped context from `ctx`. +""" +unwrap(ctx::WrappedContext) = unwrap(ctx.ctx) +unwrap(ctx::AbstractContext) = ctx -See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) """ -struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext - rng::R - sampler::S - context::C -end + unwrappedtype(ctx::AbstractContext) -function unwrap_childcontext(context::SamplingContext) - child = context.context - function reconstruct_samplingcontext(c::AbstractContext) - return SamplingContext(context.rng, context.sampler, c) - end - return child, reconstruct_samplingcontext -end +Returns the type of the unwrapped context from `ctx`. +""" +unwrappedtype(ctx::AbstractContext) = typeof(ctx) +unwrappedtype(ctx::WrappedContext{LeafCtx}) where {LeafCtx} = LeafCtx """ - struct DefaultContext <: AbstractContext end + rewrap(parent::WrappedContext, leaf::PrimitiveContext) -The `DefaultContext` is used by default to compute log the joint probability of the data -and parameters when running the model. +Rewraps `leaf` in `parent`. Supports nested `WrappedContext`. """ -struct DefaultContext <: AbstractContext end +rewrap(::AbstractContext, leaf::PrimitiveContext) = leaf """ struct PriorContext{Tvars} <: AbstractContext @@ -50,10 +57,18 @@ struct DefaultContext <: AbstractContext end The `PriorContext` enables the computation of the log prior of the parameters `vars` when running the model. """ -struct PriorContext{Tvars} <: AbstractContext +struct PriorContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars + ctx::Ctx + + PriorContext(vars, ctx) = new{typeof(vars),typeof(ctx),unwrappedtype(ctx)}(vars, ctx) +end +PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) +PriorContext(ctx::AbstractContext) = PriorContext(nothing, ctx) + +function rewrap(parent::PriorContext, leaf::PrimitiveContext) + return PriorContext(parent.vars, rewrap(childcontext(parent), leaf)) end -PriorContext() = PriorContext(nothing) """ struct LikelihoodContext{Tvars} <: AbstractContext @@ -64,10 +79,20 @@ The `LikelihoodContext` enables the computation of the log likelihood of the par running the model. `vars` can be used to evaluate the log likelihood for specific values of the model's parameters. If `vars` is `nothing`, the parameter values inside the `VarInfo` will be used by default. """ -struct LikelihoodContext{Tvars} <: AbstractContext +struct LikelihoodContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars + ctx::Ctx + + function LikelihoodContext(vars, ctx) + return new{typeof(vars),typeof(ctx),unwrappedtype(ctx)}(vars, ctx) + end +end +LikelihoodContext(vars=nothing) = LikelihoodContext(vars, EvaluationContext()) +LikelihoodContext(ctx::AbstractContext) = LikelihoodContext(nothing, ctx) + +function rewrap(parent::LikelihoodContext, leaf::PrimitiveContext) + return LikelihoodContext(parent.vars, rewrap(childcontext(parent), leaf)) end -LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext @@ -81,20 +106,24 @@ The `MiniBatchContext` enables the computation of This is useful in batch-based stochastic gradient descent algorithms to be optimizing `log(prior) + log(likelihood of all the data points)` in the expectation. """ -struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx +struct MiniBatchContext{T,Ctx,LeafCtx} <: WrappedContext{LeafCtx} loglike_scalar::T + ctx::Ctx + + function MiniBatchContext(loglike_scalar, ctx::AbstractContext) + return new{typeof(loglike_scalar),typeof(ctx),unwrappedtype(ctx)}( + loglike_scalar, ctx + ) + end end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) + +MiniBatchContext(loglike_scalar) = MiniBatchContext(loglike_scalar, EvaluationContext()) +function MiniBatchContext(ctx::AbstractContext=EvaluationContext(); batch_size, npoints) + return MiniBatchContext(npoints / batch_size, ctx) end -function unwrap_childcontext(context::MiniBatchContext) - child = context.context - function reconstruct_minibatchcontext(c::AbstractContext) - return MiniBatchContext(c, context.loglike_scalar) - end - return child, reconstruct_minibatchcontext +function rewrap(parent::MiniBatchContext, leaf::PrimitiveContext) + return MiniBatchContext(parent.loglike_scalar, rewrap(childcontext(parent), leaf)) end """ @@ -108,11 +137,17 @@ unique. See also: [`@submodel`](@ref) """ -struct PrefixContext{Prefix,C} <: AbstractContext +struct PrefixContext{Prefix,C,LeafCtx} <: WrappedContext{LeafCtx} ctx::C + + function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} + return new{Prefix,typeof(ctx),unwrappedtype(ctx)}(ctx) + end end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +PrefixContext{Prefix}() where {Prefix} = PrefixContext{Prefix}(EvaluationContext()) + +function rewrap(parent::PrefixContext{Prefix}, leaf::PrimitiveContext) where {Prefix} + return PrefixContext{Prefix}(rewrap(childcontext(parent), leaf)) end const PREFIX_SEPARATOR = Symbol(".") @@ -121,7 +156,7 @@ function PrefixContext{PrefixInner}( ctx::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated - :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( + :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)))}( ctx.ctx )) else @@ -131,16 +166,8 @@ end function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing)) + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) else VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end - -function unwrap_childcontext(context::PrefixContext{P}) where {P} - child = context.context - function reconstruct_prefixcontext(c::AbstractContext) - return PrefixContext{P}(c) - end - return child, reconstruct_prefixcontext -end From c86a37282b4b047e7f2fa31e3d2c0f8451c3a7b1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 05:38:09 +0100 Subject: [PATCH 02/15] updated context implementations --- src/context_implementations.jl | 361 +++++---------------------------- 1 file changed, 50 insertions(+), 311 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 42a336479..5fd8d6e30 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,105 +18,29 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -""" - tilde_assume(context::SamplingContext, right, vn, inds, vi) - -Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), -accumulate the log probability, and return the sampled value with a context associated -with a sampler. - -Falls back to -```julia -tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) -``` -if the context `context.ctx` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. -""" -function tilde_assume(context::SamplingContext, right, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) - else - tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) - end -end - -# Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -function tilde_assume( - rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi -) - return assume(rng, sampler, right, vn, inds, vi) -end - -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) - if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return tilde_assume(PriorContext(), right, vn, inds, vi) +function tilde_assume(ctx::SamplingContext, right, vn, inds, vi) + return assume(ctx.rng, ctx.sampler, right, vn, inds, vi) end -function tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, -) - if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) +tilde_assume(ctx::EvaluationContext, right, vn, inds, vi) = assume(right, vn, inds, vi) +function tilde_assume(ctx::PriorContext, right, vn, inds, vi) + if ctx.vars !== nothing + vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, vn, inds, vi) end -function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, inds, vi) -end -function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, inds, vi) -end - -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) - if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) - settrans!(vi, false, vn) - end - return tilde_assume(LikelihoodContext(), right, vn, inds, vi) -end -function tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, -) - if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) +function tilde_assume(ctx::LikelihoodContext, right, vn, inds, vi) + if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) -end -function tilde_assume(::LikelihoodContext, right, vn, inds, vi) - return assume(NoDist(right), vn, inds, vi) + return tilde_assume(childcontext(ctx), NoDist(right), vn, inds, vi) end -function tilde_assume( - rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi -) - return assume(rng, sampler, NoDist(right), vn, inds, vi) +function tilde_assume(ctx::MiniBatchContext, right, left, inds, vi) + return tilde_assume(childcontext(ctx), right, left, inds, vi) end - -function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, vn, inds, vi) -end - -function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) +function tilde_assume(ctx::PrefixContext, right, vn, inds, vi) + return tilde_assume(childcontext(ctx), right, prefix(ctx, vn), inds, vi) end """ @@ -134,68 +58,18 @@ function tilde_assume!(ctx, right, vn, inds, vi) end # observe -""" - tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) - -Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. -""" -function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) - else - c - end - return tilde_observe(fallback_context, right, left, vname, vinds, vi) -end - -""" - tilde_observe(context::SamplingContext, right, left, vi) - -Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. -""" -function tilde_observe(context::SamplingContext, right, left, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) - else - c - end - return tilde_observe(fallback_context, right, left, vi) +function tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, right, left, vi) + return observe(right, left, vi) end - -# Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) -tilde_observe(::PriorContext, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) - -# `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) -end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * - tilde_observe(context.ctx, right, left, vname, vinds, vi) +tilde_observe(ctx::PriorContext, right, left, vi) = 0 +function tilde_observe(ctx::LikelihoodContext, right, left, vi) + return tilde_observe(childcontext(ctx), right, left, vi) end - -# `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) +function tilde_observe(ctx::MiniBatchContext, right, left, vi) + return ctx.loglike_scalar * tilde_observe(childcontext(ctx), right, left, vi) end -function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.ctx, right, left, vi) +function tilde_observe(ctx::PrefixContext, right, left, vi) + return tilde_observe(childcontext(ctx), right, left, vi) end """ @@ -274,142 +148,32 @@ end # .~ functions # assume -""" - dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) - -Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the -model inputs), accumulate the log probability, and return the sampled value for a context -associated with a sampler. - -Falls back to -```julia -dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) -``` -if the context `context.ctx` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. -""" -function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) - else - dot_tilde_assume( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi - ) - end +function dot_tilde_assume(ctx::SamplingContext, right, left, vns, _, vi) + return dot_assume(ctx.rng, ctx.sampler, right, vns, left, vi) end - -# `DefaultContext` -function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(right, vns, left, vi) -end - -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(rng, sampler, right, vns, left, vi) +function dot_tilde_assume(ctx::EvaluationContext, right, left, vns, inds, vi) + return dot_assume(right, vns, left, inds, vi) end - -# `LikelihoodContext` -function dot_tilde_assume( - context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi -) - return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) - else - dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::LikelihoodContext{<:NamedTuple}, - sampler, - right, - left, - vn, - inds, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) - else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) +function dot_tilde_assume(ctx::LikelihoodContext, right, left, vns, inds, vi) + sym = getsym(vns) + if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) + var = _getindex(getfield(ctx.vars, sym), inds) + set_val!(vi, vns, right, var) + settrans!.(Ref(vi), false, vns) end + return dot_tilde_assume(childcontext(ctx), NoDist.(right), vns, left, vi) end -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value +function dot_tilde_assume(ctx::MiniBatchContext, right, left, vns, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, left, vns, inds, vi) end - -# `PriorContext` -function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) - return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) - else - dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) - end -end -function dot_tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - left, - vn, - inds, - vi, -) - return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) - _right, _left, _vns = unwrap_right_left_vns(right, var, vn) - set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) - else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) +function dot_tilde_assume(ctx::PriorContext, right, left, vns, inds, vi) + sym = getsym(vns) + if ctx.vars !== nothing + var = _getindex(getfield(ctx.vars, sym), inds) + set_val!(vi, vns, right, var) + settrans!.(Ref(vi), false, vns) end -end -function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - value, logp = dot_assume(right, left, vn, inds, vi) - acclogp!(vi, logp) - return value -end -function dot_tilde_assume( - rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value -end - -# `MiniBatchContext` -function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) -end - -# `PrefixContext` -function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) + return dot_tilde_assume(childcontext(ctx), right, vns, left, vi) end """ @@ -576,40 +340,15 @@ function set_val!( end # observe -""" - dot_tilde_observe(context::SamplingContext, right, left, vi) - -Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log -probability, and return the observed value for a context associated with a sampler. - -Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. -""" -function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) -end - -# Leaf contexts -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) -dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, right, left, vi) return dot_observe(right, left, vi) end - -# `MiniBatchContext` -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) -end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * - dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) -end - -# `PrefixContext` -function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) +dot_tilde_observe(ctx::PriorContext, right, left, vi) = 0 +function dot_tilde_observe(ctx::LikelihoodContext, right, left, vi) + return dot_observe(childcontext(ctx), right, left, vi) end -function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vi) +function dot_tilde_observe(ctx::MiniBatchContext, right, left, vi) + return ctx.loglike_scalar * dot_tilde_observe(childcontext(ctx), right, left, vi) end """ From 0fb60901b4290fa86323683b847a58d929fb4112 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 05:38:22 +0100 Subject: [PATCH 03/15] updated compiler --- src/compiler.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dc70ae267..8fbb77f16 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -391,11 +391,13 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. + @gensym leafctx evaluatordef[:body] = quote # in case someone accessed these - if __context__ isa $(DynamicPPL.SamplingContext) - __rng__ = __context__.rng - __sampler__ = __context__.sampler + $leafctx = DynamicPPL.unwrap(__context__) + if $leafctx isa $(DynamicPPL.SamplingContext) + __rng__ = $leafctx.rng + __sampler__ = $leafctx.sampler end $(modelinfo[:body]) From 74874a9001848d83908386f6f7fe7f9d9f73135e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 05:38:29 +0100 Subject: [PATCH 04/15] updated model constructor --- src/model.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/model.jl b/src/model.jl index 2d74949c1..dab8f21ca 100644 --- a/src/model.jl +++ b/src/model.jl @@ -86,9 +86,12 @@ function (model::Model)( rng::Random.AbstractRNG, varinfo::AbstractVarInfo=VarInfo(), sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SamplingContext(rng, sampler), ) - return model(varinfo, SamplingContext(rng, sampler, context)) + # In case `context` is a `WrapperContext` of sorts, we need to `rewrap` to ensure + # that context has a `SamplingContext` as the leaf context. + context_new = rewrap(context, SamplingContext(rng, sampler)) + return model(varinfo, context_new) end (model::Model)(context::AbstractContext) = model(VarInfo(), context) From 3bec3a53bf0e1eb5effd92cb765d276bf4583c26 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 05:42:59 +0100 Subject: [PATCH 05/15] better tilde_assume impls for PriorContext and LikelihoodContext --- src/context_implementations.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5fd8d6e30..afbd1facf 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -23,14 +23,20 @@ function tilde_assume(ctx::SamplingContext, right, vn, inds, vi) end tilde_assume(ctx::EvaluationContext, right, vn, inds, vi) = assume(right, vn, inds, vi) function tilde_assume(ctx::PriorContext, right, vn, inds, vi) - if ctx.vars !== nothing + return tilde_assume(childcontext(ctx), right, vn, inds, vi) +end +function tilde_assume(ctx::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end return tilde_assume(childcontext(ctx), right, vn, inds, vi) end function tilde_assume(ctx::LikelihoodContext, right, vn, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) + return tilde_assume(childcontext(ctx), NoDist(right), vn, inds, vi) +end +function tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end @@ -155,8 +161,11 @@ function dot_tilde_assume(ctx::EvaluationContext, right, left, vns, inds, vi) return dot_assume(right, vns, left, inds, vi) end function dot_tilde_assume(ctx::LikelihoodContext, right, left, vns, inds, vi) + return dot_tilde_assume(childcontext(ctx), NoDist.(right), vns, left, vi) +end +function dot_tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, left, vns, inds, vi) sym = getsym(vns) - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) + if haskeyctx.vars, sym) var = _getindex(getfield(ctx.vars, sym), inds) set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) @@ -167,8 +176,11 @@ function dot_tilde_assume(ctx::MiniBatchContext, right, left, vns, inds, vi) return dot_tilde_assume(childcontext(ctx), right, left, vns, inds, vi) end function dot_tilde_assume(ctx::PriorContext, right, left, vns, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, vns, left, vi) +end +function dot_tilde_assume(ctx::PriorContext{<:NamedTuple}, right, left, vns, inds, vi) sym = getsym(vns) - if ctx.vars !== nothing + if haskey(ctx.vars, sym) var = _getindex(getfield(ctx.vars, sym), inds) set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) From 06fee8f8b3c381ee748b88d76dc11b04a15e26be Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 05:44:31 +0100 Subject: [PATCH 06/15] fixed typo --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index afbd1facf..b1ca55e21 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -165,7 +165,7 @@ function dot_tilde_assume(ctx::LikelihoodContext, right, left, vns, inds, vi) end function dot_tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, left, vns, inds, vi) sym = getsym(vns) - if haskeyctx.vars, sym) + if haskey(ctx.vars, sym) var = _getindex(getfield(ctx.vars, sym), inds) set_val!(vi, vns, right, var) settrans!.(Ref(vi), false, vns) From c12b01006711c868bae10fa6d6eb8485b02ee84c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 05:47:26 +0100 Subject: [PATCH 07/15] updated VarInfo constructor --- src/varinfo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/varinfo.jl b/src/varinfo.jl index e5e71eed1..f40e83a1e 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -126,7 +126,7 @@ function VarInfo( rng::Random.AbstractRNG, model::Model, sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), + context::AbstractContext=SamplingContext(rng, sampler), ) varinfo = VarInfo() model(rng, varinfo, sampler, context) From 5dacb91b4176e15115d6a0c39547e522742c8add Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 05:48:31 +0100 Subject: [PATCH 08/15] export SamplingContext and EvaluationContext --- src/DynamicPPL.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..5c34ccc33 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,7 +76,8 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts - DefaultContext, + SamplingContext, + EvaluationContext, LikelihoodContext, PriorContext, MiniBatchContext, From 1b23a92fe417604019b9c94eb9af250b95852513 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:06:30 +0100 Subject: [PATCH 09/15] unwrap sampler in matchingvalue since this is available in either PrimitiveContext --- src/model.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index dab8f21ca..e7d86fe2a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -161,7 +161,7 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] return quote - sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() + sampler = unwrap(context).sampler model.f(model, varinfo, context, $(unwrap_args...)) end end From f4011cc68667b0ce81ea237e7615c5abac403255 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:11:15 +0100 Subject: [PATCH 10/15] fixed impls of dot_tilde_assuem for LikelihoodContext and PriorContext --- src/context_implementations.jl | 39 ++++++++++++++++++++++------------ 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b1ca55e21..5d1357216 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -163,14 +163,23 @@ end function dot_tilde_assume(ctx::LikelihoodContext, right, left, vns, inds, vi) return dot_tilde_assume(childcontext(ctx), NoDist.(right), vns, left, vi) end -function dot_tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, left, vns, inds, vi) - sym = getsym(vns) - if haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +function dot_tilde_assume( + ctx::LikelihoodContext{<:NamedTuple}, + right, + left, + vn, + inds, + vi, +) + return if haskey(ctx.vars, getsym(vn)) + var = _getindex(getfield(ctx.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(childcontext(ctx), NoDist.(_right), _left, _vns, inds, vi) + else + dot_tilde_assume(childcontext(ctx), NoDist.(right), left, vn, inds, vi) end - return dot_tilde_assume(childcontext(ctx), NoDist.(right), vns, left, vi) end function dot_tilde_assume(ctx::MiniBatchContext, right, left, vns, inds, vi) return dot_tilde_assume(childcontext(ctx), right, left, vns, inds, vi) @@ -178,14 +187,16 @@ end function dot_tilde_assume(ctx::PriorContext, right, left, vns, inds, vi) return dot_tilde_assume(childcontext(ctx), right, vns, left, vi) end -function dot_tilde_assume(ctx::PriorContext{<:NamedTuple}, right, left, vns, inds, vi) - sym = getsym(vns) - if haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +function dot_tilde_assume(ctx::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(ctx.vars, getsym(vn)) + var = _getindex(getfield(ctx.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(childcontext(ctx), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(childcontext(ctx), right, left, vn, inds, vi) end - return dot_tilde_assume(childcontext(ctx), right, vns, left, vi) end """ From f70fba1932f2f8241ea4220a9a3164248c8345e7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:12:07 +0100 Subject: [PATCH 11/15] formatting --- src/context_implementations.jl | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5d1357216..d52cd4cd1 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -163,14 +163,7 @@ end function dot_tilde_assume(ctx::LikelihoodContext, right, left, vns, inds, vi) return dot_tilde_assume(childcontext(ctx), NoDist.(right), vns, left, vi) end -function dot_tilde_assume( - ctx::LikelihoodContext{<:NamedTuple}, - right, - left, - vn, - inds, - vi, -) +function dot_tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi) return if haskey(ctx.vars, getsym(vn)) var = _getindex(getfield(ctx.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) From 036073bff4094c84ca94390816a74115a36957a2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:28:36 +0100 Subject: [PATCH 12/15] improved context implementations --- src/context_implementations.jl | 105 +++++++++++++++++++++++++-------- 1 file changed, 79 insertions(+), 26 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d52cd4cd1..65f38ad64 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,20 +18,27 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume +# Leaf contexts function tilde_assume(ctx::SamplingContext, right, vn, inds, vi) return assume(ctx.rng, ctx.sampler, right, vn, inds, vi) end tilde_assume(ctx::EvaluationContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -function tilde_assume(ctx::PriorContext, right, vn, inds, vi) - return tilde_assume(childcontext(ctx), right, vn, inds, vi) + +# Default for `WrappedContext` +function tilde_assume(ctx::WrappedContext, right, left, inds, vi) + return tilde_assume(childcontext(ctx), right, left, inds, vi) end + +# `PriorContext` function tilde_assume(ctx::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(ctx.vars, getsym(vn)) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(childcontext(ctx), right, vn, inds, vi) + return tilde_assume(PriorContext(nothing, childcontext(ctx)), right, vn, inds, vi) end + +# `LikelihoodContext` function tilde_assume(ctx::LikelihoodContext, right, vn, inds, vi) return tilde_assume(childcontext(ctx), NoDist(right), vn, inds, vi) end @@ -40,11 +47,10 @@ function tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(childcontext(ctx), NoDist(right), vn, inds, vi) -end -function tilde_assume(ctx::MiniBatchContext, right, left, inds, vi) - return tilde_assume(childcontext(ctx), right, left, inds, vi) + return tilde_assume(LikelihoodContext(nothing, childcontext(ctx)), right, vn, inds, vi) end + +# `PrefixContext` function tilde_assume(ctx::PrefixContext, right, vn, inds, vi) return tilde_assume(childcontext(ctx), right, prefix(ctx, vn), inds, vi) end @@ -64,19 +70,23 @@ function tilde_assume!(ctx, right, vn, inds, vi) end # observe +# Leaf contexts function tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, right, left, vi) return observe(right, left, vi) end -tilde_observe(ctx::PriorContext, right, left, vi) = 0 -function tilde_observe(ctx::LikelihoodContext, right, left, vi) + +# Default for `WrappedContext` +function tilde_observe(ctx::WrappedContext, right, left, vi) return tilde_observe(childcontext(ctx), right, left, vi) end + +# `PriorContext` +tilde_observe(ctx::PriorContext, right, left, vi) = 0 + +# `MiniBatchContext` function tilde_observe(ctx::MiniBatchContext, right, left, vi) return ctx.loglike_scalar * tilde_observe(childcontext(ctx), right, left, vi) end -function tilde_observe(ctx::PrefixContext, right, left, vi) - return tilde_observe(childcontext(ctx), right, left, vi) -end """ tilde_observe!(ctx, right, left, vname, vinds, vi) @@ -154,12 +164,20 @@ end # .~ functions # assume +# Leaf contexts function dot_tilde_assume(ctx::SamplingContext, right, left, vns, _, vi) return dot_assume(ctx.rng, ctx.sampler, right, vns, left, vi) end function dot_tilde_assume(ctx::EvaluationContext, right, left, vns, inds, vi) return dot_assume(right, vns, left, inds, vi) end + +# Default for `WrappedContext` +function dot_tilde_assume(ctx::WrappedContext, right, left, vns, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, vns, left, vi) +end + +# `LikelihoodContext` function dot_tilde_assume(ctx::LikelihoodContext, right, left, vns, inds, vi) return dot_tilde_assume(childcontext(ctx), NoDist.(right), vns, left, vi) end @@ -169,29 +187,38 @@ function dot_tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, left, vn, _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(childcontext(ctx), NoDist.(_right), _left, _vns, inds, vi) + dot_tilde_assume( + LikelihoodContext(nothing, childcontext(ctx)), _right, _left, _vns, inds, vi + ) else - dot_tilde_assume(childcontext(ctx), NoDist.(right), left, vn, inds, vi) + dot_tilde_assume( + LikelihoodContext(nothing, childcontext(ctx)), right, left, vn, inds, vi + ) end end -function dot_tilde_assume(ctx::MiniBatchContext, right, left, vns, inds, vi) - return dot_tilde_assume(childcontext(ctx), right, left, vns, inds, vi) -end -function dot_tilde_assume(ctx::PriorContext, right, left, vns, inds, vi) - return dot_tilde_assume(childcontext(ctx), right, vns, left, vi) -end + +# `PriorContext` function dot_tilde_assume(ctx::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) return if haskey(ctx.vars, getsym(vn)) var = _getindex(getfield(ctx.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(childcontext(ctx), _right, _left, _vns, inds, vi) + dot_tilde_assume( + PriorContext(nothing, childcontext(ctx)), _right, _left, _vns, inds, vi + ) else - dot_tilde_assume(childcontext(ctx), right, left, vn, inds, vi) + dot_tilde_assume( + PriorContext(nothing, childcontext(ctx)), right, left, vn, inds, vi + ) end end +# `PrefixContext` +function dot_tilde_assume(ctx::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume(childcontext(ctx), right, prefix.(Ref(ctx), vn), inds, vi) +end + """ dot_tilde_assume!(ctx, right, left, vn, inds, vi) @@ -359,12 +386,38 @@ end function dot_tilde_observe(ctx::Union{SamplingContext,EvaluationContext}, right, left, vi) return dot_observe(right, left, vi) end +function dot_tilde_observe( + ctx::Union{SamplingContext,EvaluationContext}, right, left, vname, vinds, vi +) + return dot_observe(right, left, vi) +end +# Default for `WrappedContext` +function dot_tilde_observe(ctx::WrappedContext, right, left, vi) + return dot_tilde_observe(childcontext(ctx), right, left, vi) +end +function dot_tilde_observe(ctx::WrappedContext, right, left, vname, vinds, vi) + return dot_tilde_observe(childcontext(ctx), right, left, vname, vinds, vi) +end + +# `PriorContext` dot_tilde_observe(ctx::PriorContext, right, left, vi) = 0 -function dot_tilde_observe(ctx::LikelihoodContext, right, left, vi) - return dot_observe(childcontext(ctx), right, left, vi) +dot_tilde_observe(ctx::PriorContext, right, left, vname, vinds, vi) = 0 + +# `MiniBatchContext` +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) + return ctx.loglike_scalar * + dot_tilde_observe(childcontext(ctx), sampler, right, left, vi) +end +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * + dot_tilde_observe(childcontext(ctx), sampler, right, left, vname, vinds, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(childcontext(ctx), right, left, vi) + +# `PrefixContext` +function dot_tilde_observe(ctx::PrefixContext, right, left, vname, vinds, vi) + return dot_tilde_observe( + childcontext(ctx), right, left, prefix(context, vname), vinds, vi + ) end """ From dc91169d7f70d52701a94835b97df00f28f22268 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:33:15 +0100 Subject: [PATCH 13/15] simplfied tilde_assume for LikelihoodContext and PriorContext --- src/context_implementations.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 65f38ad64..f68073f62 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -35,7 +35,7 @@ function tilde_assume(ctx::PriorContext{<:NamedTuple}, right, vn, inds, vi) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(PriorContext(nothing, childcontext(ctx)), right, vn, inds, vi) + return tilde_assume(PriorContext(childcontext(ctx)), right, vn, inds, vi) end # `LikelihoodContext` @@ -47,7 +47,7 @@ function tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return tilde_assume(LikelihoodContext(nothing, childcontext(ctx)), right, vn, inds, vi) + return tilde_assume(LikelihoodContext(childcontext(ctx)), right, vn, inds, vi) end # `PrefixContext` @@ -188,11 +188,11 @@ function dot_tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, left, vn, set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) dot_tilde_assume( - LikelihoodContext(nothing, childcontext(ctx)), _right, _left, _vns, inds, vi + LikelihoodContext(childcontext(ctx)), _right, _left, _vns, inds, vi ) else dot_tilde_assume( - LikelihoodContext(nothing, childcontext(ctx)), right, left, vn, inds, vi + LikelihoodContext(childcontext(ctx)), right, left, vn, inds, vi ) end end @@ -205,11 +205,11 @@ function dot_tilde_assume(ctx::PriorContext{<:NamedTuple}, right, left, vn, inds set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) dot_tilde_assume( - PriorContext(nothing, childcontext(ctx)), _right, _left, _vns, inds, vi + PriorContext(childcontext(ctx)), _right, _left, _vns, inds, vi ) else dot_tilde_assume( - PriorContext(nothing, childcontext(ctx)), right, left, vn, inds, vi + PriorContext(childcontext(ctx)), right, left, vn, inds, vi ) end end From 57d4854d1a1f07b656990eb034638bdbd364e20d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:35:45 +0100 Subject: [PATCH 14/15] formatting --- src/context_implementations.jl | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f68073f62..1a647a904 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -191,9 +191,7 @@ function dot_tilde_assume(ctx::LikelihoodContext{<:NamedTuple}, right, left, vn, LikelihoodContext(childcontext(ctx)), _right, _left, _vns, inds, vi ) else - dot_tilde_assume( - LikelihoodContext(childcontext(ctx)), right, left, vn, inds, vi - ) + dot_tilde_assume(LikelihoodContext(childcontext(ctx)), right, left, vn, inds, vi) end end @@ -204,13 +202,9 @@ function dot_tilde_assume(ctx::PriorContext{<:NamedTuple}, right, left, vn, inds _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume( - PriorContext(childcontext(ctx)), _right, _left, _vns, inds, vi - ) + dot_tilde_assume(PriorContext(childcontext(ctx)), _right, _left, _vns, inds, vi) else - dot_tilde_assume( - PriorContext(childcontext(ctx)), right, left, vn, inds, vi - ) + dot_tilde_assume(PriorContext(childcontext(ctx)), right, left, vn, inds, vi) end end From 77d0602cbbd801204dc450edf43bfeb7e3837340 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:48:47 +0100 Subject: [PATCH 15/15] formatting --- src/contexts.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/contexts.jl b/src/contexts.jl index 85d956b07..ef4d600ea 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -61,7 +61,9 @@ struct PriorContext{Tvars,Ctx,LeafCtx} <: WrappedContext{LeafCtx} vars::Tvars context::Ctx - PriorContext(vars, context) = new{typeof(vars),typeof(context),unwrappedtype(context)}(vars, context) + function PriorContext(vars, context) + return new{typeof(vars),typeof(context),unwrappedtype(context)}(vars, context) + end end PriorContext(vars=nothing) = PriorContext(vars, EvaluationContext()) PriorContext(context::AbstractContext) = PriorContext(nothing, context)