From d70e1be46058e912121b41dd0e2f0724c57474c1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:21:20 +0100 Subject: [PATCH 01/46] added sampling context and unwrap_childcontext --- src/contexts.jl | 63 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/contexts.jl b/src/contexts.jl index 4d4f30bdc..1ee43f2b2 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,3 +1,39 @@ +""" + unwrap_childcontext(context::AbstractContext) + +Return a tuple of the child context of a `context`, or `nothing` if the context does +not wrap any other context, and a function `f(c::AbstractContext)` that constructs +an instance of `context` in which the child context is replaced with `c`. + +Falls back to `(nothing, _ -> context)`. +""" +function unwrap_childcontext(context::AbstractContext) + reconstruct_context(@nospecialize(x)) = context + return nothing, reconstruct_context +end + +""" + SamplingContext(rng, sampler, context) + +Create a context that allows you to sample parameters with the `sampler` when running the model. +The `context` determines how the returned log density is computed when running the model. + +See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +""" +struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext + rng::R + sampler::S + context::C +end + +function unwrap_childcontext(context::SamplingContext) + child = context.context + function reconstruct_samplingcontext(c::AbstractContext) + return SamplingContext(context.rng, context.sampler, c) + end + return child, reconstruct_samplingcontext +end + """ struct DefaultContext <: AbstractContext end @@ -53,6 +89,25 @@ function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) return MiniBatchContext(ctx, npoints / batch_size) end +function unwrap_childcontext(context::MiniBatchContext) + child = context.context + function reconstruct_minibatchcontext(c::AbstractContext) + return MiniBatchContext(c, context.loglike_scalar) + end + return child, reconstruct_minibatchcontext +end + +""" + PrefixContext{Prefix}(context) + +Create a context that allows you to use the wrapped `context` when running the model and +adds the `Prefix` to all parameters. + +This context is useful in nested models to ensure that the names of the parameters are +unique. + +See also: [`@submodel`](@ref) +""" struct PrefixContext{Prefix,C} <: AbstractContext ctx::C end @@ -81,3 +136,11 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end + +function unwrap_childcontext(context::PrefixContext{P}) where {P} + child = context.context + function reconstruct_prefixcontext(c::AbstractContext) + return PrefixContext{P}(c) + end + return child, reconstruct_prefixcontext +end From f74399031eca9b5eaab824c16416a825c489f59c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:23:31 +0100 Subject: [PATCH 02/46] updated tilde methods --- src/context_implementations.jl | 451 +++++++++++++++++++++++++-------- 1 file changed, 352 insertions(+), 99 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 0698b6cdf..8aa8ddfca 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -18,28 +18,103 @@ _getindex(x, inds::Tuple) = _getindex(x[first(inds)...], Base.tail(inds)) _getindex(x, inds::Tuple{}) = x # assume -function tilde_assume(rng, ctx::DefaultContext, sampler, right, vn::VarName, _, vi) - return assume(rng, sampler, right, vn, vi) +""" + tilde_assume(context::SamplingContext, right, vn, inds, vi) + +Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), +accumulate the log probability, and return the sampled value with a context associated +with a sampler. + +Falls back to +```julia +tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) +``` +if the context `context.ctx` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function tilde_assume(context::SamplingContext, right, vn, inds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return if child_of_c === nothing + tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) + else + tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) + end +end + +# Leaf contexts +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) +function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(PriorContext(), right, vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + settrans!(vi, false, vn) + end + return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PriorContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars !== nothing - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) +function tilde_assume(::PriorContext, right, vn, inds, vi) + return assume(right, vn, inds, vi) +end +function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, inds, vi) +end + +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, right, vn, vi) + return tilde_assume(LikelihoodContext(), right, vn, inds, vi) end -function tilde_assume(rng, ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi) - if ctx.vars isa NamedTuple && haskey(ctx.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds)) +function tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, + sampler, + right, + vn, + inds, + vi, +) + if haskey(context.vars, getsym(vn)) + vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) settrans!(vi, false, vn) end - return assume(rng, sampler, NoDist(right), vn, vi) + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) +end +function tilde_assume(::LikelihoodContext, right, vn, inds, vi) + return assume(NoDist(right), vn, inds, vi) +end +function tilde_assume( + rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi +) + return assume(rng, sampler, NoDist(right), vn, inds, vi) end -function tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, left, inds, vi) + +function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) + return tilde_assume(context.ctx, right, vn, inds, vi) end -function tilde_assume(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) + +function tilde_assume(context::PrefixContext, right, vn, inds, vi) + return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) end """ @@ -50,27 +125,76 @@ accumulate the log probability, and return the sampled value. Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. """ -function tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(rng, ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, sampler, right, vn, inds, vi) + value, logp = tilde_assume(ctx, sampler, right, vn, inds, vi) acclogp!(vi, logp) return value end # observe -function tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +""" + tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + +Handle observed variables with a `context` associated with a sampler. +Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.ctx` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + fallback_context = if child_of_c !== nothing + reconstruct_c(reconstruct_context(child_of_c)) + else + c + end + return tilde_observe(fallback_context, right, left, vname, vinds, vi) end -function tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 + +""" + tilde_observe(context::SamplingContext, right, left, vi) + +Handle observed constants with a `context` associated with a sampler. +Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring +the information about the sampler if the context `context.ctx` does not call any other +context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls +`tilde_observe(c, right, left, vi)` where `c` is a context in +which the order of the sampling context and its child are swapped. +""" +function tilde_observe(context::SamplingContext, right, left, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + fallback_context = if child_of_c !== nothing + reconstruct_c(reconstruct_context(child_of_c)) + else + c + end + return tilde_observe(fallback_context, right, left, vi) +end + +# Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) + +# `MiniBatchContext` +function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) end -function tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return observe(sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return context.loglike_scalar * tilde_observe(context.ctx, right, left, vname, vinds, vi) end -function tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * tilde_observe(ctx.ctx, sampler, right, left, vi) + +# `PrefixContext` +function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return tilde_observe( + context.ctx, right, left, prefix(context, vname), vinds, vi + ) end -function tilde_observe(ctx::PrefixContext, sampler, right, left, vi) - return tilde_observe(ctx.ctx, sampler, right, left, vi) +function tilde_observe(context::PrefixContext, right, left, vi) + return tilde_observe(context.ctx, right, left, vi) end """ @@ -112,77 +236,179 @@ function observe(spl::Sampler, weight) return error("DynamicPPL.observe: unmanaged inference algorithm: $(typeof(spl))") end +# fallback without sampler +function assume(dist::Distribution, vn::VarName, inds, vi) + if !haskey(vi, vn) + error("variable $vn does not exist") + end + r = vi[vn] + return r, Bijectors.logpdf_with_trans(dist, vi[vn], istrans(vi, vn)) +end + +# SampleFromPrior and SampleFromUniform function assume( - rng, spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, vi + rng::Random.AbstractRNG, + sampler::Union{SampleFromPrior,SampleFromUniform}, + dist::Distribution, + vn::VarName, + inds, + vi, ) + # Always overwrite the parameters with new ones. + r = init(rng, dist, sampler) if haskey(vi, vn) - # Always overwrite the parameters with new ones for `SampleFromUniform`. - if spl isa SampleFromUniform || is_flagged(vi, vn, "del") - unset_flag!(vi, vn, "del") - r = init(rng, dist, spl) - vi[vn] = vectorize(dist, r) - settrans!(vi, false, vn) - setorder!(vi, vn, get_num_produce(vi)) - else - r = vi[vn] - end + vi[vn] = vectorize(dist, r) + setorder!(vi, vn, get_num_produce(vi)) else - r = init(rng, dist, spl) - push!(vi, vn, r, dist, spl) - settrans!(vi, false, vn) + push!(vi, vn, r, dist, sampler) end + settrans!(vi, false, vn) return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end -function observe( - spl::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, value, vi -) +# default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) +function observe(right::Distribution, left, vi) increment_num_produce!(vi) - return Distributions.loglikelihood(dist, value) + return Distributions.loglikelihood(right, left) end # .~ functions # assume -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, _, vi) +""" + dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + +Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the +model inputs), accumulate the log probability, and return the sampled value for a context +associated with a sampler. + +Falls back to +```julia +dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) +``` +if the context `context.ctx` does not call any other context, as indicated by +[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` +where `c` is a context in which the order of the sampling context and its child are swapped. +""" +function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + c, reconstruct_context = unwrap_childcontext(context) + child_of_c, reconstruct_c = unwrap_childcontext(c) + return if child_of_c === nothing + dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) + else + dot_tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi) + end +end + +# `DefaultContext` +function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) + return dot_assume(right, vns, left, vi) +end + +function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end + +# `LikelihoodContext` function dot_tilde_assume( - rng, - ctx::LikelihoodContext, + context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::LikelihoodContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars isa NamedTuple && haskey(ctx.vars, sym) - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, NoDist.(right), vns, left, vi) end -function dot_tilde_assume(rng, ctx::MiniBatchContext, sampler, right, left, vns, inds, vi) - return dot_tilde_assume(rng, ctx.ctx, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) + value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) + acclogp!(vi, logp) + return value end function dot_tilde_assume( - rng, - ctx::PriorContext, + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi +) + value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) + acclogp!(vi, logp) + return value +end + +# `PriorContext` +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + else + dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + end +end +function dot_tilde_assume( + rng::Random.AbstractRNG, + context::PriorContext{<:NamedTuple}, sampler, right, left, - vns::AbstractArray{<:VarName{sym}}, + vn, inds, vi, -) where {sym} - if ctx.vars !== nothing - var = _getindex(getfield(ctx.vars, sym), inds) - set_val!(vi, vns, right, var) - settrans!.(Ref(vi), false, vns) +) + return if haskey(context.vars, getsym(vn)) + var = _getindex(getfield(context.vars, getsym(vn)), inds) + _right, _left, _vns = unwrap_right_left_vns(right, var, vn) + set_val!(vi, _vns, _right, _left) + settrans!.(Ref(vi), false, _vns) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + else + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) end - return dot_assume(rng, sampler, right, vns, left, vi) +end +function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) + value, logp = dot_assume(right, left, vn, inds, vi) + acclogp!(vi, logp) + return value +end +function dot_tilde_assume( + rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi +) + value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) + acclogp!(vi, logp) + return value +end + +# `MiniBatchContext` +function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) +end + +# `PrefixContext` +function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) end """ @@ -193,13 +419,26 @@ model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, sampler, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) acclogp!(vi, logp) return value end -# Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +# `dot_assume` +function dot_assume( + dist::MultivariateDistribution, + var::AbstractMatrix, + vns::AbstractVector{<:VarName}, + inds, + vi, +) + @assert length(dist) == size(var, 1) + lp = sum(zip(vns, eachcol(var))) do vn, ri + return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) + end + return var, lp +end function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -214,6 +453,19 @@ function dot_assume( var .= r return var, lp end + +function dot_assume( + dists::Union{Distribution,AbstractArray{<:Distribution}}, + var::AbstractArray, + vns::AbstractArray{<:VarName}, + inds, + vi, +) + # Make sure `var` is not a matrix for multivariate distributions + lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) + return var, lp +end + function dot_assume( rng, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -323,18 +575,38 @@ function set_val!( end # observe -function dot_tilde_observe(ctx::DefaultContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) -end -function dot_tilde_observe(ctx::PriorContext, sampler, right, left, vi) - return 0 -end -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) - return dot_observe(sampler, right, left, vi) +""" + dot_tilde_observe(context::SamplingContext, right, left, vi) + +Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log +probability, and return the observed value for a context associated with a sampler. + +Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. +""" +function dot_tilde_observe(context::SamplingContext, right, left, vi) + return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) end + +# Leaf contexts +dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) = dot_observe(right, left, vi) + +# `MiniBatchContext` function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end +function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) +end + +# `PrefixContext` +function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) + return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) +end +function dot_tilde_observe(context::PrefixContext, right, left, vi) + return dot_tilde_observe(context.ctx, right, left, vi) +end """ dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) @@ -366,41 +638,22 @@ function dot_tilde_observe!(ctx, sampler, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dist::MultivariateDistribution, - value::AbstractMatrix, - vi, -) +function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::Distribution, - value::AbstractArray, - vi, -) +function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end -function dot_observe( - spl::Union{SampleFromPrior,SampleFromUniform}, - dists::AbstractArray{<:Distribution}, - value::AbstractArray, - vi, -) +function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end -function dot_observe(spl::Sampler, ::Any, ::Any, ::Any) - return error( - "[DynamicPPL] $(alg_str(spl)) doesn't support vectorizing observe statement" - ) -end + From 3d2e7e2b4dfb2462e6732eb3f0b3ac5494d1696f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:23:48 +0100 Subject: [PATCH 03/46] updated model call signature --- src/model.jl | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/model.jl b/src/model.jl index 7189b590e..250b89721 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,12 +88,18 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) + return model(SamplingContext(rng, sampler, context), varinfo) +end + +(model::Model)(context::AbstractContext) = model(VarInfo(), context) +function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(rng, model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, sampler, context) else - return evaluate_threadsafe(rng, model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, sampler, context) end end + function (model::Model)(args...) return model(Random.GLOBAL_RNG, args...) end @@ -109,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(rng, model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, sampler, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -118,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(rng, model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, sampler, context) resetlogp!(varinfo) - return _evaluate(rng, model, varinfo, sampler, context) + return _evaluate(model, varinfo, sampler, context) end """ - evaluate_threadsafe(rng, model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, sampler, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -134,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(rng, model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, sampler, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(rng, model, wrapper, sampler, context) + result = _evaluate(model, wrapper, sampler, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(rng, model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, sampler, context) Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( - rng, model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, sampler, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(rng, model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, sampler, context, $(unwrap_args...))) end """ @@ -183,7 +189,7 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), DefaultContext()) + model(varinfo, DefaultContext()) return getlogp(varinfo) end @@ -195,7 +201,7 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), PriorContext()) + model(varinfo, PriorContext()) return getlogp(varinfo) end @@ -207,7 +213,7 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, SampleFromPrior(), LikelihoodContext()) + model(varinfo, LikelihoodContext()) return getlogp(varinfo) end From 4f1d39694083bae41ac7e94cbb19640639512879 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:24:04 +0100 Subject: [PATCH 04/46] updated compiler --- src/compiler.jl | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 20d8bf8ef..bc906f58c 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -394,10 +394,8 @@ function build_output(modelinfo, linenumbernode) # Add the internal arguments to the user-specified arguments (positional + keywords). evaluatordef[:args] = vcat( [ - :(__rng__::$(Random.AbstractRNG)), :(__model__::$(DynamicPPL.Model)), :(__varinfo__::$(DynamicPPL.AbstractVarInfo)), - :(__sampler__::$(DynamicPPL.AbstractSampler)), :(__context__::$(DynamicPPL.AbstractContext)), ], modelinfo[:allargs_exprs], @@ -407,7 +405,15 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + evaluatordef[:body] = quote + # in case someone accessed these + if __context__ isa $(DynamicPPL.SamplingContext) + __rng__ = __context__.rng + __sampler__ = __context__.sampler + end + + $(modelinfo[:body]) + end ## Build the model function. From b187d74efcfb5b9482f39022252477b5a0bc2cb9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:27:37 +0100 Subject: [PATCH 05/46] formatting --- src/context_implementations.jl | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 8aa8ddfca..a0bf0381a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -45,7 +45,9 @@ end # Leaf contexts tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) -function tilde_assume(rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi) +function tilde_assume( + rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi +) return assume(rng, sampler, right, vn, inds, vi) end @@ -184,14 +186,13 @@ function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vname, vinds, vi) + return context.loglike_scalar * + tilde_observe(context.ctx, right, left, vname, vinds, vi) end # `PrefixContext` function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe( - context.ctx, right, left, prefix(context, vname), vinds, vi - ) + return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) end function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.ctx, right, left, vi) @@ -296,7 +297,9 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) return if child_of_c === nothing dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) else - dot_tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi) + dot_tilde_assume( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi + ) end end @@ -590,14 +593,17 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) = dot_observe(right, left, vi) +function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) + return dot_observe(right, left, vi) +end # `MiniBatchContext` function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) end function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) + return ctx.loglike_scalar * + dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) end # `PrefixContext` @@ -656,4 +662,3 @@ function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, @debug "value = $value" return sum(Distributions.loglikelihood.(dists, value)) end - From ee99f8ce5676c5cb571417e6dd4b3d9570922bf5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:30:54 +0100 Subject: [PATCH 06/46] added getsym for vectors --- src/varname.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/varname.jl b/src/varname.jl index bb936a4ce..40c5c25e9 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -39,3 +39,7 @@ Possibly existing indices of `varname` are neglected. ) where {s,missings,_F,_a,_T} return s in missings end + + +# HACK: Type-piracy. Is this really the way to go? +AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From c4845d08b34b58bc30762b4be944c8924c9794e2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:35:12 +0100 Subject: [PATCH 07/46] Update src/varname.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/varname.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/varname.jl b/src/varname.jl index 40c5c25e9..343bb0da8 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -40,6 +40,5 @@ Possibly existing indices of `varname` are neglected. return s in missings end - # HACK: Type-piracy. Is this really the way to go? AbstractPPL.getsym(::AbstractVector{<:VarName{sym}}) where {sym} = sym From a0c05f39315c93d9ed43cefe227255310172f339 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:43:42 +0100 Subject: [PATCH 08/46] fixed some signatures for Model --- src/model.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/model.jl b/src/model.jl index 250b89721..8d353d2de 100644 --- a/src/model.jl +++ b/src/model.jl @@ -88,7 +88,7 @@ function (model::Model)( sampler::AbstractSampler=SampleFromPrior(), context::AbstractContext=DefaultContext(), ) - return model(SamplingContext(rng, sampler, context), varinfo) + return model(varinfo, SamplingContext(rng, sampler, context)) end (model::Model)(context::AbstractContext) = model(VarInfo(), context) @@ -115,7 +115,7 @@ function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) end """ - evaluate_threadunsafe(model, varinfo, sampler, context) + evaluate_threadunsafe(model, varinfo, context) Evaluate the `model` without wrapping `varinfo` inside a `ThreadSafeVarInfo`. @@ -124,13 +124,13 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ -function evaluate_threadunsafe(model, varinfo, sampler, context) +function evaluate_threadunsafe(model, varinfo, context) resetlogp!(varinfo) - return _evaluate(model, varinfo, sampler, context) + return _evaluate(model, varinfo, context) end """ - evaluate_threadsafe(model, varinfo, sampler, context) + evaluate_threadsafe(model, varinfo, context) Evaluate the `model` with `varinfo` wrapped inside a `ThreadSafeVarInfo`. @@ -140,24 +140,24 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ -function evaluate_threadsafe(model, varinfo, sampler, context) +function evaluate_threadsafe(model, varinfo, context) resetlogp!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) - result = _evaluate(model, wrapper, sampler, context) + result = _evaluate(model, wrapper, context) setlogp!(varinfo, getlogp(wrapper)) return result end """ - _evaluate(model::Model, varinfo, sampler, context) + _evaluate(model::Model, varinfo, context) Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. """ @generated function _evaluate( - model::Model{_F,argnames}, varinfo, sampler, context + model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, sampler, context, $(unwrap_args...))) + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ From 307cd7e1f3a1a02bd79284ccfc640848961e2bd9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 00:49:44 +0100 Subject: [PATCH 09/46] fixed a method call --- src/model.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/model.jl b/src/model.jl index 8d353d2de..3a01f9bf3 100644 --- a/src/model.jl +++ b/src/model.jl @@ -94,9 +94,9 @@ end (model::Model)(context::AbstractContext) = model(VarInfo(), context) function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 - return evaluate_threadunsafe(model, varinfo, sampler, context) + return evaluate_threadunsafe(model, varinfo, context) else - return evaluate_threadsafe(model, varinfo, sampler, context) + return evaluate_threadsafe(model, varinfo, context) end end @@ -151,7 +151,7 @@ end """ _evaluate(model::Model, varinfo, context) -Evaluate the `model` with the arguments matching the given `sampler` and `varinfo` object. +Evaluate the `model` with the arguments matching the given `context` and `varinfo` object. """ @generated function _evaluate( model::Model{_F,argnames}, varinfo, context From 597277119922a24d0783b78134a8f541eb05dc0e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 01:00:49 +0100 Subject: [PATCH 10/46] fixed method signatures --- src/compiler.jl | 8 ------ src/context_implementations.jl | 48 +++++++++++++++++----------------- 2 files changed, 24 insertions(+), 32 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index bc906f58c..8201a82f4 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -283,7 +283,6 @@ function generate_tilde(left, right) return quote $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__, @@ -300,9 +299,7 @@ function generate_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left = $(DynamicPPL.tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn )..., @@ -312,7 +309,6 @@ function generate_tilde(left, right) else $(DynamicPPL.tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, @@ -334,7 +330,6 @@ function generate_dot_tilde(left, right) return quote $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__, @@ -351,9 +346,7 @@ function generate_dot_tilde(left, right) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption $left .= $(DynamicPPL.dot_tilde_assume!)( - __rng__, __context__, - __sampler__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $left, $vn )..., @@ -363,7 +356,6 @@ function generate_dot_tilde(left, right) else $(DynamicPPL.dot_tilde_observe!)( __context__, - __sampler__, $(DynamicPPL.check_tilde_rhs)($right), $left, $vn, diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a0bf0381a..d6ff3b5bd 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -120,15 +120,15 @@ function tilde_assume(context::PrefixContext, right, vn, inds, vi) end """ - tilde_assume!(rng, ctx, sampler, right, vn, inds, vi) + tilde_assume!(ctx, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(rng, ctx, sampler, right, vn, inds, vi)`. +Falls back to `tilde_assume!(ctx, right, vn, inds, vi)`. """ -function tilde_assume!(ctx, sampler, right, vn, inds, vi) - value, logp = tilde_assume(ctx, sampler, right, vn, inds, vi) +function tilde_assume!(ctx, right, vn, inds, vi) + value, logp = tilde_assume(ctx, right, vn, inds, vi) acclogp!(vi, logp) return value end @@ -199,30 +199,30 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + tilde_observe!(ctx, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vname, vinds, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, sampler, right, left, vi) + tilde_observe(ctx, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, sampler, right, left, vi)`. +Falls back to `tilde(ctx, right, left, vi)`. """ -function tilde_observe!(ctx, sampler, right, left, vi) - logp = tilde_observe(ctx, sampler, right, left, vi) +function tilde_observe!(ctx, right, left, vi) + logp = tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end @@ -415,15 +415,15 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) end """ - dot_tilde_assume!(rng, ctx, sampler, right, left, vn, inds, vi) + dot_tilde_assume!(ctx, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(rng, ctx, sampler, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(ctx, sampler, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume!(ctx, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(ctx, right, left, vn, inds, vi) acclogp!(vi, logp) return value end @@ -615,30 +615,30 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe!(ctx, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, sampler, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vn, inds, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, sampler, right, left, vi) + dot_tilde_observe!(ctx, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, sampler, right, left, vi)`. +Falls back to `dot_tilde_observe(ctx, right, left, vi)`. """ -function dot_tilde_observe!(ctx, sampler, right, left, vi) - logp = dot_tilde_observe(ctx, sampler, right, left, vi) +function dot_tilde_observe!(ctx, right, left, vi) + logp = dot_tilde_observe(ctx, right, left, vi) acclogp!(vi, logp) return left end From c4ecd0e676ab58aa13ef9995289d4368868d212c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 01:08:56 +0100 Subject: [PATCH 11/46] sort of fixed the matchingvalue functionality for model --- src/model.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 3a01f9bf3..2d74949c1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -157,7 +157,10 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return :(model.f(model, varinfo, context, $(unwrap_args...))) + return quote + sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() + model.f(model, varinfo, context, $(unwrap_args...)) + end end """ From a34b51cd60fffe8ff45903153550ff3486680f91 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 03:36:55 +0100 Subject: [PATCH 12/46] formatting --- src/compiler.jl | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 8201a82f4..dc70ae267 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -282,10 +282,7 @@ function generate_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.tilde_observe!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end @@ -329,10 +326,7 @@ function generate_dot_tilde(left, right) if !(left isa Symbol || left isa Expr) return quote $(DynamicPPL.dot_tilde_observe!)( - __context__, - $(DynamicPPL.check_tilde_rhs)($right), - $left, - __varinfo__, + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end end From e4a2cf81154e44b23af16e473d1361cf11225fc4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:02:41 +0100 Subject: [PATCH 13/46] removed left-over acclogp! that should not be here anymore --- src/context_implementations.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 42a336479..b088577f5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -345,16 +345,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - value, logp = dot_assume(NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(NoDist.(right), left, vn, inds, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) end # `PriorContext` @@ -390,16 +386,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - value, logp = dot_assume(right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(right, left, vn, inds, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - value, logp = dot_assume(rng, sampler, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value + return dot_assume(rng, sampler, right, left, vn, inds, vi) end # `MiniBatchContext` From 7605785fff5407d558dd920720180efc0e41d885 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:04:29 +0100 Subject: [PATCH 14/46] export SamplingContext --- src/DynamicPPL.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index acdb98183..3ad30972c 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -76,6 +76,7 @@ export AbstractVarInfo, SampleFromPrior, SampleFromUniform, # Contexts + SamplingContext, DefaultContext, LikelihoodContext, PriorContext, From 354ac52b0d2115b2df7d437407b98140a208ba5d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:38:25 +0100 Subject: [PATCH 15/46] use context instead of ctx to refer to contexts --- src/context_implementations.jl | 64 +++++++++++++++++----------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b088577f5..f859b4619 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -120,15 +120,15 @@ function tilde_assume(context::PrefixContext, right, vn, inds, vi) end """ - tilde_assume!(ctx, right, vn, inds, vi) + tilde_assume!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(ctx, right, vn, inds, vi)`. +Falls back to `tilde_assume!(context, right, vn, inds, vi)`. """ -function tilde_assume!(ctx, right, vn, inds, vi) - value, logp = tilde_assume(ctx, right, vn, inds, vi) +function tilde_assume!(context, right, vn, inds, vi) + value, logp = tilde_assume(context, right, vn, inds, vi) acclogp!(vi, logp) return value end @@ -199,30 +199,30 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(ctx, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(ctx, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(ctx, right, left, vname, vinds, vi) - logp = tilde_observe(ctx, right, left, vi) +function tilde_observe!(context, right, left, vname, vinds, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - tilde_observe(ctx, right, left, vi) + tilde_observe(context, right, left, vi) Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(ctx, right, left, vi)`. +Falls back to `tilde(context, right, left, vi)`. """ -function tilde_observe!(ctx, right, left, vi) - logp = tilde_observe(ctx, right, left, vi) +function tilde_observe!(context, right, left, vi) + logp = tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end @@ -302,11 +302,11 @@ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) end # `DefaultContext` -function dot_tilde_assume(ctx::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(right, vns, left, vi) end -function dot_tilde_assume(rng, ctx::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) return dot_assume(rng, sampler, right, vns, left, vi) end @@ -405,15 +405,15 @@ function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) end """ - dot_tilde_assume!(ctx, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(ctx, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(ctx, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(ctx, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, inds, vi) + value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) acclogp!(vi, logp) return value end @@ -583,17 +583,17 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(ctx::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) return dot_observe(right, left, vi) end # `MiniBatchContext` -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vi) - return ctx.loglike_scalar * dot_tilde_observe(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) end -function dot_tilde_observe(ctx::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return ctx.loglike_scalar * - dot_tilde_observe(ctx.ctx, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) + return context.loglike_scalar * + dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) end # `PrefixContext` @@ -605,30 +605,30 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(ctx, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(ctx, right, left, vn, inds, vi) - logp = dot_tilde_observe(ctx, right, left, vi) +function dot_tilde_observe!(context, right, left, vn, inds, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end """ - dot_tilde_observe!(ctx, right, left, vi) + dot_tilde_observe!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(ctx, right, left, vi)`. +Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(ctx, right, left, vi) - logp = dot_tilde_observe(ctx, right, left, vi) +function dot_tilde_observe!(context, right, left, vi) + logp = dot_tilde_observe(context, right, left, vi) acclogp!(vi, logp) return left end From b7a2b3b5b5483eb11741a9a2c7b3135abaecbcad Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:38:46 +0100 Subject: [PATCH 16/46] formatting --- src/context_implementations.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index f859b4619..a8f279804 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -591,7 +591,9 @@ end function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) end -function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe( + context::MiniBatchContext, sampler, right, left, vname, vinds, vi +) return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) end From 9e0fc9a9eecb6f74d54e0b4a1fad9cf94c0b41eb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:39:41 +0100 Subject: [PATCH 17/46] use context instead of ctx for variables --- src/contexts.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/contexts.jl b/src/contexts.jl index 1ee43f2b2..8598fb633 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -71,7 +71,7 @@ LikelihoodContext() = LikelihoodContext(nothing) """ struct MiniBatchContext{Tctx, T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end @@ -82,11 +82,11 @@ This is useful in batch-based stochastic gradient descent algorithms to be optim `log(prior) + log(likelihood of all the data points)` in the expectation. """ struct MiniBatchContext{Tctx,T} <: AbstractContext - ctx::Tctx + context::Tctx loglike_scalar::T end -function MiniBatchContext(ctx=DefaultContext(); batch_size, npoints) - return MiniBatchContext(ctx, npoints / batch_size) +function MiniBatchContext(context=DefaultContext(); batch_size, npoints) + return MiniBatchContext(context, npoints / batch_size) end function unwrap_childcontext(context::MiniBatchContext) @@ -109,23 +109,23 @@ unique. See also: [`@submodel`](@ref) """ struct PrefixContext{Prefix,C} <: AbstractContext - ctx::C + context::C end -function PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} - return PrefixContext{Prefix,typeof(ctx)}(ctx) +function PrefixContext{Prefix}(context::AbstractContext) where {Prefix} + return PrefixContext{Prefix,typeof(context)}(context) end const PREFIX_SEPARATOR = Symbol(".") function PrefixContext{PrefixInner}( - ctx::PrefixContext{PrefixOuter} + context::PrefixContext{PrefixOuter} ) where {PrefixInner,PrefixOuter} if @generated :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}( - ctx.ctx + context.context )) else - PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(context.context) end end From 7a4a1a38ca6895c401e366e9b2707117b5ce36e5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 1 Jun 2021 06:40:18 +0100 Subject: [PATCH 18/46] use context instead of ctx to refer to contexts --- src/context_implementations.jl | 47 ++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a8f279804..e66501aee 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -27,9 +27,9 @@ with a sampler. Falls back to ```julia -tilde_assume(context.rng, context.ctx, context.sampler, right, vn, inds, vi) +tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) ``` -if the context `context.ctx` does not call any other context, as indicated by +if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. """ @@ -112,11 +112,11 @@ function tilde_assume( end function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.ctx, right, prefix(context, vn), inds, vi) + return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end """ @@ -138,8 +138,8 @@ end tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. @@ -159,8 +159,8 @@ end tilde_observe(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.ctx, right, left, vi)` ignoring -the information about the sampler if the context `context.ctx` does not call any other +Falls back to `tilde_observe(context.context, right, left, vi)` ignoring +the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `tilde_observe(c, right, left, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. @@ -183,19 +183,19 @@ tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * tilde_observe(context.ctx, right, left, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) return context.loglike_scalar * - tilde_observe(context.ctx, right, left, vname, vinds, vi) + tilde_observe(context.context, right, left, vname, vinds, vi) end # `PrefixContext` function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) end function tilde_observe(context::PrefixContext, right, left, vi) - return tilde_observe(context.ctx, right, left, vi) + return tilde_observe(context.context, right, left, vi) end """ @@ -283,9 +283,9 @@ associated with a sampler. Falls back to ```julia -dot_tilde_assume(context.rng, context.ctx, context.sampler, right, left, vn, inds, vi) +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) ``` -if the context `context.ctx` does not call any other context, as indicated by +if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` where `c` is a context in which the order of the sampling context and its child are swapped. """ @@ -396,12 +396,12 @@ end # `MiniBatchContext` function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) end # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.ctx, right, prefix.(Ref(context), vn), inds, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) end """ @@ -574,10 +574,10 @@ end Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value for a context associated with a sampler. -Falls back to `dot_tilde_observe(context.ctx, right, left, vi) ignoring the sampler. +Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vname, vinds, vi) + return dot_tilde_observe(context.context, right, left, vname, vinds, vi) end # Leaf contexts @@ -589,21 +589,24 @@ end # `MiniBatchContext` function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * dot_tilde_observe(context.ctx, sampler, right, left, vi) + return context.loglike_scalar * + dot_tilde_observe(context.context, sampler, right, left, vi) end function dot_tilde_observe( context::MiniBatchContext, sampler, right, left, vname, vinds, vi ) return context.loglike_scalar * - dot_tilde_observe(context.ctx, sampler, right, left, vname, vinds, vi) + dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) end # `PrefixContext` function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return dot_tilde_observe(context.ctx, right, left, prefix(context, vname), vinds, vi) + return dot_tilde_observe( + context.context, right, left, prefix(context, vname), vinds, vi + ) end function dot_tilde_observe(context::PrefixContext, right, left, vi) - return dot_tilde_observe(context.ctx, right, left, vi) + return dot_tilde_observe(context.context, right, left, vi) end """ From 7899473512e85089f0a6497823e804c0a2a7e12c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:20:27 +0100 Subject: [PATCH 19/46] Update src/compiler.jl Co-authored-by: David Widmann --- src/compiler.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index dc70ae267..8734b72ed 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -392,12 +392,6 @@ function build_output(modelinfo, linenumbernode) # Replace the user-provided function body with the version created by DynamicPPL. evaluatordef[:body] = quote - # in case someone accessed these - if __context__ isa $(DynamicPPL.SamplingContext) - __rng__ = __context__.rng - __sampler__ = __context__.sampler - end - $(modelinfo[:body]) end From 1630476c742f82e3f45096923e39a4b2da6150a2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:20:41 +0100 Subject: [PATCH 20/46] Update src/context_implementations.jl Co-authored-by: David Widmann --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e66501aee..4fd787c86 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -138,6 +138,7 @@ end tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) Handle observed variables with a `context` associated with a sampler. + Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls From 6892d2b1276ef92f6765c3272673ac7a58682465 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 2 Jun 2021 01:37:22 +0100 Subject: [PATCH 21/46] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 4fd787c86..5647cd5fc 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -160,6 +160,7 @@ end tilde_observe(context::SamplingContext, right, left, vi) Handle observed constants with a `context` associated with a sampler. + Falls back to `tilde_observe(context.context, right, left, vi)` ignoring the information about the sampler if the context `context.context` does not call any other context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls From 1015f0e3a248aacb3039dd4adee670504de4412f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 13:56:32 +0100 Subject: [PATCH 22/46] added impl of matchingvalue for contexts --- src/compiler.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index e647df99c..352d46418 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -435,8 +435,12 @@ end """ matchingvalue(sampler, vi, value) + matchingvalue(context::AbstractContext, vi, value) -Convert the `value` to the correct type for the `sampler` and the `vi` object. +Convert the `value` to the correct type for the `sampler` or `context` and the `vi` object. + +For a `context` that is _not_ a `SamplingContext`, we fall back to +`matchingvalue(SampleFromPrior(), vi, value)`. """ function matchingvalue(sampler, vi, value) T = typeof(value) @@ -453,6 +457,13 @@ function matchingvalue(sampler, vi, value) end matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(context::AbstractContext, vi, value) + return matchingvalue(SampleFromPrior(), vi, value) +end +function matchingvalue(context::SamplingContext, vi, value) + return matchingvalue(context.sampler, vi, value) +end + """ get_matching_type(spl::AbstractSampler, vi, ::Type{T}) where {T} From 23c86a76d27248f3b2c71eb6f72ea85439dd62b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 14:32:06 +0100 Subject: [PATCH 23/46] reverted the change that makes assume always resample --- src/context_implementations.jl | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d1fd7b0ba..e5e89ca55 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -255,15 +255,23 @@ function assume( inds, vi, ) - # Always overwrite the parameters with new ones. - r = init(rng, dist, sampler) if haskey(vi, vn) - vi[vn] = vectorize(dist, r) - setorder!(vi, vn, get_num_produce(vi)) + # Always overwrite the parameters with new ones for `SampleFromUniform`. + if sampler isa SampleFromUniform || is_flagged(vi, vn, "del") + unset_flag!(vi, vn, "del") + r = init(rng, dist, sampler) + vi[vn] = vectorize(dist, r) + settrans!(vi, false, vn) + setorder!(vi, vn, get_num_produce(vi)) + else + r = vi[vn] + end else + r = init(rng, dist, sampler) push!(vi, vn, r, dist, sampler) + settrans!(vi, false, vn) end - settrans!(vi, false, vn) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) end From 17f5abe88edefca1a1307987300197d19210da34 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 14:39:20 +0100 Subject: [PATCH 24/46] removed the inds arguments from assume and dot_assume to stay non-breaking --- src/context_implementations.jl | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index e5e89ca55..50e3a8e4b 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -44,11 +44,11 @@ function tilde_assume(context::SamplingContext, right, vn, inds, vi) end # Leaf contexts -tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, inds, vi) +tilde_assume(::DefaultContext, right, vn, inds, vi) = assume(right, vn, vi) function tilde_assume( rng::Random.AbstractRNG, ::DefaultContext, sampler, right, vn, inds, vi ) - return assume(rng, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) @@ -74,10 +74,10 @@ function tilde_assume( return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) end function tilde_assume(::PriorContext, right, vn, inds, vi) - return assume(right, vn, inds, vi) + return assume(right, vn, vi) end function tilde_assume(rng::Random.AbstractRNG, ::PriorContext, sampler, right, vn, inds, vi) - return assume(rng, sampler, right, vn, inds, vi) + return assume(rng, sampler, right, vn, vi) end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) @@ -103,12 +103,12 @@ function tilde_assume( return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) end function tilde_assume(::LikelihoodContext, right, vn, inds, vi) - return assume(NoDist(right), vn, inds, vi) + return assume(NoDist(right), vn, vi) end function tilde_assume( rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi ) - return assume(rng, sampler, NoDist(right), vn, inds, vi) + return assume(rng, sampler, NoDist(right), vn, vi) end function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) @@ -238,7 +238,7 @@ function observe(spl::Sampler, weight) end # fallback without sampler -function assume(dist::Distribution, vn::VarName, inds, vi) +function assume(dist::Distribution, vn::VarName, vi) if !haskey(vi, vn) error("variable $vn does not exist") end @@ -252,7 +252,6 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - inds, vi, ) if haskey(vi, vn) @@ -355,12 +354,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) - return dot_assume(NoDist.(right), left, vn, inds, vi) + return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, NoDist.(right), left, vn, inds, vi) + return dot_assume(rng, sampler, NoDist.(right), left, vn, vi) end # `PriorContext` @@ -396,12 +395,12 @@ function dot_tilde_assume( end end function dot_tilde_assume(context::PriorContext, right, left, vn, inds, vi) - return dot_assume(right, left, vn, inds, vi) + return dot_assume(right, left, vn, vi) end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, right, left, vn, inds, vi) + return dot_assume(rng, sampler, right, left, vn, vi) end # `MiniBatchContext` @@ -433,7 +432,6 @@ function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, - inds, vi, ) @assert length(dist) == size(var, 1) @@ -460,7 +458,6 @@ function dot_assume( dists::Union{Distribution,AbstractArray{<:Distribution}}, var::AbstractArray, vns::AbstractArray{<:VarName}, - inds, vi, ) # Make sure `var` is not a matrix for multivariate distributions From dbd61f04d1e5719633e36c525fc04b9190659b81 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 7 Jun 2021 15:32:41 +0100 Subject: [PATCH 25/46] Update src/context_implementations.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/context_implementations.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 50e3a8e4b..16704a814 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -429,10 +429,7 @@ end # `dot_assume` function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi, + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) lp = sum(zip(vns, eachcol(var))) do vn, ri From b10ba3f17f9f8f493893f8728e5f9aaf160dfd73 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:04:28 +0100 Subject: [PATCH 26/46] added missing sampler arg to tilde_observe --- src/context_implementations.jl | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 50e3a8e4b..a65336670 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -170,18 +170,19 @@ which the order of the sampling context and its child are swapped. function tilde_observe(context::SamplingContext, right, left, vi) c, reconstruct_context = unwrap_childcontext(context) child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) + return if child_of_c === nothing + tilde_observe(c, context.sampler, right, left, vi) else - c + tilde_observe( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vi + ) end - return tilde_observe(fallback_context, right, left, vi) end # Leaf contexts -tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) -tilde_observe(::PriorContext, right, left, vi) = 0 -tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) +tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) From bc5029f45c20a8dfa777b8e370c7af16797c890c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:06:28 +0100 Subject: [PATCH 27/46] added missing sampler argument in dot_tilde_observe --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index a65336670..04ddac2c5 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -583,7 +583,7 @@ probability, and return the observed value for a context associated with a sampl Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) - return dot_tilde_observe(context.context, right, left, vname, vinds, vi) + return dot_tilde_observe(context.context, context.sampler, right, left, vi) end # Leaf contexts From 7eac33dc6072023a0ac4ab59cb26f6e8b52efb32 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:06:49 +0100 Subject: [PATCH 28/46] fixed order of arguments in some dot_assume calls --- src/context_implementations.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 04ddac2c5..77a7475a3 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -360,7 +360,7 @@ end function dot_tilde_assume( rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, NoDist.(right), left, vn, vi) + return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) end # `PriorContext` @@ -401,7 +401,7 @@ end function dot_tilde_assume( rng::Random.AbstractRNG, context::PriorContext, sampler, right, left, vn, inds, vi ) - return dot_assume(rng, sampler, right, left, vn, vi) + return dot_assume(rng, sampler, right, vn, left, vi) end # `MiniBatchContext` From 85994813e350ea3e132e6f2dc0fc3d4896b4e641 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:01 +0100 Subject: [PATCH 29/46] formatting --- src/context_implementations.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 77a7475a3..d7d51c72a 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -430,10 +430,7 @@ end # `dot_assume` function dot_assume( - dist::MultivariateDistribution, - var::AbstractMatrix, - vns::AbstractVector{<:VarName}, - vi, + dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) lp = sum(zip(vns, eachcol(var))) do vn, ri From 90a8c4562e10153ef9323200c59d417411ec009b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:22 +0100 Subject: [PATCH 30/46] formatting --- src/context_implementations.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index d7d51c72a..431af5a5e 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -173,9 +173,7 @@ function tilde_observe(context::SamplingContext, right, left, vi) return if child_of_c === nothing tilde_observe(c, context.sampler, right, left, vi) else - tilde_observe( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vi - ) + tilde_observe(reconstruct_c(reconstruct_context(child_of_c)), right, left, vi) end end From f9d4ff83d2ed6b4e67426e636c7c1006a66cafa5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:07:36 +0100 Subject: [PATCH 31/46] added missing sampler argument in tilde_observe for SamplingContext --- src/context_implementations.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 431af5a5e..5dce075a2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -148,12 +148,13 @@ which the order of the sampling context and its child are swapped. function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) c, reconstruct_context = unwrap_childcontext(context) child_of_c, reconstruct_c = unwrap_childcontext(c) - fallback_context = if child_of_c !== nothing - reconstruct_c(reconstruct_context(child_of_c)) + return if child_of_c === nothing + tilde_observe(c, context.sampler, right, left, vname, vinds, vi) else - c + tilde_observe( + reconstruct_c(reconstruct_context(child_of_c)), right, left, vname, vinds, vi + ) end - return tilde_observe(fallback_context, right, left, vname, vinds, vi) end """ From e424fe7b45821382b7ceb35d5294689730081334 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 7 Jun 2021 16:08:15 +0100 Subject: [PATCH 32/46] added missing word in a docstring --- src/context_implementations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 5dce075a2..b67833d62 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -614,7 +614,7 @@ end """ dot_tilde_observe!(context, right, left, vname, vinds, vi) -Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur the model inputs), +Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable From 70957d27c66a3a1ce7b3ac751e68709531b5e548 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 08:03:50 +0100 Subject: [PATCH 33/46] updated submodel macro --- src/submodel_macro.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 92584ae8b..070a5aa4c 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,10 +1,8 @@ macro submodel(expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), $(esc(:__context__)), ) end @@ -13,10 +11,8 @@ end macro submodel(prefix, expr) return quote _evaluate( - $(esc(:__rng__)), $(esc(expr)), $(esc(:__varinfo__)), - $(esc(:__sampler__)), PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), ) end From d00cdcfb0ceed552a42abaadac2a8d4413e4bd4d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:58:46 +0100 Subject: [PATCH 34/46] removed unwrap_childcontext and related since its not needed for this PR --- src/context_implementations.jl | 130 +++++++++++++++------------------ src/contexts.jl | 38 ---------- 2 files changed, 58 insertions(+), 110 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index b67833d62..ae4d631ae 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -29,18 +29,9 @@ Falls back to ```julia tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) ``` -if the context `context.context` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `tilde_assume(c, right, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. """ function tilde_assume(context::SamplingContext, right, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_assume(context.rng, c, context.sampler, right, vn, inds, vi) - else - tilde_assume(reconstruct_c(reconstruct_context(child_of_c)), right, vn, inds, vi) - end + return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) end # Leaf contexts @@ -115,10 +106,18 @@ function tilde_assume(context::MiniBatchContext, right, vn, inds, vi) return tilde_assume(context.context, right, vn, inds, vi) end +function tilde_assume(rng, context::MiniBatchContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, vn, inds, vi) +end + function tilde_assume(context::PrefixContext, right, vn, inds, vi) return tilde_assume(context.context, right, prefix(context, vn), inds, vi) end +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) +end + """ tilde_assume!(context, right, vn, inds, vi) @@ -139,22 +138,12 @@ end Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)` ignoring -the information about the sampler if the context `context.context` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vname, vinds, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. +Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_observe(c, context.sampler, right, left, vname, vinds, vi) - else - tilde_observe( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vname, vinds, vi - ) - end + return tilde_observe( + context.rng, context.context, context.sampler, right, left, vname, vinds, vi + ) end """ @@ -162,39 +151,31 @@ end Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vi)` ignoring -the information about the sampler if the context `context.context` does not call any other -context, as indicated by [`unwrap_childcontext`](@ref). Otherwise, calls -`tilde_observe(c, right, left, vi)` where `c` is a context in -which the order of the sampling context and its child are swapped. +Falls back to `tilde_observe(context.context, right, left, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - tilde_observe(c, context.sampler, right, left, vi) - else - tilde_observe(reconstruct_c(reconstruct_context(child_of_c)), right, left, vi) - end + return tilde_observe(context.context, context.sampler, right, left, vi) end # Leaf contexts +tilde_observe(::DefaultContext, right, left, vi) = observe(right, left, vi) tilde_observe(::DefaultContext, sampler, right, left, vi) = observe(right, left, vi) +tilde_observe(::PriorContext, right, left, vi) = 0 tilde_observe(::PriorContext, sampler, right, left, vi) = 0 +tilde_observe(::LikelihoodContext, right, left, vi) = observe(right, left, vi) tilde_observe(::LikelihoodContext, sampler, right, left, vi) = observe(right, left, vi) # `MiniBatchContext` -function tilde_observe(context::MiniBatchContext, sampler, right, left, vi) +function tilde_observe(context::MiniBatchContext, right, left, vi) return context.loglike_scalar * tilde_observe(context.context, right, left, vi) end -function tilde_observe(context::MiniBatchContext, sampler, right, left, vname, vinds, vi) - return context.loglike_scalar * - tilde_observe(context.context, right, left, vname, vinds, vi) +function tilde_observe(context::MiniBatchContext, right, left, vname, vi) + return context.loglike_scalar * tilde_observe(context.context, right, left, vname, vi) end # `PrefixContext` -function tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return tilde_observe(context.context, right, left, prefix(context, vname), vinds, vi) +function tilde_observe(context::PrefixContext, right, left, vname, vi) + return tilde_observe(context.context, right, left, prefix(context, vname), vi) end function tilde_observe(context::PrefixContext, right, left, vi) return tilde_observe(context.context, right, left, vi) @@ -294,25 +275,16 @@ Falls back to ```julia dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) ``` -if the context `context.context` does not call any other context, as indicated by -[`unwrap_childcontext`](@ref). Otherwise, calls `dot_tilde_assume(c, right, left, vn, inds, vi)` -where `c` is a context in which the order of the sampling context and its child are swapped. """ function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) - c, reconstruct_context = unwrap_childcontext(context) - child_of_c, reconstruct_c = unwrap_childcontext(c) - return if child_of_c === nothing - dot_tilde_assume(context.rng, c, context.sampler, right, left, vn, inds, vi) - else - dot_tilde_assume( - reconstruct_c(reconstruct_context(child_of_c)), right, left, vn, inds, vi - ) - end + return dot_tilde_assume( + context.rng, context.context, context.sampler, right, left, vn, inds, vi + ) end # `DefaultContext` -function dot_tilde_assume(::DefaultContext, sampler, right, left, vns, inds, vi) - return dot_assume(right, vns, left, vi) +function dot_tilde_assume(::DefaultContext, right, left, vns, inds, vi) + return dot_assume(right, left, vns, vi) end function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) @@ -408,11 +380,23 @@ function dot_tilde_assume(context::MiniBatchContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, left, vn, inds, vi) end +function dot_tilde_assume( + rng, context::MiniBatchContext, sampler, right, left, vn, inds, vi +) + return dot_tilde_assume(rng, context.context, sampler, right, left, vn, inds, vi) +end + # `PrefixContext` function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) end +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) + return dot_tilde_assume( + rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + ) +end + """ dot_tilde_assume!(context, right, left, vn, inds, vi) @@ -583,30 +567,23 @@ function dot_tilde_observe(context::SamplingContext, right, left, vi) end # Leaf contexts -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) +dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(sampler, right, left, vi) +dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 -function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) +function dot_tilde_observe(context::LikelihoodContext, right, left, vi) return dot_observe(right, left, vi) end +function dot_tilde_observe(context::LikelihoodContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) +end # `MiniBatchContext` -function dot_tilde_observe(context::MiniBatchContext, sampler, right, left, vi) - return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vi) -end -function dot_tilde_observe( - context::MiniBatchContext, sampler, right, left, vname, vinds, vi -) - return context.loglike_scalar * - dot_tilde_observe(context.context, sampler, right, left, vname, vinds, vi) +function dot_tilde_observe(context::MiniBatchContext, right, left, vi) + return context.loglike_scalar * dot_tilde_observe(context.context, right, left, vi) end # `PrefixContext` -function dot_tilde_observe(context::PrefixContext, right, left, vname, vinds, vi) - return dot_tilde_observe( - context.context, right, left, prefix(context, vname), vinds, vi - ) -end function dot_tilde_observe(context::PrefixContext, right, left, vi) return dot_tilde_observe(context.context, right, left, vi) end @@ -641,18 +618,27 @@ function dot_tilde_observe!(context, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi) + return dot_observe(dist, value, vi) +end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) increment_num_produce!(vi) @debug "dist = $dist" @debug "value = $value" return Distributions.loglikelihood(dist, value) end +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi) + return dot_observe(dists, value, vi) +end function dot_observe(dists::Distribution, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" @debug "value = $value" return Distributions.loglikelihood(dists, value) end +function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) + return dot_observe(dists, value, vi) +end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) increment_num_produce!(vi) @debug "dists = $dists" diff --git a/src/contexts.jl b/src/contexts.jl index 6daa18776..8093c88f3 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -1,17 +1,3 @@ -""" - unwrap_childcontext(context::AbstractContext) - -Return a tuple of the child context of a `context`, or `nothing` if the context does -not wrap any other context, and a function `f(c::AbstractContext)` that constructs -an instance of `context` in which the child context is replaced with `c`. - -Falls back to `(nothing, _ -> context)`. -""" -function unwrap_childcontext(context::AbstractContext) - reconstruct_context(@nospecialize(x)) = context - return nothing, reconstruct_context -end - """ SamplingContext(rng, sampler, context) @@ -26,14 +12,6 @@ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractConte context::C end -function unwrap_childcontext(context::SamplingContext) - child = context.context - function reconstruct_samplingcontext(c::AbstractContext) - return SamplingContext(context.rng, context.sampler, c) - end - return child, reconstruct_samplingcontext -end - """ struct DefaultContext <: AbstractContext end @@ -89,14 +67,6 @@ function MiniBatchContext(context=DefaultContext(); batch_size, npoints) return MiniBatchContext(context, npoints / batch_size) end -function unwrap_childcontext(context::MiniBatchContext) - child = context.context - function reconstruct_minibatchcontext(c::AbstractContext) - return MiniBatchContext(c, context.loglike_scalar) - end - return child, reconstruct_minibatchcontext -end - """ PrefixContext{Prefix}(context) @@ -136,11 +106,3 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) end end - -function unwrap_childcontext(context::PrefixContext{P}) where {P} - child = context.context - function reconstruct_prefixcontext(c::AbstractContext) - return PrefixContext{P}(c) - end - return child, reconstruct_prefixcontext -end From 639fd6ebe36bee13fc2372915e2480858012612c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:05 +0100 Subject: [PATCH 35/46] updated submodel macro --- src/submodel_macro.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 070a5aa4c..1d574e286 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,10 +1,6 @@ macro submodel(expr) return quote - _evaluate( - $(esc(expr)), - $(esc(:__varinfo__)), - $(esc(:__context__)), - ) + _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) end end From c9a06fb46c8a9d4f184e2b64cfd9c119d68c5cc3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:15 +0100 Subject: [PATCH 36/46] fixed evaluation implementations of dot_assume --- src/context_implementations.jl | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index ae4d631ae..6befa1f6d 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -416,10 +416,17 @@ function dot_assume( dist::MultivariateDistribution, var::AbstractMatrix, vns::AbstractVector{<:VarName}, vi ) @assert length(dist) == size(var, 1) - lp = sum(zip(vns, eachcol(var))) do vn, ri + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dist, SampleFromPrior()) + lp = sum(zip(vns, eachcol(r))) do vn, ri return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end - return var, lp + return r, lp end function dot_assume( rng, @@ -441,9 +448,15 @@ function dot_assume( vns::AbstractArray{<:VarName}, vi, ) - # Make sure `var` is not a matrix for multivariate distributions - lp = sum(Bijectors.logpdf_with_trans.(dists, var, istrans(vi, vns[1]))) - return var, lp + # NOTE: We cannot work with `var` here because we might have a model of the form + # + # m = Vector{Float64}(undef, n) + # m .~ Normal() + # + # in which case `var` will have `undef` elements, even if `m` is present in `vi`. + r = get_and_set_val!(Random.GLOBAL_RNG, vi, vns, dists, SampleFromPrior()) + lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) + return r, lp end function dot_assume( From 2fe5f4016cb66ff5da3af57415dd096843363a1b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:59:45 +0100 Subject: [PATCH 37/46] updated pointwise_loglikelihoods and related --- src/loglikelihoods.jl | 99 ++++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 39 deletions(-) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 89672127a..6fca717c6 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -1,80 +1,102 @@ # Context version struct PointwiseLikelihoodContext{A,Ctx} <: AbstractContext loglikelihoods::A - ctx::Ctx + context::Ctx end function PointwiseLikelihoodContext( - likelihoods=Dict{VarName,Vector{Float64}}(), ctx::AbstractContext=LikelihoodContext() + likelihoods=Dict{VarName,Vector{Float64}}(), + context::AbstractContext=LikelihoodContext(), ) - return PointwiseLikelihoodContext{typeof(likelihoods),typeof(ctx)}(likelihoods, ctx) + return PointwiseLikelihoodContext{typeof(likelihoods),typeof(context)}( + likelihoods, context + ) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{VarName,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::VarName, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, string(vn), Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::VarName, logp::Real ) - return ctx.loglikelihoods[string(vn)] = logp + return context.loglikelihoods[string(vn)] = logp end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Vector{Float64}}}, + vn::String, + logp::Real, ) - lookup = ctx.loglikelihoods + lookup = context.loglikelihoods ℓ = get!(lookup, vn, Float64[]) return push!(ℓ, logp) end function Base.push!( - ctx::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real + context::PointwiseLikelihoodContext{Dict{String,Float64}}, vn::String, logp::Real ) - return ctx.loglikelihoods[vn] = logp + return context.loglikelihoods[vn] = logp end -function tilde_assume(rng, ctx::PointwiseLikelihoodContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, ctx.ctx, sampler, right, vn, inds, vi) +function tilde_assume(context::PointwiseLikelihoodContext, right, vn, inds, vi) + return tilde_assume(context.context, right, vn, inds, vi) end -function dot_tilde_assume( - rng, ctx::PointwiseLikelihoodContext, sampler, right, left, vn, inds, vi -) - value, logp = dot_tilde(rng, ctx.ctx, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + return dot_tilde_assume(context.context, right, left, vn, inds, vi) +end + +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return tilde_observe!(context.context, right, left, vi) +end +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!`. + logp = tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - return value + + # Track loglikelihood value. + push!(context, vn, logp) + + return left end -function tilde_observe( - ctx::PointwiseLikelihoodContext, sampler, right, left, vname, vinds, vi -) - # This is slightly unfortunate since it is not completely generic... - # Ideally we would call `tilde_observe` recursively but then we don't get the - # loglikelihood value. - logp = tilde(ctx.ctx, sampler, right, left, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) + # Defer literal `observe` to child-context. + return dot_tilde_observe(context.context, right, left, vi) +end +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) + # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. + # we have to intercept the call to `dot_tilde_observe!`. + logp = dot_tilde_observe(context.context, right, left, vi) acclogp!(vi, logp) - # track loglikelihood value - push!(ctx, vname, logp) + # Track loglikelihood value. + push!(context, vn, logp) return left end @@ -150,30 +172,29 @@ Dict{VarName,Array{Float64,2}} with 4 entries: """ function pointwise_loglikelihoods(model::Model, chain, keytype::Type{T}=String) where {T} # Get the data by executing the model once - spl = SampleFromPrior() vi = VarInfo(model) - ctx = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) + context = PointwiseLikelihoodContext(Dict{T,Vector{Float64}}()) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) for (sample_idx, chain_idx) in iters # Update the values - setval_and_resample!(vi, chain, sample_idx, chain_idx) + setval!(vi, chain, sample_idx, chain_idx) # Execute model - model(vi, spl, ctx) + model(vi, context) end niters = size(chain, 1) nchains = size(chain, 3) loglikelihoods = Dict( varname => reshape(logliks, niters, nchains) for - (varname, logliks) in ctx.loglikelihoods + (varname, logliks) in context.loglikelihoods ) return loglikelihoods end function pointwise_loglikelihoods(model::Model, varinfo::AbstractVarInfo) - ctx = PointwiseLikelihoodContext(Dict{VarName,Float64}()) - model(varinfo, SampleFromPrior(), ctx) - return ctx.loglikelihoods + context = PointwiseLikelihoodContext(Dict{VarName,Vector{Float64}}()) + model(varinfo, context) + return context.loglikelihoods end From b532ca690c254b12732dcdb5f7c9a67577a763cf Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:00:08 +0100 Subject: [PATCH 38/46] added proper tests for pointwise_loglikelihoods --- test/loglikelihoods.jl | 122 +++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 2 + 2 files changed, 124 insertions(+) create mode 100644 test/loglikelihoods.jl diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl new file mode 100644 index 000000000..5c1fdc082 --- /dev/null +++ b/test/loglikelihoods.jl @@ -0,0 +1,122 @@ +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x = 10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x = 10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + x .~ Normal(m, 0.5) +end + +# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `observe` +# m ~ MvNormal(length(x), 1.0) +# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +# end + +@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `dot_observe` +# m ~ Normal() +# [10.0, ] .~ Normal(m, 0.5) +# end + +@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + x ~ MvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) + + +@testset "loglikelihoods.jl" begin + for m in mean_of_mean_models + vi = VarInfo(m) + + vns = vi.metadata.m.vns + if length(vns) == 1 && length(vi[vns[1]]) == 1 + # Only have one latent variable. + DynamicPPL.setval!(vi, [1.0, ], ["m", ]) + else + DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) + end + + lls = pointwise_loglikelihoods(m, vi) + + if isempty(lls) + # One of the models with literal observations, so we just skip. + continue + end + + loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 + # Only have one observation, so we need to double it + # for comparison with other models. + 2 * sum(lls[first(keys(lls))]) + else + sum(sum, values(lls)) + end + + @test loglikelihood ≈ -324.45158270528947 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2b3d5d55c..d83be0eea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -45,6 +45,8 @@ include("test_util.jl") include("threadsafe.jl") include("serialization.jl") + + include("loglikelihoods.jl") end @testset "compat" begin From 4e2274e7abffc1058f370d32f4cc93e21df1d1f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:00:24 +0100 Subject: [PATCH 39/46] updated DPPL tests to reflect recent changes --- test/compiler.jl | 11 +++++------ test/threadsafe.jl | 8 ++++---- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/test/compiler.jl b/test/compiler.jl index 78b472563..d219f91ea 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -172,10 +172,10 @@ end @model function testmodel_missing3(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end @@ -184,18 +184,17 @@ end @test getlogp(varinfo) == lp @test varinfo_ isa AbstractVarInfo @test model_ === model - @test sampler_ === SampleFromPrior() - @test context_ === DefaultContext() + @test context_ isa SamplingContext @test rng_ isa Random.AbstractRNG # disable warnings @model function testmodel_missing4(x) x[1] ~ Bernoulli(0.5) global varinfo_ = __varinfo__ - global sampler_ = __sampler__ + global sampler_ = __context__.sampler global model_ = __model__ global context_ = __context__ - global rng_ = __rng__ + global rng_ = __context__.rng global lp = getlogp(__varinfo__) return x end false diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 746d6a5f8..7a2bdd039 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,14 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - Random.GLOBAL_RNG, wthreads(x), vi, SampleFromPrior(), DefaultContext() + wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @model function wothreads(x) @@ -96,14 +96,14 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - Random.GLOBAL_RNG, wothreads(x), vi, SampleFromPrior(), DefaultContext() + wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) ) end end From 10899f370c5335a53b0212146cddaf69ad43e62c Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:03:03 +0100 Subject: [PATCH 40/46] formatting --- src/context_implementations.jl | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 6befa1f6d..77cbc0fb2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -581,7 +581,9 @@ end # Leaf contexts dot_tilde_observe(::DefaultContext, right, left, vi) = dot_observe(right, left, vi) -dot_tilde_observe(::DefaultContext, sampler, right, left, vi) = dot_observe(sampler, right, left, vi) +function dot_tilde_observe(::DefaultContext, sampler, right, left, vi) + return dot_observe(sampler, right, left, vi) +end dot_tilde_observe(::PriorContext, right, left, vi) = 0 dot_tilde_observe(::PriorContext, sampler, right, left, vi) = 0 function dot_tilde_observe(context::LikelihoodContext, right, left, vi) @@ -631,7 +633,12 @@ function dot_tilde_observe!(context, right, left, vi) end # Ambiguity error when not sure to use Distributions convention or Julia broadcasting semantics -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dist::MultivariateDistribution, value::AbstractMatrix, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dist::MultivariateDistribution, + value::AbstractMatrix, + vi, +) return dot_observe(dist, value, vi) end function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) @@ -640,7 +647,12 @@ function dot_observe(dist::MultivariateDistribution, value::AbstractMatrix, vi) @debug "value = $value" return Distributions.loglikelihood(dist, value) end -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::Distribution, value::AbstractArray, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dists::Distribution, + value::AbstractArray, + vi, +) return dot_observe(dists, value, vi) end function dot_observe(dists::Distribution, value::AbstractArray, vi) @@ -649,7 +661,12 @@ function dot_observe(dists::Distribution, value::AbstractArray, vi) @debug "value = $value" return Distributions.loglikelihood(dists, value) end -function dot_observe(::Union{SampleFromPrior,SampleFromUniform}, dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) +function dot_observe( + ::Union{SampleFromPrior,SampleFromUniform}, + dists::AbstractArray{<:Distribution}, + value::AbstractArray, + vi, +) return dot_observe(dists, value, vi) end function dot_observe(dists::AbstractArray{<:Distribution}, value::AbstractArray, vi) From 1f21ce4158f5ed256b0435eafdc08acf37b8aece Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:03:47 +0100 Subject: [PATCH 41/46] formatting --- test/loglikelihoods.jl | 33 +++++++++++++++++---------------- test/threadsafe.jl | 16 ++++++++++++---- 2 files changed, 29 insertions(+), 20 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 5c1fdc082..4cc7325b8 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,28 +1,28 @@ # A collection of models for which the mean-of-means for the posterior should # be same. -@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo1(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo2(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `assume` with indexing and `observe` m = TV(undef, length(x)) for i in eachindex(m) m[i] ~ Normal() end - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo3(x = 10 * ones(2)) +@model function gdemo3(x=10 * ones(2)) # Multivariate `assume` and `observe` m ~ MvNormal(length(x), 1.0) - x ~ MvNormal(m, 0.5 * ones(length(x))) + return x ~ MvNormal(m, 0.5 * ones(length(x))) end -@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo4(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and `observe` with indexing m = TV(undef, length(x)) m .~ Normal() @@ -33,10 +33,10 @@ end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function gdemo5(x = 10 * ones(1)) +@model function gdemo5(x=10 * ones(1)) # `assume` and `dot_observe` m ~ Normal() - x .~ Normal(m, 0.5) + return x .~ Normal(m, 0.5) end # @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} @@ -45,7 +45,7 @@ end # [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) # end -@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo7(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing m = TV(undef, 2) m .~ Normal() @@ -60,7 +60,7 @@ end # [10.0, ] .~ Normal(m, 0.5) # end -@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} +@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) m .~ Normal() @@ -76,10 +76,10 @@ end end @model function _likelihood_dot_observe(m, x) - x ~ MvNormal(m, 0.5 * ones(length(m))) + return x ~ MvNormal(m, 0.5 * ones(length(m))) end -@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} +@model function gdemo10(x=10 * ones(2), ::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, length(x)) m .~ Normal() @@ -87,8 +87,9 @@ end @submodel _likelihood_dot_observe(m, x) end -const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) - +const mean_of_mean_models = ( + gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() +) @testset "loglikelihoods.jl" begin for m in mean_of_mean_models @@ -97,7 +98,7 @@ const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), g vns = vi.metadata.m.vns if length(vns) == 1 && length(vi[vns[1]]) == 1 # Only have one latent variable. - DynamicPPL.setval!(vi, [1.0, ], ["m", ]) + DynamicPPL.setval!(vi, [1.0], ["m"]) else DynamicPPL.setval!(vi, [1.0, 1.0], ["m[1]", "m[2]"]) end diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 7a2bdd039..83c53ccd6 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -61,14 +61,18 @@ # Ensure that we use `ThreadSafeVarInfo` to handle multithreaded observe statements. DynamicPPL.evaluate_threadsafe( - wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa DynamicPPL.ThreadSafeVarInfo println(" evaluate_threadsafe:") @time DynamicPPL.evaluate_threadsafe( - wthreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wthreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @model function wothreads(x) @@ -96,14 +100,18 @@ # Ensure that we use `VarInfo`. DynamicPPL.evaluate_threadunsafe( - wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) @test getlogp(vi) ≈ lp_w_threads @test vi_ isa VarInfo println(" evaluate_threadunsafe:") @time DynamicPPL.evaluate_threadunsafe( - wothreads(x), vi, SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()) + wothreads(x), + vi, + SamplingContext(Random.GLOBAL_RNG, SampleFromPrior(), DefaultContext()), ) end end From 70045061b0730bac5e76ed3c6b981511734b21f0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:25:56 +0100 Subject: [PATCH 42/46] renamed mean_of_mean_models used in tests --- test/loglikelihoods.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index 4cc7325b8..74fb88d70 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -87,12 +87,12 @@ end @submodel _likelihood_dot_observe(m, x) end -const mean_of_mean_models = ( +const gdemo_models = ( gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10() ) @testset "loglikelihoods.jl" begin - for m in mean_of_mean_models + for m in gdemo_models vi = VarInfo(m) vns = vi.metadata.m.vns From fa6c4d6aed0312b69d054c3b46f7f438b69810c3 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 08:05:29 +0100 Subject: [PATCH 43/46] bumped dppl version in integration tests --- test/turing/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/turing/Project.toml b/test/turing/Project.toml index a4f68621d..67b8d5645 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.11" +DynamicPPL = "0.12" Turing = "0.15, 0.16" julia = "1.3" From 684d829b437e546a02d588d2204082f0be7df0ae Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 09:23:07 +0100 Subject: [PATCH 44/46] Apply suggestions from code review Co-authored-by: David Widmann --- src/compiler.jl | 4 +--- src/context_implementations.jl | 13 +++++++++---- src/contexts.jl | 2 +- src/loglikelihoods.jl | 2 +- src/model.jl | 7 ++----- 5 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 352d46418..2fa94bcd9 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -395,9 +395,7 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = quote - $(modelinfo[:body]) - end + evaluatordef[:body] = modelinfo[:body] ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 77cbc0fb2..259a7c1c3 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -124,7 +124,8 @@ end Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `tilde_assume!(context, right, vn, inds, vi)`. +By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log +probability of `vi` with the returned value. """ function tilde_assume!(context, right, vn, inds, vi) value, logp = tilde_assume(context, right, vn, inds, vi) @@ -138,7 +139,10 @@ end Handle observed variables with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vname, vinds, vi)`. +Falls back to +```julia +tilde_observe(context.rng, context.context, context.sampler, right, left, vname, vinds, vi) +``` """ function tilde_observe(context::SamplingContext, right, left, vname, vinds, vi) return tilde_observe( @@ -151,7 +155,7 @@ end Handle observed constants with a `context` associated with a sampler. -Falls back to `tilde_observe(context.context, right, left, vi)`. +Falls back to `tilde_observe(context.context, context.sampler, right, left, vi)`. """ function tilde_observe(context::SamplingContext, right, left, vi) return tilde_observe(context.context, context.sampler, right, left, vi) @@ -202,7 +206,8 @@ end Handle observed constants, e.g., `1.0 ~ Normal()`, accumulate the log probability, and return the observed value. -Falls back to `tilde(context, right, left, vi)`. +By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log +probability of `vi` with the returned value. """ function tilde_observe!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 8093c88f3..05ad8df0d 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -4,7 +4,7 @@ Create a context that allows you to sample parameters with the `sampler` when running the model. The `context` determines how the returned log density is computed when running the model. -See also: [`JointContext`](@ref), [`LoglikelihoodContext`](@ref), [`PriorContext`](@ref) +See also: [`DefaultContext`](@ref), [`LikelihoodContext`](@ref), [`PriorContext`](@ref) """ struct SamplingContext{S<:AbstractSampler,C<:AbstractContext,R} <: AbstractContext rng::R diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 6fca717c6..6c66e4ec4 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -87,7 +87,7 @@ end function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. - return dot_tilde_observe(context.context, right, left, vi) + return dot_tilde_observe!(context.context, right, left, vi) end function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. diff --git a/src/model.jl b/src/model.jl index 2d74949c1..9ec047a44 100644 --- a/src/model.jl +++ b/src/model.jl @@ -156,11 +156,8 @@ Evaluate the `model` with the arguments matching the given `context` and `varinf @generated function _evaluate( model::Model{_F,argnames}, varinfo, context ) where {_F,argnames} - unwrap_args = [:($matchingvalue(sampler, varinfo, model.args.$var)) for var in argnames] - return quote - sampler = context isa $(SamplingContext) ? context.sampler : SampleFromPrior() - model.f(model, varinfo, context, $(unwrap_args...)) - end + unwrap_args = [:($matchingvalue(context, varinfo, model.args.$var)) for var in argnames] + return :(model.f(model, varinfo, context, $(unwrap_args...))) end """ From 07bb28416adccab4f4783d0d708c9645f49be327 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:29:54 +0100 Subject: [PATCH 45/46] Apply suggestions from code review Co-authored-by: David Widmann --- src/context_implementations.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 259a7c1c3..6833a7856 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -191,13 +191,11 @@ end Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe(context, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ function tilde_observe!(context, right, left, vname, vinds, vi) - logp = tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return tilde_observe!(context, right, left, vi) end """ @@ -578,7 +576,7 @@ end Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value for a context associated with a sampler. -Falls back to `dot_tilde_observe(context.context, right, left, vi) ignoring the sampler. +Falls back to `dot_tilde_observe(context.context, context.sampler, right, left, vi)`. """ function dot_tilde_observe(context::SamplingContext, right, left, vi) return dot_tilde_observe(context.context, context.sampler, right, left, vi) @@ -614,13 +612,11 @@ end Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe(context, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ function dot_tilde_observe!(context, right, left, vn, inds, vi) - logp = dot_tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return dot_tilde_observe!(context, right, left, vi) end """ From c7c6a3c066e1c1d229edbe18fd8a1f7a4e8a9fc6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 9 Jun 2021 12:56:57 +0100 Subject: [PATCH 46/46] fixed ambiguity error --- src/compiler.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/compiler.jl b/src/compiler.jl index 352d46418..3924eae95 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -455,7 +455,9 @@ function matchingvalue(sampler, vi, value) return value end end -matchingvalue(sampler, vi, value::FloatOrArrayType) = get_matching_type(sampler, vi, value) +function matchingvalue(sampler::AbstractSampler, vi, value::FloatOrArrayType) + return get_matching_type(sampler, vi, value) +end function matchingvalue(context::AbstractContext, vi, value) return matchingvalue(SampleFromPrior(), vi, value)