diff --git a/Project.toml b/Project.toml index b5ed82a53..cdff210e4 100644 --- a/Project.toml +++ b/Project.toml @@ -11,6 +11,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] @@ -21,5 +22,6 @@ Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" +Setfield = "0.7" ZygoteRules = "0.2" julia = "1.3" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index ac2734b47..9609fcbb5 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -10,6 +10,7 @@ using ChainRulesCore: ChainRulesCore using MacroTools: MacroTools using ZygoteRules: ZygoteRules using BangBang: BangBang +using Setfield: Setfield using Random: Random @@ -28,15 +29,22 @@ import Base: keys, haskey +import BangBang: push!!, empty!!, setindex!! + # VarInfo export AbstractVarInfo, VarInfo, UntypedVarInfo, TypedVarInfo, + push!!, + empty!!, getlogp, setlogp!, acclogp!, resetlogp!, + setlogp!!, + acclogp!!, + resetlogp!!, get_num_produce, set_num_produce!, reset_num_produce!, @@ -51,6 +59,8 @@ export AbstractVarInfo, istrans, link!, invlink!, + link!!, + invlink!!, tonamedtuple, # VarName (reexport from AbstractPPL) VarName, @@ -134,4 +144,26 @@ include("compat/ad.jl") include("loglikelihoods.jl") include("submodel_macro.jl") +# Deprecations +@deprecate empty!(vi::VarInfo) empty!!(vi::VarInfo) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution +) +@deprecate push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler +) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, sampler::AbstractSampler) +@deprecate push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) push!!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector +) +@deprecate push!( + vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector} +) push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Set{Selector}) + +@deprecate setlogp!(vi, logp) setlogp!!(vi, logp) +@deprecate acclogp!(vi, logp) acclogp!!(vi, logp) +@deprecate resetlogp!(vi) resetlogp!!(vi) + +@deprecate link!(vi, spl) link!!(vi, spl) +@deprecate invlink!(vi, spl) invlink!!(vi, spl) + end # module diff --git a/src/compat/ad.jl b/src/compat/ad.jl index 47a627506..4fd2830b3 100644 --- a/src/compat/ad.jl +++ b/src/compat/ad.jl @@ -1,5 +1,5 @@ # See https://github.com/TuringLang/Turing.jl/issues/1199 -ChainRulesCore.@non_differentiable push!( +ChainRulesCore.@non_differentiable push!!( vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) diff --git a/src/compiler.jl b/src/compiler.jl index 4e32b7d5a..83cf8b767 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -369,7 +369,7 @@ function generate_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -383,7 +383,7 @@ function generate_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume!)( + $left, __varinfo__ = $(DynamicPPL.tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_vn)( $(DynamicPPL.check_tilde_rhs)($right), $vn @@ -397,7 +397,7 @@ function generate_tilde(left, right) $left = $(DynamicPPL.getvalue_nested)(__context__, $vn) end - $(DynamicPPL.tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), @@ -418,7 +418,7 @@ function generate_dot_tilde(left, right) # If the LHS is a literal, it is always an observation if isliteral(left) return quote - $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ ) end @@ -432,7 +432,7 @@ function generate_dot_tilde(left, right) $inds = $(vinds(left)) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_assume!!)( __context__, $(DynamicPPL.unwrap_right_left_vns)( $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn @@ -446,7 +446,7 @@ function generate_dot_tilde(left, right) $left .= $(DynamicPPL.getvalue_nested)(__context__, $vn) end - $(DynamicPPL.dot_tilde_observe!)( + _, __varinfo__ = $(DynamicPPL.dot_tilde_observe!!)( __context__, $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), @@ -458,6 +458,108 @@ function generate_dot_tilde(left, right) end end +""" + isfuncdef(expr) + +Return `true` if `expr` is any form of function definition, and `false` otherwise. +""" +function isfuncdef(e::Expr) + return if Meta.isexpr(e, :function) + # Classic `function f(...)` + true + elseif Meta.isexpr(e, :->) + # Anonymous functions/lambdas, e.g. `do` blocks or `->` defs. + true + elseif Meta.isexpr(e, :(=)) && Meta.isexpr(e.args[1], :call) + # Short function defs, e.g. `f(args...) = ...`. + true + else + false + end +end + +""" + replace_returns(expr) + +Return `Expr` with all `return ...` statements replaced with +`return ..., DynamicPPL.return_values(__varinfo__)`. + +Note that this method will _not_ replace `return` statements within function +definitions. This is checked using [`isfuncdef`](@ref). +""" +replace_returns(e) = e +replace_returns(e::Symbol) = e +function replace_returns(e::Expr) + if isfuncdef(e) + return e + end + + if Meta.isexpr(e, :return) + # NOTE: `return` always has an argument. In the case of + # `return`, the parsed expression will be `return nothing`. + # Hence we don't need any special handling for empty returns. + retval_expr = if length(e.args) > 1 + Expr(:tuple, e.args...) + else + e.args[1] + end + + return :(return $(DynamicPPL.return_values)($retval_expr, __varinfo__)) + end + + return Expr(e.head, map(replace_returns, e.args)...) +end + +""" + return_values(retval, varinfo) + +Return `(retval, varinfo)` if `retval` is not a `Tuple` with second +component being a `AbstractVarInfo`. + +Used together with [`replace_returns`](@ref), it handles the following case. + +# Example + +Suppose the following is the return-value: + +```julia +return x ~ Normal() +``` + +Without `return_values`, once expanded in [`generated_mainbody!`](@ref), this would be + +```julia +return (x, __varinfo__ = tilde_assume!!(...)), __varinfo__ +``` + +i.e. the return-value of the model would end up `(x, __varinfo__), __varinfo__` +which in turn would lead to a `(::Model)(args...)` call returning `(x, __varinfo__)`, +breaking with the expectation of the user. + +In such a scenario `return_values` effectively results in the following + +```julia +return x, __varinfo__ = tilde_assume!!(...) +``` + +preserving user expectation, as desired. +""" +return_values(retval, varinfo::AbstractVarInfo) = (retval, varinfo) +return_values(retval::Tuple{<:Any,<:AbstractVarInfo}, ::AbstractVarInfo) = retval + +# If it's just a symbol, e.g. `f(x) = 1`, then we make it `f(x) = return 1`. +make_returns_explicit!(body) = Expr(:return, body) +function make_returns_explicit!(body::Expr) + # If the last statement is a return-statement, we don't do anything. + if Meta.isexpr(body.args[end], :return) + return body + end + + # Otherwise we replace the last statement with a `return` statement. + body.args[end] = Expr(:return, body.args[end]) + return body +end + const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true @@ -489,7 +591,11 @@ function build_output(modelinfo, linenumbernode) evaluatordef[:kwargs] = [] # Replace the user-provided function body with the version created by DynamicPPL. - evaluatordef[:body] = modelinfo[:body] + # NOTE: We need to replace statements of the form `return ...` with + # `return DynamicPPL.return_values(..., __varinfo__)` to ensure that the second + # element in the returned value is always the most up-to-date `__varinfo__`. + # See the docstrings of `replace_returns` and `return_values` for more info. + evaluatordef[:body] = replace_returns(make_returns_explicit!(modelinfo[:body])) ## Build the model function. diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 498e83492..6f21ca1b2 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -68,7 +68,9 @@ end function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi = setindex!!( + vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn + ) settrans!(vi, false, vn) end return tilde_assume(PriorContext(), right, vn, inds, vi) @@ -83,7 +85,9 @@ function tilde_assume( vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi = setindex!!( + vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn + ) settrans!(vi, false, vn) end return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) @@ -91,7 +95,9 @@ end function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi = setindex!!( + vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn + ) settrans!(vi, false, vn) end return tilde_assume(LikelihoodContext(), right, vn, inds, vi) @@ -106,7 +112,9 @@ function tilde_assume( vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi = setindex!!( + vi, vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)), vn + ) settrans!(vi, false, vn) end return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) @@ -128,7 +136,7 @@ function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) end """ - tilde_assume!(context, right, vn, inds, vi) + tilde_assume!!(context, right, vn, inds, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. @@ -136,10 +144,9 @@ accumulate the log probability, and return the sampled value. By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_assume!(context, right, vn, inds, vi) - value, logp = tilde_assume(context, right, vn, inds, vi) - acclogp!(vi, logp) - return value +function tilde_assume!!(context, right, vn, inds, vi) + value, logp, vi = tilde_assume(context, right, vn, inds, vi) + return value, acclogp!!(vi, logp) end # observe @@ -180,16 +187,16 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(context, right, left, vname, vinds, vi) + tilde_observe!!(context, right, left, vname, vinds, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name +Falls back to `tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(context, right, left, vname, vinds, vi) - return tilde_observe!(context, right, left, vi) +function tilde_observe!!(context, right, left, vname, vinds, vi) + return tilde_observe!!(context, right, left, vi) end """ @@ -201,10 +208,9 @@ return the observed value. By default, calls `tilde_observe(context, right, left, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_observe!(context, right, left, vi) +function tilde_observe!!(context, right, left, vi) logp = tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return left, acclogp!!(vi, logp) end function assume(rng, spl::Sampler, dist) @@ -218,7 +224,7 @@ end # fallback without sampler function assume(dist::Distribution, vn::VarName, vi) r = vi[vn] - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end # SampleFromPrior and SampleFromUniform @@ -227,7 +233,7 @@ function assume( sampler::Union{SampleFromPrior,SampleFromUniform}, dist::Distribution, vn::VarName, - vi, + vi::VarInfo, ) if haskey(vi, vn) # Always overwrite the parameters with new ones for `SampleFromUniform`. @@ -242,11 +248,11 @@ function assume( end else r = init(rng, dist, sampler) - push!(vi, vn, r, dist, sampler) + push!!(vi, vn, r, dist, sampler) settrans!(vi, false, vn) end - return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)) + return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)), vi end # default fallback (used e.g. by `SampleFromPrior` and `SampleUniform`) @@ -313,7 +319,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!(vi, false, _vns) dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) else dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) @@ -333,7 +339,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!(vi, false, _vns) dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) else dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) @@ -354,7 +360,7 @@ function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!(vi, false, _vns) dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) else dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) @@ -374,7 +380,7 @@ function dot_tilde_assume( var = _getindex(getfield(context.vars, getsym(vn)), inds) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) - settrans!.(Ref(vi), false, _vns) + settrans!(vi, false, _vns) dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) else dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) @@ -393,17 +399,18 @@ function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, end """ - dot_tilde_assume!(context, right, left, vn, inds, vi) + dot_tilde_assume!!(context, right, left, vn, inds, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. """ -function dot_tilde_assume!(context, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) - acclogp!(vi, logp) - return value +function dot_tilde_assume!!(context, right, left, vn, inds, vi) + value, logp, vi = dot_tilde_assume(context, right, left, vn, inds, vi) + # Mutation of `value` no longer occurs in main body, so we do it here. + left .= value + return value, acclogp!!(vi, logp) end # `dot_assume` @@ -421,7 +428,7 @@ function dot_assume( lp = sum(zip(vns, eachcol(r))) do vn, ri return Bijectors.logpdf_with_trans(dist, ri, istrans(vi, vn)) end - return r, lp + return r, lp, vi end function dot_assume( @@ -435,7 +442,7 @@ function dot_assume( @assert length(dist) == size(var, 1) r = get_and_set_val!(rng, vi, vns, dist, spl) lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume( @@ -452,7 +459,7 @@ function dot_assume( # in which case `var` will have `undef` elements, even if `m` is present in `vi`. r = reshape(vi[vec(vns)], size(vns)) lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume( @@ -461,12 +468,12 @@ function dot_assume( dists::Union{Distribution,AbstractArray{<:Distribution}}, vns::AbstractArray{<:VarName}, var::AbstractArray, - vi, + vi::VarInfo, ) r = get_and_set_val!(rng, vi, vns, dists, spl) # Make sure `r` is not a matrix for multivariate distributions lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1]))) - return r, lp + return r, lp, vi end function dot_assume(rng, spl::Sampler, ::Any, ::AbstractArray{<:VarName}, ::Any, ::Any) return error( @@ -476,7 +483,7 @@ end function get_and_set_val!( rng, - vi, + vi::VarInfo, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -500,7 +507,7 @@ function get_and_set_val!( r = init(rng, dist, spl, n) for i in 1:n vn = vns[i] - push!(vi, vn, r[:, i], dist, spl) + push!!(vi, vn, r[:, i], dist, spl) settrans!(vi, false, vn) end end @@ -509,7 +516,7 @@ end function get_and_set_val!( rng, - vi, + vi::VarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}}, spl::Union{SampleFromPrior,SampleFromUniform}, @@ -533,14 +540,19 @@ function get_and_set_val!( else f = (vn, dist) -> init(rng, dist, spl) r = f.(vns, dists) - push!.(Ref(vi), vns, r, dists, Ref(spl)) - settrans!.(Ref(vi), false, vns) + # TODO: This unnecessarily allocates a potentially large array of references + # to `vi`. Address this. + push!!.(Ref(vi), vns, r, dists, Ref(spl)) + settrans!(vi, false, vns) end return r end function set_val!( - vi, vns::AbstractVector{<:VarName}, dist::MultivariateDistribution, val::AbstractMatrix + vi::VarInfo, + vns::AbstractVector{<:VarName}, + dist::MultivariateDistribution, + val::AbstractMatrix, ) @assert size(val, 2) == length(vns) foreach(enumerate(vns)) do (i, vn) @@ -549,7 +561,7 @@ function set_val!( return val end function set_val!( - vi, + vi::VarInfo, vns::AbstractArray{<:VarName}, dists::Union{Distribution,AbstractArray{<:Distribution}}, val::AbstractArray, @@ -598,30 +610,29 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vname, vinds, vi) + dot_tilde_observe!!(context, right, left, vname, vinds, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. -Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable +Falls back to `dot_tilde_observe!!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(context, right, left, vn, inds, vi) - return dot_tilde_observe!(context, right, left, vi) +function dot_tilde_observe!!(context, right, left, vn, inds, vi) + return dot_tilde_observe!!(context, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vi) + dot_tilde_observe!!(context, right, left, vi) Handle broadcasted observed constants, e.g., `[1.0] .~ MvNormal()`, accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe(context, right, left, vi)`. """ -function dot_tilde_observe!(context, right, left, vi) +function dot_tilde_observe!!(context, right, left, vi) logp = dot_tilde_observe(context, right, left, vi) - acclogp!(vi, logp) - return left + return left, acclogp!!(vi, logp) end # Falls back to non-sampler definition. diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 0cac29219..89b3d9e6d 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -67,27 +67,18 @@ function Base.push!( return context.loglikelihoods[vn] = logp end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) - # Defer literal `observe` to child-context. - return tilde_observe!(context.context, right, left, vi) -end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) - # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. - # we have to intercept the call to `tilde_observe!`. +function tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) + # Need the `logp` value, so we cannot defer `acclogp!!` to child-context, i.e. + # we have to intercept the call to `tilde_observe!!`. logp = tilde_observe(context.context, right, left, vi) - acclogp!(vi, logp) # Track loglikelihood value. push!(context, vn, logp) - return left + return left, acclogp!!(vi, logp) end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) - # Defer literal `observe` to child-context. - return dot_tilde_observe!(context.context, right, left, vi) -end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_observe!!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. @@ -95,7 +86,6 @@ function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn # hence we need the `logp` for each of them. Broadcasting the univariate # `tilde_obseve` does exactly this. logps = _pointwise_tilde_observe(context.context, right, left, vi) - acclogp!(vi, sum(logps)) # Need to unwrap the `vn`, i.e. get one `VarName` for each entry in `left`. _, _, vns = unwrap_right_left_vns(right, left, vn) @@ -104,7 +94,7 @@ function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn push!(context, vn, logp) end - return left + return left, acclogp!!(vi, sum(logps)) end # FIXME: This is really not a good approach since it needs to stay in sync with diff --git a/src/model.jl b/src/model.jl index 646437370..23ff90762 100644 --- a/src/model.jl +++ b/src/model.jl @@ -374,17 +374,20 @@ Sample from the `model` using the `sampler` with random number generator `rng` a The method resets the log joint probability of `varinfo` and increases the evaluation number of `sampler`. """ -function (model::Model)( - rng::Random.AbstractRNG, - varinfo::AbstractVarInfo=VarInfo(), - sampler::AbstractSampler=SampleFromPrior(), - context::AbstractContext=DefaultContext(), -) - return model(varinfo, SamplingContext(rng, sampler, context)) -end +(model::Model)(args...) = (first ∘ evaluate)(model, args...) + +""" + evaluate(model::Model[, rng, varinfo, sampler, context]) + +Sample from the `model` using the `sampler` with random number generator `rng` and the +`context`, and store the sample and log joint probability in `varinfo`. + +Returns both the return-value of the original model, and the resulting varinfo. -(model::Model)(context::AbstractContext) = model(VarInfo(), context) -function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) +The method resets the log joint probability of `varinfo` and increases the evaluation +number of `sampler`. +""" +function evaluate(model::Model, varinfo::AbstractVarInfo, context::AbstractContext) if Threads.nthreads() == 1 return evaluate_threadunsafe(model, varinfo, context) else @@ -392,18 +395,30 @@ function (model::Model)(varinfo::AbstractVarInfo, context::AbstractContext) end end -function (model::Model)(args...) - return model(Random.GLOBAL_RNG, args...) +function evaluate( + model::Model, + rng::Random.AbstractRNG, + varinfo::AbstractVarInfo=VarInfo(), + sampler::AbstractSampler=SampleFromPrior(), + context::AbstractContext=DefaultContext(), +) + return evaluate(model, varinfo, SamplingContext(rng, sampler, context)) +end + +evaluate(model::Model, context::AbstractContext) = evaluate(model, VarInfo(), context) + +function evaluate(model::Model, args...) + return evaluate(model, Random.GLOBAL_RNG, args...) end # without VarInfo -function (model::Model)(rng::Random.AbstractRNG, sampler::AbstractSampler, args...) - return model(rng, VarInfo(), sampler, args...) +function evaluate(model::Model, rng::Random.AbstractRNG, sampler::AbstractSampler, args...) + return evaluate(model, rng, VarInfo(), sampler, args...) end # without VarInfo and without AbstractSampler -function (model::Model)(rng::Random.AbstractRNG, context::AbstractContext) - return model(rng, VarInfo(), SampleFromPrior(), context) +function evaluate(model::Model, rng::Random.AbstractRNG, context::AbstractContext) + return evaluate(model, rng, VarInfo(), SampleFromPrior(), context) end """ @@ -417,7 +432,7 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadsafe`](@ref) """ function evaluate_threadunsafe(model, varinfo, context) - resetlogp!(varinfo) + resetlogp!!(varinfo) return _evaluate(model, varinfo, context) end @@ -433,10 +448,10 @@ This method is not exposed and supposed to be used only internally in DynamicPPL See also: [`evaluate_threadunsafe`](@ref) """ function evaluate_threadsafe(model, varinfo, context) - resetlogp!(varinfo) + resetlogp!!(varinfo) wrapper = ThreadSafeVarInfo(varinfo) result = _evaluate(model, wrapper, context) - setlogp!(varinfo, getlogp(wrapper)) + setlogp!!(varinfo, getlogp(wrapper)) return result end @@ -495,8 +510,8 @@ Return the log joint probability of variables `varinfo` for the probabilistic `m See [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logjoint(model::Model, varinfo::AbstractVarInfo) - model(varinfo, DefaultContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, DefaultContext()) + return getlogp(varinfo_new) end """ @@ -507,8 +522,8 @@ Return the log prior probability of variables `varinfo` for the probabilistic `m See also [`logjoint`](@ref) and [`loglikelihood`](@ref). """ function logprior(model::Model, varinfo::AbstractVarInfo) - model(varinfo, PriorContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, PriorContext()) + return getlogp(varinfo_new) end """ @@ -519,8 +534,8 @@ Return the log likelihood of variables `varinfo` for the probabilistic `model`. See also [`logjoint`](@ref) and [`logprior`](@ref). """ function Distributions.loglikelihood(model::Model, varinfo::AbstractVarInfo) - model(varinfo, LikelihoodContext()) - return getlogp(varinfo) + _, varinfo_new = evaluate(model, varinfo, LikelihoodContext()) + return getlogp(varinfo_new) end """ diff --git a/src/prob_macro.jl b/src/prob_macro.jl index d761e9fdc..84497aef0 100644 --- a/src/prob_macro.jl +++ b/src/prob_macro.jl @@ -146,7 +146,7 @@ function logprior( foreach(keys(vi.metadata)) do n @assert n in keys(left) "Variable $n is not defined." end - model(vi, SampleFromPrior(), PriorContext(left)) + _, vi = DynamicPPL.evaluate(model, vi, SampleFromPrior(), PriorContext(left)) return getlogp(vi) end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl index 1d574e286..23b6245ec 100644 --- a/src/submodel_macro.jl +++ b/src/submodel_macro.jl @@ -1,15 +1,39 @@ +""" + @submodel x = model(args...) + @submodel prefix x = model(args...) + +Treats `model` as a distribution, where `x` is the return-value of `model`. + +If `prefix` is specified, then variables sampled within `model` will be +prefixed by `prefix`. This is useful if you have variables of same names in +several models used together. +""" macro submodel(expr) - return quote - _evaluate($(esc(expr)), $(esc(:__varinfo__)), $(esc(:__context__))) - end + return submodel(expr) end macro submodel(prefix, expr) - return quote - _evaluate( - $(esc(expr)), - $(esc(:__varinfo__)), - PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))), - ) + ctx = :(PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__)))) + return submodel(expr, ctx) +end + +function submodel(expr, ctx=esc(:__context__)) + args_assign = getargs_assignment(expr) + return if args_assign === nothing + # In this case we only want to get the `__varinfo__`. + quote + $(esc(:_)), $(esc(:__varinfo__)) = _evaluate( + $(esc(expr)), $(esc(:__varinfo__)), $(ctx) + ) + end + else + # Here we also want the return-variable. + # TODO: Should we prefix by `L` by default? + L, R = args_assign + quote + $(esc(L)), $(esc(:__varinfo__)) = _evaluate( + $(esc(R)), $(esc(:__varinfo__)), $(ctx) + ) + end end end diff --git a/src/threadsafe.jl b/src/threadsafe.jl index c940f9e3f..9c59fa507 100644 --- a/src/threadsafe.jl +++ b/src/threadsafe.jl @@ -15,7 +15,7 @@ ThreadSafeVarInfo(vi::ThreadSafeVarInfo) = vi # Instead of updating the log probability of the underlying variables we # just update the array of log probabilities. -function acclogp!(vi::ThreadSafeVarInfo, logp) +function acclogp!!(vi::ThreadSafeVarInfo, logp) vi.logps[Threads.threadid()][] += logp return vi end @@ -26,17 +26,17 @@ getlogp(vi::ThreadSafeVarInfo) = getlogp(vi.varinfo) + sum(getindex, vi.logps) # TODO: Make remaining methods thread-safe. -function resetlogp!(vi::ThreadSafeVarInfo) +function resetlogp!!(vi::ThreadSafeVarInfo) for x in vi.logps x[] = zero(x[]) end - return resetlogp!(vi.varinfo) + return resetlogp!!(vi.varinfo) end -function setlogp!(vi::ThreadSafeVarInfo, logp) +function setlogp!!(vi::ThreadSafeVarInfo, logp) for x in vi.logps x[] = zero(x[]) end - return setlogp!(vi.varinfo, logp) + return setlogp!!(vi.varinfo, logp) end get_num_produce(vi::ThreadSafeVarInfo) = get_num_produce(vi.varinfo) @@ -46,8 +46,8 @@ set_num_produce!(vi::ThreadSafeVarInfo, n::Int) = set_num_produce!(vi.varinfo, n syms(vi::ThreadSafeVarInfo) = syms(vi.varinfo) -function setgid!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) - return setgid!(vi.varinfo, gid, vn) +function setgid!!(vi::ThreadSafeVarInfo, gid::Selector, vn::VarName) + return setgid!!(vi.varinfo, gid, vn) end setorder!(vi::ThreadSafeVarInfo, vn::VarName, index::Int) = setorder!(vi.varinfo, vn, index) setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) @@ -55,8 +55,8 @@ setval!(vi::ThreadSafeVarInfo, val, vn::VarName) = setval!(vi.varinfo, val, vn) keys(vi::ThreadSafeVarInfo) = keys(vi.varinfo) haskey(vi::ThreadSafeVarInfo, vn::VarName) = haskey(vi.varinfo, vn) -link!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!(vi.varinfo, spl) -invlink!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!(vi.varinfo, spl) +link!!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = link!!(vi.varinfo, spl) +invlink!!(vi::ThreadSafeVarInfo, spl::AbstractSampler) = invlink!!(vi.varinfo, spl) islinked(vi::ThreadSafeVarInfo, spl::AbstractSampler) = islinked(vi.varinfo, spl) getindex(vi::ThreadSafeVarInfo, spl::AbstractSampler) = getindex(vi.varinfo, spl) @@ -80,20 +80,20 @@ function set_retained_vns_del_by_spl!(vi::ThreadSafeVarInfo, spl::Sampler) end isempty(vi::ThreadSafeVarInfo) = isempty(vi.varinfo) -function empty!(vi::ThreadSafeVarInfo) - empty!(vi.varinfo) +function empty!!(vi::ThreadSafeVarInfo) + empty!!(vi.varinfo) fill!(vi.logps, zero(getlogp(vi))) return vi end -function push!( +function push!!( vi::ThreadSafeVarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector} ) - return push!(vi.varinfo, vn, r, dist, gidset) + return push!!(vi.varinfo, vn, r, dist, gidset) end -function unset_flag!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) - return unset_flag!(vi.varinfo, vn, flag) +function unset_flag!!(vi::ThreadSafeVarInfo, vn::VarName, flag::String) + return unset_flag!!(vi.varinfo, vn, flag) end function is_flagged(vi::ThreadSafeVarInfo, vn::VarName, flag::String) return is_flagged(vi.varinfo, vn, flag) diff --git a/src/utils.jl b/src/utils.jl index db7faabbd..537fcc90e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -9,7 +9,7 @@ Add the result of the evaluation of `ex` to the joint log probability. """ macro addlogprob!(ex) return quote - acclogp!($(esc(:(__varinfo__))), $(esc(ex))) + $(esc(:(__varinfo__))) = acclogp!!($(esc(:(__varinfo__))), $(esc(ex))) end end @@ -44,6 +44,20 @@ function getargs_tilde(expr::Expr) end end +""" + getargs_assignment(x) + +Return the arguments `L` and `R`, if `x` is an expression of the form `L = R`, or `nothing` +otherwise. +""" +getargs_assignment(x) = nothing +function getargs_assignment(expr::Expr) + return MacroTools.@match expr begin + (L_ = R_) => (L, R) + x_ => nothing + end +end + function to_namedtuple_expr(syms, vals=syms) length(syms) == 0 && return :(NamedTuple()) @@ -66,11 +80,10 @@ vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r)) # otherwise we will have error for MatrixDistribution. # Note this is not the case for MultivariateDistribution so I guess this might be lack of # support for some types related to matrices (like PDMat). -reconstruct(d::UnivariateDistribution, val::AbstractVector) = val[1] -reconstruct(d::MultivariateDistribution, val::AbstractVector) = copy(val) -function reconstruct(d::MatrixDistribution, val::AbstractVector) - return reshape(copy(val), size(d)) -end +reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val) +reconstruct(::Tuple{}, val::AbstractVector) = val[1] +reconstruct(s::NTuple{1}, val::AbstractVector) = copy(val) +reconstruct(s::NTuple{2}, val::AbstractVector) = reshape(copy(val), s) function reconstruct!(r, d::Distribution, val::AbstractVector) return reconstruct!(r, d, val) end @@ -79,17 +92,17 @@ function reconstruct!(r, d::MultivariateDistribution, val::AbstractVector) return r end function reconstruct(d::Distribution, val::AbstractVector, n::Int) - return reconstruct(d, val, n) + return reconstruct(size(d), val, n) end -function reconstruct(d::UnivariateDistribution, val::AbstractVector, n::Int) +function reconstruct(::Tuple{}, val::AbstractVector, n::Int) return copy(val) end -function reconstruct(d::MultivariateDistribution, val::AbstractVector, n::Int) - return copy(reshape(val, size(d)[1], n)) +function reconstruct(s::NTuple{1}, val::AbstractVector, n::Int) + return copy(reshape(val, s[1], n)) end -function reconstruct(d::MatrixDistribution, val::AbstractVector, n::Int) - tmp = reshape(val, size(d)[1], size(d)[2], n) - orig = [tmp[:, :, i] for i in 1:size(tmp, 3)] +function reconstruct(s::NTuple{2}, val::AbstractVector, n::Int) + tmp = reshape(val, s..., n) + orig = [tmp[:, :, i] for i in 1:n] return orig end function reconstruct!(r, d::Distribution, val::AbstractVector, n::Int) diff --git a/src/varinfo.jl b/src/varinfo.jl index 64c122dc2..6815a5333 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -335,14 +335,15 @@ getall(vi::TypedVarInfo) = vcat(_getall(vi.metadata)...) end """ - setall!(vi::VarInfo, val) + setall!!(vi::VarInfo, val) -Set the values of all the variables in `vi` to `val`. +Set the values of all the variables in `vi` to `val`, +mutating if it makese sense. The values may or may not be transformed to Euclidean space. """ -setall!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val -setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) +setall!!(vi::UntypedVarInfo, val) = vi.metadata.vals .= val +setall!!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val) @generated function _setall!(metadata::NamedTuple{names}, val, start=0) where {names} expr = Expr(:block) start = :(1) @@ -365,10 +366,18 @@ getgid(vi::VarInfo, vn::VarName) = getmetadata(vi, vn).gids[getidx(vi, vn)] """ settrans!(vi::VarInfo, trans::Bool, vn::VarName) -Set the `trans` flag value of `vn` in `vi`. +Set the `trans` flag value of `vn` in `vi`, mutating if it makes sense. """ -function settrans!(vi::AbstractVarInfo, trans::Bool, vn::VarName) - return trans ? set_flag!(vi, vn, "trans") : unset_flag!(vi, vn, "trans") +function settrans!( + vi::AbstractVarInfo, trans::Bool, vn::Union{VarName,AbstractArray{<:VarName}} +) + if trans + set_flag!(vi, vn, "trans") + else + unset_flag!(vi, vn, "trans") + end + + return trans end """ @@ -519,6 +528,14 @@ function set_flag!(vi::VarInfo, vn::VarName, flag::String) return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = true end +function set_flag!(vi::VarInfo, vns::AbstractArray{<:VarName}, flag::String) + foreach(vns) do vn + set_flag!(vi, vn, flag) + end + + return true +end + #### #### APIs for typed and untyped VarInfo #### @@ -593,16 +610,16 @@ end TypedVarInfo(vi::TypedVarInfo) = vi """ - empty!(vi::VarInfo) + empty!!(vi::VarInfo) Empty the fields of `vi.metadata` and reset `vi.logp[]` and `vi.num_produce[]` to -zeros. +zeros, mutating if it makes sense. This is useful when using a sampling algorithm that assumes an empty `vi`, e.g. `SMC`. """ -function empty!(vi::VarInfo) +function empty!!(vi::VarInfo) _empty!(vi.metadata) - resetlogp!(vi) + resetlogp!!(vi) reset_num_produce!(vi) return vi end @@ -660,34 +677,34 @@ Return the log of the joint probability of the observed data and parameters samp getlogp(vi::AbstractVarInfo) = vi.logp[] """ - setlogp!(vi::VarInfo, logp) + setlogp!!(vi::VarInfo, logp) Set the log of the joint probability of the observed data and parameters sampled in -`vi` to `logp`. +`vi` to `logp`, mutating if it makes sense. """ -function setlogp!(vi::VarInfo, logp) +function setlogp!!(vi::VarInfo, logp) vi.logp[] = logp return vi end """ - acclogp!(vi::VarInfo, logp) + acclogp!!(vi::VarInfo, logp) Add `logp` to the value of the log of the joint probability of the observed data and -parameters sampled in `vi`. +parameters sampled in `vi`, mutating if it makes sense. """ -function acclogp!(vi::VarInfo, logp) +function acclogp!!(vi::VarInfo, logp) vi.logp[] += logp return vi end """ - resetlogp!(vi::AbstractVarInfo) + resetlogp!!(vi::AbstractVarInfo) Reset the value of the log of the joint probability of the observed data and parameters -sampled in `vi` to 0. +sampled in `vi` to 0, mutating if it makes sense. """ -resetlogp!(vi::AbstractVarInfo) = setlogp!(vi, zero(getlogp(vi))) +resetlogp!!(vi::AbstractVarInfo) = setlogp!!(vi, zero(getlogp(vi))) """ get_num_produce(vi::VarInfo) @@ -735,13 +752,13 @@ end # X -> R for all variables associated with given sampler """ - link!(vi::VarInfo, spl::Sampler) + link!!(vi::VarInfo, spl::Sampler) Transform the values of the random variables sampled by `spl` in `vi` from the support of their distributions to the Euclidean space and set their corresponding `"trans"` flag values to `true`. """ -function link!(vi::UntypedVarInfo, spl::Sampler) +function link!!(vi::UntypedVarInfo, spl::Sampler) # TODO: Change to a lazy iterator over `vns` vns = _getvns(vi, spl) if ~istrans(vi, vns[1]) @@ -760,10 +777,10 @@ function link!(vi::UntypedVarInfo, spl::Sampler) @warn("[DynamicPPL] attempt to link a linked vi") end end -function link!(vi::TypedVarInfo, spl::AbstractSampler) - return link!(vi, spl, Val(getspace(spl))) +function link!!(vi::TypedVarInfo, spl::AbstractSampler) + return link!!(vi, spl, Val(getspace(spl))) end -function link!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function link!!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _link!(vi.metadata, vi, vns, spaceval) end @@ -804,13 +821,13 @@ end # R -> X for all variables associated with given sampler """ - invlink!(vi::VarInfo, spl::AbstractSampler) + invlink!!(vi::VarInfo, spl::AbstractSampler) Transform the values of the random variables sampled by `spl` in `vi` from the Euclidean space back to the support of their distributions and sets their corresponding `"trans"` flag values to `false`. """ -function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) +function invlink!!(vi::UntypedVarInfo, spl::AbstractSampler) vns = _getvns(vi, spl) if istrans(vi, vns[1]) for vn in vns @@ -827,10 +844,10 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler) @warn("[DynamicPPL] attempt to invlink an invlinked vi") end end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler) - return invlink!(vi, spl, Val(getspace(spl))) +function invlink!!(vi::TypedVarInfo, spl::AbstractSampler) + return invlink!!(vi, spl, Val(getspace(spl))) end -function invlink!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) +function invlink!!(vi::TypedVarInfo, spl::AbstractSampler, spaceval::Val) vns = _getvns(vi, spl) return _invlink!(vi.metadata, vi, vns, spaceval) end @@ -960,7 +977,8 @@ Set the current value(s) of the random variable `vn` in `vi` to `val`. The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, vn::VarName) = setval!(vi, val, vn) +setindex!(vi::AbstractVarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) +setindex!!(vi::AbstractVarInfo, val, vn::VarName) = (setindex!(vi, val, vn); return vi) """ setindex!(vi::VarInfo, val, spl::Union{SampleFromPrior, Sampler}) @@ -969,13 +987,13 @@ Set the current value(s) of the random variables sampled by `spl` in `vi` to `va The value(s) may or may not be transformed to Euclidean space. """ -setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!(vi, val) +setindex!(vi::AbstractVarInfo, val, spl::SampleFromPrior) = setall!!(vi, val) setindex!(vi::UntypedVarInfo, val, spl::Sampler) = setval!(vi, val, _getranges(vi, spl)) function setindex!(vi::TypedVarInfo, val, spl::Sampler) # Gets a `NamedTuple` mapping each symbol to the indices in the symbol's `vals` field sampled from the sampler `spl` ranges = _getranges(vi, spl) _setindex!(vi.metadata, val, ranges) - return val + return vi end # Recursively writes the entries of `val` to the `vals` fields of all the symbols as if they were a contiguous vector. @generated function _setindex!(metadata, val, ranges::NamedTuple{names}) where {names} @@ -1093,42 +1111,42 @@ function Base.show(io::IO, vi::UntypedVarInfo) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution) Push a new random variable `vn` with a sampled value `r` from a distribution `dist` to -the `VarInfo` `vi`. +the `VarInfo` `vi`, mutating if it makes sense. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) - return push!(vi, vn, r, dist, Set{Selector}([])) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution) + return push!!(vi, vn, r, dist, Set{Selector}([])) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler) Push a new random variable `vn` with a sampled value `r` sampled with a sampler `spl` -from a distribution `dist` to `VarInfo` `vi`. +from a distribution `dist` to `VarInfo` `vi`, if it makes sense. The sampler is passed here to invalidate its cache where defined. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) - return push!(vi, vn, r, dist, spl.selector) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::Sampler) + return push!!(vi, vn, r, dist, spl.selector) end -function push!( +function push!!( vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, spl::AbstractSampler ) - return push!(vi, vn, r, dist) + return push!!(vi, vn, r, dist) end """ - push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gid::Selector) Push a new random variable `vn` with a sampled value `r` sampled with a sampler of selector `gid` from a distribution `dist` to `VarInfo` `vi`. """ -function push!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) - return push!(vi, vn, r, dist, Set([gid])) +function push!!(vi::AbstractVarInfo, vn::VarName, r, dist::Distribution, gid::Selector) + return push!!(vi, vn, r, dist, Set([gid])) end -function push!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) +function push!!(vi::VarInfo, vn::VarName, r, dist::Distribution, gidset::Set{Selector}) if vi isa UntypedVarInfo @assert ~(vn in keys(vi)) "[push!] attempt to add an exisitng variable $(getsym(vn)) ($(vn)) to VarInfo (keys=$(keys(vi))) with dist=$dist, gid=$gidset" elseif vi isa TypedVarInfo @@ -1189,6 +1207,14 @@ function unset_flag!(vi::VarInfo, vn::VarName, flag::String) return getmetadata(vi, vn).flags[flag][getidx(vi, vn)] = false end +function unset_flag!(vi::VarInfo, vns::AbstractArray{<:VarName}, flag::String) + foreach(vns) do vn + unset_flag!(vi, vn, flag) + end + + return false +end + """ set_retained_vns_del_by_spl!(vi::VarInfo, spl::Sampler) @@ -1395,7 +1421,7 @@ function setval!( return setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) +function _setval_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1414,7 +1440,7 @@ end Set the values in `vi` to the provided values and those which are not present in `x` or `chains` to *be* resampled. -Note that this does *not* resample the values not provided! It will call `setflag!(vi, vn, "del")` +Note that this does *not* resample the values not provided! It will call [`set_flag!(vi, vn, "del")`](@ref) for variables `vn` for which no values are provided, which means that the next time we call `model(vi)` these variables will be resampled. @@ -1476,7 +1502,7 @@ function setval_and_resample!( return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) end -function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys) +function _setval_and_resample_kernel!(vi::VarInfo, vn::VarName, values, keys) indices = findall(Base.Fix1(subsumes_string, string(vn)), keys) if !isempty(indices) val = reduce(vcat, values[indices]) @@ -1490,3 +1516,32 @@ function _setval_and_resample_kernel!(vi::AbstractVarInfo, vn::VarName, values, return indices end + +""" + values_as(vi::AbstractVarInfo, ::Type{NamedTuple}) + values_as(vi::AbstractVarInfo, ::Type{Dict}) + +Return values in `vi` as the specified type. +""" +function values_as(vi::UntypedVarInfo, ::Type{NamedTuple}) + iter = values_from_metadata(vi.metadata) + return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) +end +values_as(vi::UntypedVarInfo, ::Type{Dict}) = Dict(values_from_metadata(vi.metadata)) + +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{NamedTuple}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return NamedTuple(map(p -> Symbol(p.first) => p.second, iter)) +end + +function values_as(vi::VarInfo{<:NamedTuple{names}}, ::Type{Dict}) where {names} + iter = Iterators.flatten(values_from_metadata(getfield(vi.metadata, n)) for n in names) + return Dict(iter) +end + +function values_from_metadata(md::Metadata) + return ( + vn => reconstruct(md.dists[md.idcs[vn]], md.vals[md.ranges[md.idcs[vn]]]) for + vn in md.vns + ) +end diff --git a/test/compiler.jl b/test/compiler.jl index 6f85e9453..2b9a2273b 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -364,8 +364,8 @@ end end @model function demo_useval(x, y) - x1 = @submodel sub1 demo_return(x) - x2 = @submodel sub2 demo_return(y) + @submodel sub1 x1 = demo_return(x) + @submodel sub2 x2 = demo_return(y) return z ~ Normal(x1 + x2 + 100, 1.0) end @@ -399,7 +399,7 @@ end num_steps = length(y[1]) num_obs = length(y) @inbounds for i in 1:num_obs - x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ) + @submodel $(Symbol("ar1_$i")) x = AR1(num_steps, α, μ, σ) y[i] ~ MvNormal(x, 0.1) end end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index db01a0b9a..568a832ae 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -69,7 +69,7 @@ end @model function gdemo9() # Submodel prior - m = @submodel _prior_dot_assume() + @submodel m = _prior_dot_assume() for i in eachindex(m) 10.0 ~ Normal(m[i], 0.5) end @@ -110,7 +110,7 @@ const gdemo_models = ( ) @testset "loglikelihoods.jl" begin - for m in gdemo_models + @testset "$(m.name)" for m in gdemo_models vi = VarInfo(m) vns = vi.metadata.m.vns diff --git a/test/threadsafe.jl b/test/threadsafe.jl index 83c53ccd6..bd1f4f154 100644 --- a/test/threadsafe.jl +++ b/test/threadsafe.jl @@ -17,17 +17,17 @@ lp = getlogp(vi) @test getlogp(threadsafe_vi) == lp - acclogp!(threadsafe_vi, 42) + acclogp!!(threadsafe_vi, 42) @test threadsafe_vi.logps[Threads.threadid()][] == 42 @test getlogp(vi) == lp @test getlogp(threadsafe_vi) == lp + 42 - resetlogp!(threadsafe_vi) + resetlogp!!(threadsafe_vi) @test iszero(getlogp(vi)) @test iszero(getlogp(threadsafe_vi)) @test all(iszero(x[]) for x in threadsafe_vi.logps) - setlogp!(threadsafe_vi, 42) + setlogp!!(threadsafe_vi, 42) @test getlogp(vi) == 42 @test getlogp(threadsafe_vi) == 42 @test all(iszero(x[]) for x in threadsafe_vi.logps) diff --git a/test/varinfo.jl b/test/varinfo.jl index 4c8ec43cb..d642f1f48 100644 --- a/test/varinfo.jl +++ b/test/varinfo.jl @@ -33,7 +33,7 @@ end @testset "Base" begin # Test Base functions: - # string, Symbol, ==, hash, in, keys, haskey, isempty, push!, empty!, + # string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!, # getindex, setindex!, getproperty, setproperty! csym = gensym() vn1 = @varname x[1][2] @@ -46,7 +46,7 @@ @test inspace(vn1, (:x,)) function test_base!(vi) - empty!(vi) + empty!!(vi) @test getlogp(vi) == 0 @test get_num_produce(vi) == 0 @@ -58,7 +58,7 @@ @test isempty(vi) @test ~haskey(vi, vn) @test !(vn in keys(vi)) - push!(vi, vn, r, dist, gid) + push!!(vi, vn, r, dist, gid) @test ~isempty(vi) @test haskey(vi, vn) @test vn in keys(vi) @@ -75,9 +75,9 @@ @test vi[vn] == 3 * r @test vi[SampleFromPrior()][1] == 3 * r - empty!(vi) + empty!!(vi) @test isempty(vi) - push!(vi, vn, r, dist, gid) + push!!(vi, vn, r, dist, gid) function test_inspace() space = (:x, :y, @varname(z[1]), @varname(M[1:10, :])) @@ -98,7 +98,7 @@ end vi = VarInfo() test_base!(vi) - test_base!(empty!(TypedVarInfo(vi))) + test_base!(empty!!(TypedVarInfo(vi))) end @testset "flags" begin # Test flag setting: @@ -109,7 +109,7 @@ r = rand(dist) gid = Selector() - push!(vi, vn_x, r, dist, gid) + push!!(vi, vn_x, r, dist, gid) # del is set by default @test !is_flagged(vi, vn_x, "del") @@ -122,7 +122,7 @@ end vi = VarInfo() test_varinfo!(vi) - test_varinfo!(empty!(TypedVarInfo(vi))) + test_varinfo!(empty!!(TypedVarInfo(vi))) end @testset "setgid!" begin vi = VarInfo() @@ -133,14 +133,14 @@ gid1 = Selector() gid2 = Selector(2, :HMC) - push!(vi, vn, r, dist, gid1) + push!!(vi, vn, r, dist, gid1) @test meta.gids[meta.idcs[vn]] == Set([gid1]) setgid!(vi, gid2, vn) @test meta.gids[meta.idcs[vn]] == Set([gid1, gid2]) - vi = empty!(TypedVarInfo(vi)) + vi = empty!!(TypedVarInfo(vi)) meta = vi.metadata - push!(vi, vn, r, dist, gid1) + push!!(vi, vn, r, dist, gid1) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1]) setgid!(vi, gid2, vn) @test meta.x.gids[meta.x.idcs[vn]] == Set([gid1, gid2])