diff --git a/Project.toml b/Project.toml index e1a1fddac..201dd62d9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.10.19" +version = "0.10.20" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 62189fa4c..2d6a61f33 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -79,6 +79,7 @@ export AbstractVarInfo, LikelihoodContext, PriorContext, MiniBatchContext, + PrefixContext, assume, dot_assume, observer, @@ -96,7 +97,9 @@ export AbstractVarInfo, logjoint, pointwise_loglikelihoods, # Convenience macros - @addlogprob! + @addlogprob!, + @submodel + # Reexport using Distributions: loglikelihood @@ -124,5 +127,6 @@ include("compiler.jl") include("prob_macro.jl") include("compat/ad.jl") include("loglikelihoods.jl") +include("submodel_macro.jl") end # module diff --git a/src/compiler.jl b/src/compiler.jl index 7f906265a..eb6804476 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -54,7 +54,7 @@ check_tilde_rhs(x::AbstractArray{<:Distribution}) = x ################# """ - @model(expr[, warn = true]) + @model(expr[, warn = false]) Macro to specify a probabilistic model. @@ -73,7 +73,7 @@ end To generate a `Model`, call `model(xvalue)` or `model(xvalue, yvalue)`. """ -macro model(expr, warn=true) +macro model(expr, warn=false) # include `LineNumberNode` with information about the call site in the # generated function for easier debugging and interpretation of error messages esc(model(__module__, __source__, expr, warn)) diff --git a/src/context_implementations.jl b/src/context_implementations.jl index bd3340761..5000b8bfe 100644 --- a/src/context_implementations.jl +++ b/src/context_implementations.jl @@ -39,6 +39,9 @@ end function tilde(rng, ctx::MiniBatchContext, sampler, right, left::VarName, inds, vi) return tilde(rng, ctx.ctx, sampler, right, left, inds, vi) end +function tilde(rng, ctx::PrefixContext, sampler, right, vn::VarName, inds, vi) + return tilde(rng, ctx.ctx, sampler, right, prefix(ctx, vn), inds, vi) +end """ tilde_assume(rng, ctx, sampler, right, vn, inds, vi) @@ -75,6 +78,9 @@ end function tilde(ctx::MiniBatchContext, sampler, right, left, vi) return ctx.loglike_scalar * tilde(ctx.ctx, sampler, right, left, vi) end +function tilde(ctx::PrefixContext, sampler, right, left, vi) + return tilde(ctx.ctx, sampler, right, left, vi) +end """ tilde_observe(ctx, sampler, right, left, vname, vinds, vi) diff --git a/src/contexts.jl b/src/contexts.jl index 2de05a034..0d7006bf0 100644 --- a/src/contexts.jl +++ b/src/contexts.jl @@ -52,3 +52,29 @@ end function MiniBatchContext(ctx = DefaultContext(); batch_size, npoints) return MiniBatchContext(ctx, npoints/batch_size) end + + +struct PrefixContext{Prefix, C} <: AbstractContext + ctx::C +end +PrefixContext{Prefix}(ctx::AbstractContext) where {Prefix} = PrefixContext{Prefix, typeof(ctx)}(ctx) + +const PREFIX_SEPARATOR = Symbol(".") + +function PrefixContext{PrefixInner}( + ctx::PrefixContext{PrefixOuter} +) where {PrefixInner, PrefixOuter} + if @generated + :(PrefixContext{$(QuoteNode(Symbol(PrefixOuter, _prefix_seperator, PrefixInner)))}(ctx.ctx)) + else + PrefixContext{Symbol(PrefixOuter, PREFIX_SEPARATOR, PrefixInner)}(ctx.ctx) + end +end + +function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix, Sym} + if @generated + return :(VarName{$(QuoteNode(Symbol(Prefix, _prefix_seperator, Sym)))}(vn.indexing)) + else + VarName{Symbol(Prefix, PREFIX_SEPARATOR, Sym)}(vn.indexing) + end +end diff --git a/src/submodel_macro.jl b/src/submodel_macro.jl new file mode 100644 index 000000000..267ec8933 --- /dev/null +++ b/src/submodel_macro.jl @@ -0,0 +1,23 @@ +macro submodel(expr) + return quote + _evaluate( + $(esc(:__rng__)), + $(esc(expr)), + $(esc(:__varinfo__)), + $(esc(:__sampler__)), + $(esc(:__context__)) + ) + end +end + +macro submodel(prefix, expr) + return quote + _evaluate( + $(esc(:__rng__)), + $(esc(expr)), + $(esc(:__varinfo__)), + $(esc(:__sampler__)), + PrefixContext{$(esc(Meta.quot(prefix)))}($(esc(:__context__))) + ) + end +end diff --git a/test/compiler.jl b/test/compiler.jl index ca94a53dc..244738c3d 100644 --- a/test/compiler.jl +++ b/test/compiler.jl @@ -314,6 +314,107 @@ end @test demo2()() == 42 end + @testset "submodel" begin + # No prefix, 1 level. + @model function demo1(x) + x ~ Normal() + end; + @model function demo2(x, y) + @submodel demo1(x) + y ~ Uniform() + end; + # No observation. + m = demo2(missing, missing); + vi = VarInfo(m); + ks = keys(vi) + @test VarName(:x) ∈ ks + @test VarName(:y) ∈ ks + + # Observation in top-level. + m = demo2(missing, 1.0); + vi = VarInfo(m); + ks = keys(vi) + @test VarName(:x) ∈ ks + @test VarName(:y) ∉ ks + + # Observation in nested model. + m = demo2(1000.0, missing); + vi = VarInfo(m); + ks = keys(vi) + @test VarName(:x) ∉ ks + @test VarName(:y) ∈ ks + + # Observe all. + m = demo2(1000.0, 0.5); + vi = VarInfo(m); + ks = keys(vi) + @test isempty(ks) + + # Check values makes sense. + @model function demo2(x, y) + @submodel demo1(x) + y ~ Normal(x) + end; + m = demo2(1000.0, missing); + # Mean of `y` should be close to 1000. + @test abs(mean([VarInfo(m)[VarName(:y)] for i = 1:10]) - 1000) ≤ 10; + + # Prefixed submodels and usage of submodel return values. + @model function demo_return(x) + x ~ Normal() + return x + end; + + @model function demo_useval(x, y) + x1 = @submodel sub1 demo_return(x) + x2 = @submodel sub2 demo_return(y) + + z ~ Normal(x1 + x2 + 100, 1.0) + end; + m = demo_useval(missing, missing) + vi = VarInfo(m); + ks = keys(vi) + @test VarName(Symbol("sub1.x")) ∈ ks + @test VarName(Symbol("sub2.x")) ∈ ks + @test VarName(:z) ∈ ks + @test abs(mean([VarInfo(m)[VarName(:z)] for i = 1:10]) - 100) ≤ 10 + + # AR1 model. Dynamic prefixing. + @model function AR1(num_steps, α, μ, σ, ::Type{TV} = Vector{Float64}) where {TV} + η ~ MvNormal(num_steps, 1.0) + δ = sqrt(1 - α^2) + + x = TV(undef, num_steps) + x[1] = η[1] + @inbounds for t = 2:num_steps + x[t] = @. α * x[t - 1] + δ * η[t] + end + + return @. μ + σ * x + end + + @model function demo(y) + α ~ Uniform() + μ ~ Normal() + σ ~ truncated(Normal(), 0, Inf) + + num_steps = length(y[1]) + num_obs = length(y) + @inbounds for i = 1:num_obs + x = @submodel $(Symbol("ar1_$i")) AR1(num_steps, α, μ, σ) + y[i] ~ MvNormal(x, 0.1) + end + end; + + ys = [randn(10), randn(10)]; + m = demo(ys); + vi = VarInfo(m); + + for k in [:α, :μ, :σ, Symbol("ar1_1.η"), Symbol("ar1_2.η")] + @test VarName(k) ∈ keys(vi) + end + end + @testset "check_tilde_rhs" begin @test_throws ArgumentError DynamicPPL.check_tilde_rhs(randn())