diff --git a/Project.toml b/Project.toml index f40c46c0e..9d26a90ce 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.15.2" +version = "0.16.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -12,16 +12,18 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractMCMC = "2, 3.0" -AbstractPPL = "0.2" +AbstractPPL = "0.3" BangBang = "0.3" Bijectors = "0.5.2, 0.6, 0.7, 0.8, 0.9" ChainRulesCore = "0.9.7, 0.10, 1" Distributions = "0.23.8, 0.24, 0.25" MacroTools = "0.5.6" +Setfield = "0.7.1" ZygoteRules = "0.2" julia = "1.3" diff --git a/benchmarks/benchmark_body.jmd b/benchmarks/benchmark_body.jmd index f9c994dc9..9d9810dc2 100644 --- a/benchmarks/benchmark_body.jmd +++ b/benchmarks/benchmark_body.jmd @@ -8,8 +8,15 @@ m = time_model_def(model_def, data); ```julia suite = make_suite(m); -results = run(suite) -results +results = run(suite); +``` + +```julia +results["evaluation_untyped"] +``` + +```julia +results["evaluation_typed"] ``` ```julia; echo=false; results="hidden"; diff --git a/benchmarks/benchmarks.jmd b/benchmarks/benchmarks.jmd index 614afb2e9..5b86b261e 100644 --- a/benchmarks/benchmarks.jmd +++ b/benchmarks/benchmarks.jmd @@ -94,3 +94,37 @@ data = mapreduce(c -> rand(MvNormal([μs[c], μs[c]], 1.), N), hcat, 1:2); ```julia; echo=false weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) ``` + +### `demo4`: loads of indexing + +```julia +@model function demo4(n, ::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, n) + for i in eachindex(x) + x[i] ~ Normal(m, 1.0) + end +end + +model_def = demo4 +data = (100_000, ); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` + +```julia +@model function demo4_dotted(n, ::Type{TV}=Vector{Float64}) where {TV} + m ~ Normal() + x = TV(undef, n) + x .~ Normal(m, 1.0) +end + +model_def = demo4_dotted +data = (100_000, ); +``` + +```julia; echo=false +weave_child(WEAVE_ARGS[:benchmarkbody], mod = @__MODULE__, args = WEAVE_ARGS) +``` diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 1238add26..15ee22325 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -11,6 +11,9 @@ using MacroTools: MacroTools using ZygoteRules: ZygoteRules using BangBang: BangBang +using Setfield: Setfield +using BangBang: BangBang + using Random: Random import Base: diff --git a/src/compiler.jl b/src/compiler.jl index 5c2369f6c..8ad248622 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -18,7 +18,7 @@ function isassumption(expr::Union{Symbol,Expr}) vn = gensym(:vn) return quote - let $vn = $(varname(expr)) + let $vn = $(AbstractPPL.drop_escape(varname(expr))) if $(DynamicPPL.contextual_isassumption)(__context__, $vn) # Considered an assumption by `__context__` which means either: # 1. We hit the default implementation, e.g. using `DefaultContext`, @@ -133,17 +133,17 @@ variables. # Example ```jldoctest; setup=:(using Distributions, LinearAlgebra) -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); string(vns[end]) -"x[:,2]" +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(MvNormal(ones(2), I), randn(2, 2), @varname(x)); vns[end] +x[:,2] -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); string(vns[end]) -"x[:][1,2]" +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x)); vns[end] +x[1,2] -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); string(vns[end]) -"x[1][3]" +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2), @varname(x[:])); vns[end] +x[:][1,2] -julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(1, 2, 3), @varname(x)); string(vns[end]) -"x[1,2,3]" +julia> _, _, vns = DynamicPPL.unwrap_right_left_vns(Normal(), randn(3), @varname(x[1])); vns[end] +x[1][3] ``` """ unwrap_right_left_vns(right, left, vns) = right, left, vns @@ -158,7 +158,7 @@ function unwrap_right_left_vns( # for `i = size(left, 2)`. Hence the symbol should be `x[:, i]`, # and we therefore add the `Colon()` below. vns = map(axes(left, 2)) do i - return VarName(vn, (vn.indexing..., (Colon(), i))) + return vn ∘ Setfield.IndexLens((Colon(), i)) end return unwrap_right_left_vns(right, left, vns) end @@ -168,7 +168,7 @@ function unwrap_right_left_vns( vn::VarName, ) vns = map(CartesianIndices(left)) do i - return VarName(vn, (vn.indexing..., Tuple(i))) + return vn ∘ Setfield.IndexLens(Tuple(i)) end return unwrap_right_left_vns(right, left, vns) end @@ -317,6 +317,10 @@ function generate_mainbody!(mod, found, expr::Expr, warn) # Do not touch interpolated expressions expr.head === :$ && return expr.args[1] + # Do we don't want escaped expressions because we unfortunately + # escape the entire body afterwards. + Meta.isexpr(expr, :escape) && return generate_mainbody(mod, found, expr.args[1], warn) + # If it's a macro, we expand it if Meta.isexpr(expr, :macrocall) return generate_mainbody!(mod, found, macroexpand(mod, expr; recursive=true), warn) @@ -349,6 +353,15 @@ function generate_mainbody!(mod, found, expr::Expr, warn) return Expr(expr.head, map(x -> generate_mainbody!(mod, found, x, warn), expr.args)...) end +function generate_tilde_literal(left, right) + # If the LHS is a literal, it is always an observation + return quote + $(DynamicPPL.tilde_observe!)( + __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ + ) + end +end + """ generate_tilde(left, right) @@ -356,31 +369,20 @@ Generate an `observe` expression for data variables and `assume` expression for variables. """ function generate_tilde(left, right) - # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn isassumption + + # HACK: Usage of `drop_escape` is unfortunate. It's a consequence of the fact + # that in DynamicPPL we the entire function body. Instead we should be + # more selective with our escape. Until that's the case, we remove them all. return quote - $vn = $(varname(left)) - $inds = $(vinds(left)) + $vn = $(AbstractPPL.drop_escape(varname(left))) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left = $(DynamicPPL.tilde_assume!)( - __context__, - $(DynamicPPL.unwrap_right_vn)( - $(DynamicPPL.check_tilde_rhs)($right), $vn - )..., - $inds, - __varinfo__, - ) + $(generate_tilde_assume(left, right, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) @@ -392,44 +394,46 @@ function generate_tilde(left, right) $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, - $inds, __varinfo__, ) end end end +function generate_tilde_assume(left, right, vn) + expr = :( + $left = $(DynamicPPL.tilde_assume!)( + __context__, + $(DynamicPPL.unwrap_right_vn)($(DynamicPPL.check_tilde_rhs)($right), $vn)..., + __varinfo__, + ) + ) + + return if left isa Expr + AbstractPPL.drop_escape( + Setfield.setmacro(BangBang.prefermutation, expr; overwrite=true) + ) + else + return expr + end +end + """ generate_dot_tilde(left, right) Generate the expression that replaces `left .~ right` in the model body. """ function generate_dot_tilde(left, right) - # If the LHS is a literal, it is always an observation - if isliteral(left) - return quote - $(DynamicPPL.dot_tilde_observe!)( - __context__, $(DynamicPPL.check_tilde_rhs)($right), $left, __varinfo__ - ) - end - end + isliteral(left) && return generate_tilde_literal(left, right) # Otherwise it is determined by the model or its value, # if the LHS represents an observation - @gensym vn inds isassumption + @gensym vn isassumption return quote - $vn = $(varname(left)) - $inds = $(vinds(left)) + $vn = $(AbstractPPL.drop_escape(varname(left))) $isassumption = $(DynamicPPL.isassumption(left)) if $isassumption - $left .= $(DynamicPPL.dot_tilde_assume!)( - __context__, - $(DynamicPPL.unwrap_right_left_vns)( - $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn - )..., - $inds, - __varinfo__, - ) + $(generate_dot_tilde_assume(left, right, vn)) else # If `vn` is not in `argnames`, we need to make sure that the variable is defined. if !$(DynamicPPL.inargnames)($vn, __model__) @@ -441,13 +445,27 @@ function generate_dot_tilde(left, right) $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn, - $inds, __varinfo__, ) end end end +function generate_dot_tilde_assume(left, right, vn) + # We don't need to use `Setfield.@set` here since + # `.=` is always going to be inplace + needs `left` to + # be something that supports `.=`. + return :( + $left .= $(DynamicPPL.dot_tilde_assume!)( + __context__, + $(DynamicPPL.unwrap_right_left_vns)( + $(DynamicPPL.check_tilde_rhs)($right), $(maybe_view(left)), $vn + )..., + __varinfo__, + ) + ) +end + const FloatOrArrayType = Type{<:Union{AbstractFloat,AbstractArray}} hasmissing(T::Type{<:AbstractArray{TA}}) where {TA<:AbstractArray} = hasmissing(TA) hasmissing(T::Type{<:AbstractArray{>:Missing}}) = true diff --git a/src/context_implementations.jl b/src/context_implementations.jl index 498e83492..19b5ce061 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -14,21 +14,9 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg))) require_gradient(spl::Sampler) = false require_particles(spl::Sampler) = false -_getindex(x, inds::Tuple) = _getindex(Base.maybeview(x, first(inds)...), Base.tail(inds)) -_getindex(x, inds::Tuple{}) = x -_getvalue(x, vn::VarName{sym}) where {sym} = _getindex(getproperty(x, sym), vn.indexing) -function _getvalue(x, vns::AbstractVector{<:VarName{sym}}) where {sym} - val = getproperty(x, sym) - - # This should work with both cartesian and linear indexing. - return map(vns) do vn - _getindex(val, vn) - end -end - # assume """ - tilde_assume(context::SamplingContext, right, vn, inds, vi) + tilde_assume(context::SamplingContext, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value with a context associated @@ -36,18 +24,18 @@ with a sampler. Falls back to ```julia -tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) ``` """ -function tilde_assume(context::SamplingContext, right, vn, inds, vi) - return tilde_assume(context.rng, context.context, context.sampler, right, vn, inds, vi) +function tilde_assume(context::SamplingContext, right, vn, vi) + return tilde_assume(context.rng, context.context, context.sampler, right, vn, vi) end # Leaf contexts function tilde_assume(context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), context, args...) end -function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vinds, vi) +function tilde_assume(::IsLeaf, context::AbstractContext, right, vn, vi) return assume(right, vn, vi) end function tilde_assume(::IsParent, context::AbstractContext, args...) @@ -57,44 +45,36 @@ end function tilde_assume(rng, context::AbstractContext, args...) return tilde_assume(NodeTrait(tilde_assume, context), rng, context, args...) end -function tilde_assume( - ::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vinds, vi -) +function tilde_assume(::IsLeaf, rng, context::AbstractContext, sampler, right, vn, vi) return assume(rng, sampler, right, vn, vi) end function tilde_assume(::IsParent, rng, context::AbstractContext, args...) return tilde_assume(rng, childcontext(context), args...) end -function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, inds, vi) +function tilde_assume(context::PriorContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end - return tilde_assume(PriorContext(), right, vn, inds, vi) + return tilde_assume(PriorContext(), right, vn, vi) end function tilde_assume( - rng::Random.AbstractRNG, - context::PriorContext{<:NamedTuple}, - sampler, - right, - vn, - inds, - vi, + rng::Random.AbstractRNG, context::PriorContext{<:NamedTuple}, sampler, right, vn, vi ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end - return tilde_assume(rng, PriorContext(), sampler, right, vn, inds, vi) + return tilde_assume(rng, PriorContext(), sampler, right, vn, vi) end -function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, inds, vi) +function tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, vn, vi) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end - return tilde_assume(LikelihoodContext(), right, vn, inds, vi) + return tilde_assume(LikelihoodContext(), right, vn, vi) end function tilde_assume( rng::Random.AbstractRNG, @@ -102,42 +82,39 @@ function tilde_assume( sampler, right, vn, - inds, vi, ) if haskey(context.vars, getsym(vn)) - vi[vn] = vectorize(right, _getindex(getfield(context.vars, getsym(vn)), inds)) + vi[vn] = vectorize(right, get(context.vars, vn)) settrans!(vi, false, vn) end - return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, inds, vi) + return tilde_assume(rng, LikelihoodContext(), sampler, right, vn, vi) end -function tilde_assume(::LikelihoodContext, right, vn, inds, vi) +function tilde_assume(::LikelihoodContext, right, vn, vi) return assume(NoDist(right), vn, vi) end -function tilde_assume( - rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, inds, vi -) +function tilde_assume(rng::Random.AbstractRNG, ::LikelihoodContext, sampler, right, vn, vi) return assume(rng, sampler, NoDist(right), vn, vi) end -function tilde_assume(context::PrefixContext, right, vn, inds, vi) - return tilde_assume(context.context, right, prefix(context, vn), inds, vi) +function tilde_assume(context::PrefixContext, right, vn, vi) + return tilde_assume(context.context, right, prefix(context, vn), vi) end -function tilde_assume(rng, context::PrefixContext, sampler, right, vn, inds, vi) - return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), inds, vi) +function tilde_assume(rng, context::PrefixContext, sampler, right, vn, vi) + return tilde_assume(rng, context.context, sampler, right, prefix(context, vn), vi) end """ - tilde_assume!(context, right, vn, inds, vi) + tilde_assume!(context, right, vn, vi) Handle assumed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the sampled value. -By default, calls `tilde_assume(context, right, vn, inds, vi)` and accumulates the log +By default, calls `tilde_assume(context, right, vn, vi)` and accumulates the log probability of `vi` with the returned value. """ -function tilde_assume!(context, right, vn, inds, vi) - value, logp = tilde_assume(context, right, vn, inds, vi) +function tilde_assume!(context, right, vn, vi) + value, logp = tilde_assume(context, right, vn, vi) acclogp!(vi, logp) return value end @@ -180,7 +157,7 @@ function tilde_observe(context::PrefixContext, right, left, vi) end """ - tilde_observe!(context, right, left, vname, vinds, vi) + tilde_observe!(context, right, left, vname, vi) Handle observed variables, e.g., `x ~ Normal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. @@ -188,7 +165,7 @@ accumulate the log probability, and return the observed value. Falls back to `tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function tilde_observe!(context, right, left, vname, vinds, vi) +function tilde_observe!(context, right, left, vname, vi) return tilde_observe!(context, right, left, vi) end @@ -260,7 +237,7 @@ end # assume """ - dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) + dot_tilde_assume(context::SamplingContext, right, left, vn, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value for a context @@ -268,12 +245,12 @@ associated with a sampler. Falls back to ```julia -dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, inds, vi) +dot_tilde_assume(context.rng, context.context, context.sampler, right, left, vn, vi) ``` """ -function dot_tilde_assume(context::SamplingContext, right, left, vn, inds, vi) +function dot_tilde_assume(context::SamplingContext, right, left, vn, vi) return dot_tilde_assume( - context.rng, context.context, context.sampler, right, left, vn, inds, vi + context.rng, context.context, context.sampler, right, left, vn, vi ) end @@ -285,12 +262,10 @@ function dot_tilde_assume(rng, context::AbstractContext, args...) return dot_tilde_assume(rng, NodeTrait(dot_tilde_assume, context), context, args...) end -function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, inds, vi) +function dot_tilde_assume(::IsLeaf, ::AbstractContext, right, left, vns, vi) return dot_assume(right, left, vns, vi) end -function dot_tilde_assume( - ::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, inds, vi -) +function dot_tilde_assume(::IsLeaf, rng, ::AbstractContext, sampler, right, left, vns, vi) return dot_assume(rng, sampler, right, vns, left, vi) end @@ -301,22 +276,20 @@ function dot_tilde_assume(rng, ::IsParent, context::AbstractContext, args...) return dot_tilde_assume(rng, childcontext(context), args...) end -function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, inds, vi) +function dot_tilde_assume(rng, ::DefaultContext, sampler, right, left, vns, vi) return dot_assume(rng, sampler, right, vns, left, vi) end # `LikelihoodContext` -function dot_tilde_assume( - context::LikelihoodContext{<:NamedTuple}, right, left, vn, inds, vi -) +function dot_tilde_assume(context::LikelihoodContext{<:NamedTuple}, right, left, vn, vi) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, inds, vi) + dot_tilde_assume(LikelihoodContext(), _right, _left, _vns, vi) else - dot_tilde_assume(LikelihoodContext(), right, left, vn, inds, vi) + dot_tilde_assume(LikelihoodContext(), right, left, vn, vi) end end function dot_tilde_assume( @@ -326,38 +299,37 @@ function dot_tilde_assume( right, left, vn, - inds, vi, ) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume(rng, LikelihoodContext(), sampler, _right, _left, _vns, vi) else - dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(rng, LikelihoodContext(), sampler, right, left, vn, vi) end end -function dot_tilde_assume(context::LikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_assume(context::LikelihoodContext, right, left, vn, vi) return dot_assume(NoDist.(right), left, vn, vi) end function dot_tilde_assume( - rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, inds, vi + rng::Random.AbstractRNG, context::LikelihoodContext, sampler, right, left, vn, vi ) return dot_assume(rng, sampler, NoDist.(right), vn, left, vi) end # `PriorContext` -function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, inds, vi) +function dot_tilde_assume(context::PriorContext{<:NamedTuple}, right, left, vn, vi) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(PriorContext(), _right, _left, _vns, inds, vi) + dot_tilde_assume(PriorContext(), _right, _left, _vns, vi) else - dot_tilde_assume(PriorContext(), right, left, vn, inds, vi) + dot_tilde_assume(PriorContext(), right, left, vn, vi) end end function dot_tilde_assume( @@ -367,41 +339,40 @@ function dot_tilde_assume( right, left, vn, - inds, vi, ) return if haskey(context.vars, getsym(vn)) - var = _getindex(getfield(context.vars, getsym(vn)), inds) + var = get(context.vars, vn) _right, _left, _vns = unwrap_right_left_vns(right, var, vn) set_val!(vi, _vns, _right, _left) settrans!.(Ref(vi), false, _vns) - dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, inds, vi) + dot_tilde_assume(rng, PriorContext(), sampler, _right, _left, _vns, vi) else - dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, inds, vi) + dot_tilde_assume(rng, PriorContext(), sampler, right, left, vn, vi) end end # `PrefixContext` -function dot_tilde_assume(context::PrefixContext, right, left, vn, inds, vi) - return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), inds, vi) +function dot_tilde_assume(context::PrefixContext, right, left, vn, vi) + return dot_tilde_assume(context.context, right, prefix.(Ref(context), vn), vi) end -function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, inds, vi) +function dot_tilde_assume(rng, context::PrefixContext, sampler, right, left, vn, vi) return dot_tilde_assume( - rng, context.context, sampler, right, prefix.(Ref(context), vn), inds, vi + rng, context.context, sampler, right, prefix.(Ref(context), vn), vi ) end """ - dot_tilde_assume!(context, right, left, vn, inds, vi) + dot_tilde_assume!(context, right, left, vn, vi) Handle broadcasted assumed variables, e.g., `x .~ MvNormal()` (where `x` does not occur in the model inputs), accumulate the log probability, and return the sampled value. -Falls back to `dot_tilde_assume(context, right, left, vn, inds, vi)`. +Falls back to `dot_tilde_assume(context, right, left, vn, vi)`. """ -function dot_tilde_assume!(context, right, left, vn, inds, vi) - value, logp = dot_tilde_assume(context, right, left, vn, inds, vi) +function dot_tilde_assume!(context, right, left, vn, vi) + value, logp = dot_tilde_assume(context, right, left, vn, vi) acclogp!(vi, logp) return value end @@ -598,7 +569,7 @@ function dot_tilde_observe(context::PrefixContext, right, left, vi) end """ - dot_tilde_observe!(context, right, left, vname, vinds, vi) + dot_tilde_observe!(context, right, left, vname, vi) Handle broadcasted observed values, e.g., `x .~ MvNormal()` (where `x` does occur in the model inputs), accumulate the log probability, and return the observed value. @@ -606,7 +577,7 @@ accumulate the log probability, and return the observed value. Falls back to `dot_tilde_observe!(context, right, left, vi)` ignoring the information about variable name and indices; if needed, these can be accessed through this function, though. """ -function dot_tilde_observe!(context, right, left, vn, inds, vi) +function dot_tilde_observe!(context, right, left, vn, vi) return dot_tilde_observe!(context, right, left, vi) end diff --git a/src/contexts.jl b/src/contexts.jl index 98eb4b85d..03fc26245 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -246,9 +246,9 @@ end function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym} if @generated - return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(vn.indexing)) + return :(VarName{$(QuoteNode(Symbol(Prefix, PREFIX_SEPARATOR, Sym)))}(getlens(vn))) else - VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) + VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(getlens(vn)) end end @@ -311,7 +311,7 @@ Return value of `vn` in `context`. function getvalue(context::AbstractContext, vn) return error("context $(context) does not contain value for $vn") end -getvalue(context::ConditionContext, vn) = _getvalue(context.values, vn) +getvalue(context::ConditionContext, vn) = get(context.values, vn) """ hasvalue_nested(context, vn) diff --git a/src/loglikelihoods.jl b/src/loglikelihoods.jl index 4e221fd57..cd50811c1 100644 --- a/src/loglikelihoods.jl +++ b/src/loglikelihoods.jl @@ -71,7 +71,7 @@ function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi) # Defer literal `observe` to child-context. return tilde_observe!(context.context, right, left, vi) end -function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vinds, vi) +function tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `tilde_observe!`. logp = tilde_observe(context.context, right, left, vi) @@ -87,7 +87,7 @@ function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vi # Defer literal `observe` to child-context. return dot_tilde_observe!(context.context, right, left, vi) end -function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, inds, vi) +function dot_tilde_observe!(context::PointwiseLikelihoodContext, right, left, vn, vi) # Need the `logp` value, so we cannot defer `acclogp!` to child-context, i.e. # we have to intercept the call to `dot_tilde_observe!`. diff --git a/test/Project.toml b/test/Project.toml index 948d8e5af..3af6ef22d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -13,6 +13,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -20,7 +21,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] AbstractMCMC = "2.1, 3.0" -AbstractPPL = "0.2" +AbstractPPL = "0.3" Bijectors = "0.9.5" Distributions = "0.25" DistributionsAD = "0.6.3" @@ -28,6 +29,7 @@ Documenter = "0.26.1, 0.27" ForwardDiff = "0.10.12" MCMCChains = "4.0.4, 5" MacroTools = "0.5.5" +Setfield = "0.7.1" StableRNGs = "1" Tracker = "0.2.11" Zygote = "0.5.4, 0.6" diff --git a/test/compiler.jl b/test/compiler.jl index 0f072b468..9f6c0163f 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -28,6 +28,11 @@ macro mymodel2(ex) end end +# Used to test sampling of immutable types. +struct MyCoolStruct{T} + a::T +end + @testset "compiler.jl" begin @testset "model macro" begin @model function testmodel_comp(x, y) @@ -235,6 +240,51 @@ end @test haskey(vi.metadata, :x) vi = VarInfo(gdemo(x)) @test haskey(vi.metadata, :x) + + # Non-array variables + @model function testmodel_nonarray(x, y) + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + for i in 1:(length(x.a) - 1) + x.a[i] ~ Normal(m, √s) + end + + # Dynamic indexing + x.a[end] ~ Normal(100.0, 1.0) + + # Immutable set + y.a ~ Normal() + + # Dotted + z = Vector{Float64}(undef, 3) + z[1:2] .~ Normal() + z[end:end] .~ Normal() + + return (; s=s, m=m, x=x, y=y, z=z) + end + + m_nonarray = testmodel_nonarray( + MyCoolStruct([missing, missing]), MyCoolStruct(missing) + ) + result = m_nonarray() + @test !any(ismissing, result.x.a) + @test result.y.a !== missing + @test result.x.a[end] > 10 + + # Ensure that we can work with `Vector{Real}(undef, N)` which is the + # reason why we're using `BangBang.prefermutation` in `src/compiler.jl` + # rather than the default from Setfield.jl. + # Related: https://github.com/jw3126/Setfield.jl/issues/157 + @model function vdemo() + x = Vector{Real}(undef, 10) + for i in eachindex(x) + x[i] ~ Normal(0, sqrt(4)) + end + + return x + end + x = vdemo()() + @test all((isassigned(x, i) for i in eachindex(x))) end @testset "nested model" begin function makemodel(p) diff --git a/test/contexts.jl b/test/contexts.jl index c63535cb3..edf581d4d 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -1,4 +1,4 @@ -using Test, DynamicPPL +using Test, DynamicPPL, Setfield using DynamicPPL: leafcontext, setleafcontext, @@ -53,7 +53,7 @@ Return `vn` but now with the prefix removed. """ function remove_prefix(vn::VarName) return VarName{Symbol(split(string(vn), string(DynamicPPL.PREFIX_SEPARATOR))[end])}( - vn.indexing + getlens(vn) ) end @@ -65,11 +65,14 @@ e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1 """ varnames(vn::VarName, val::Real) = [vn] function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return (VarName(vn, (vn.indexing..., Tuple(I))) for I in CartesianIndices(val)) + return ( + VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + I in CartesianIndices(val) + ) end function varnames(vn::VarName, val::AbstractArray) return Iterators.flatten( - varnames(VarName(vn, (vn.indexing..., Tuple(I))), val[I]) for + varnames(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for I in CartesianIndices(val) ) end @@ -183,7 +186,7 @@ end # Let's check elementwise. for vn_child in varnames(vn_without_prefix, val) - if DynamicPPL._getindex(val, vn_child.indexing) === missing + if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @test !contextual_isassumption(context, vn_child) @@ -219,7 +222,7 @@ end @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. @test getvalue_nested(context, vn_child) === - DynamicPPL._getindex(val, vn_child.indexing) + get(val, getlens(vn_child)) end end end @@ -246,11 +249,11 @@ end vn = VarName{:x}() vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test vn_prefixed.indexing === vn.indexing + @test getlens(vn_prefixed) === getlens(vn) - vn = VarName{:x}((1,)) + vn = VarName{:x}(((1,),)) vn_prefixed = @inferred DynamicPPL.prefix(ctx, vn) @test DynamicPPL.getsym(vn_prefixed) == Symbol("a.b.c.d.e.f.x") - @test vn_prefixed.indexing === vn.indexing + @test getlens(vn_prefixed) === getlens(vn) end end diff --git a/test/turing/Project.toml b/test/turing/Project.toml index fe186816f..9d75e2dcb 100644 --- a/test/turing/Project.toml +++ b/test/turing/Project.toml @@ -5,6 +5,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [compat] -DynamicPPL = "0.15" +DynamicPPL = "0.16" Turing = "0.18" julia = "1.3" diff --git a/test/turing/varinfo.jl b/test/turing/varinfo.jl index cc8b61d04..892433779 100644 --- a/test/turing/varinfo.jl +++ b/test/turing/varinfo.jl @@ -184,29 +184,6 @@ chain = sample(priorsinarray(xs), HMC(0.01, 10), 10) end @testset "varname" begin - i, j, k = 1, 2, 3 - - vn1 = @varname x[1] - @test vn1 == VarName{:x}(((1,),)) - - # Symbol - v_sym = string(:x) - @test v_sym == "x" - - # Array - v_arr = @varname x[i] - @test v_arr.indexing == ((1,),) - - # Matrix - v_mat = @varname x[i, j] - @test v_mat.indexing == ((1, 2),) - - v_mat = @varname x[i, j, k] - @test v_mat.indexing == ((1, 2, 3),) - - v_mat = @varname x[1, 2][1 + 5][45][3][i] - @test v_mat.indexing == ((1, 2), (6,), (45,), (3,), (1,)) - @model function mat_name_test() p = Array{Any}(undef, 2, 2) for i in 1:2, j in 1:2 @@ -217,10 +194,6 @@ chain = sample(mat_name_test(), HMC(0.2, 4), 1000) check_numerical(chain, ["p[1,1]"], [0]; atol=0.25) - # Multi array - v_arrarr = @varname x[i][j] - @test v_arrarr.indexing == ((1,), (2,)) - @model function marr_name_test() p = Array{Array{Any}}(undef, 2) p[1] = Array{Any}(undef, 2)