From 5d5bc8868b06689383cc1a0a95fa216882160375 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 14:15:18 +0100 Subject: [PATCH 01/35] added example_values and posterior_mean_values methods to models in TestUtils --- src/test_utils.jl | 164 ++++++++++++++++++++++++++++++++++++++++- test/simple_varinfo.jl | 38 ++++------ 2 files changed, 175 insertions(+), 27 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 0b9a5526b..621bbfe30 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -6,6 +6,7 @@ using LinearAlgebra using Distributions using Test +using Random: Random using Bijectors: Bijectors """ @@ -87,6 +88,21 @@ See also: [`logprior_true`](@ref). """ function logprior_true_with_logabsdet_jacobian end +""" + example_values(model::Model) + +Return a `NamedTuple` compatible with `keys(model)` with values in support of `model`. +""" +example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) + +""" + posterior_mean_values(model::Model) + +Return a `NamedTuple` compatible with `keys(model)` where the values represent +the posterior mean under `model`. +""" +function posterior_mean_values end + """ demo_dynamic_constraint() @@ -108,7 +124,12 @@ end function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end - +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dynamic_constraint)} +) + m = rand(rng, Normal()) + return (m=m, x=rand(rng, truncated(Normal(), m, Inf))) +end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dynamic_constraint)}, m, x ) @@ -137,6 +158,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_assume_index_observe( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -159,6 +190,16 @@ end function Base.keys(model::Model{typeof(demo_assume_index_observe)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_index_observe)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_index_observe)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_assume_multivariate_observe(x=[10.0, 10.0]) # Multivariate `assume` and `observe` @@ -176,6 +217,16 @@ end function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(m)] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_multivariate_observe)} +) + return (m=rand(rng, MvNormal(zero(model.args.x), I)),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_multivariate_observe)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_observe_index( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -198,6 +249,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index)}) + vals = example_values(model) + vals.m .= 8 + return vals +end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -217,6 +278,14 @@ end function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(m)] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_dot_observe)} +) + return (m=rand(rng, Normal()),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_dot_observe)}) + return (m=8.0,) +end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -234,6 +303,16 @@ end function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(m)] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_observe_literal)} +) + return (m=rand(rng, MvNormal(zeros(2), I)),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_observe_literal)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -254,6 +333,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index_literal)} +) + return (m=rand(rng, Normal(), 2),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index_literal)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -271,6 +360,14 @@ end function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(m)] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_assume_literal_dot_observe)} +) + return (m=rand(rng, Normal()),) +end +function posterior_mean_values(model::Model{typeof(demo_assume_literal_dot_observe)}) + return (m=8.0,) +end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} m = TV(undef, 2) @@ -299,6 +396,19 @@ end function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, +) + return (m=rand(rng, Normal(), 2),) +end +function posterior_mean_values( + model::Model{typeof(demo_assume_submodel_observe_index_literal)} +) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function _likelihood_dot_observe(m, x) return x ~ MvNormal(m, 0.25 * I) @@ -324,6 +434,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_submodel)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_submodel)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_dot_observe_matrix( x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} @@ -345,6 +465,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe_matrix)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_matrix_dot_observe_matrix( x=fill(10.0, 2, 1), ::Type{TV}=Array{Float64} @@ -369,6 +499,19 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(m[:, 1]), @varname(m[:, 2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} +) + d = length(model.args.x) ÷ 2 + return (m=rand(rng, MvNormal(zeros(d), I), 2),) +end +function posterior_mean_values( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} +) + vals = example_values(model) + vals.m .= 8 + return vals +end @model function demo_dot_assume_array_dot_observe( x=[10.0, 10.0], ::Type{TV}=Vector{Float64} @@ -388,6 +531,16 @@ end function Base.keys(model::Model{typeof(demo_dot_assume_array_dot_observe)}) return [@varname(m[1]), @varname(m[2])] end +function example_values( + rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_array_dot_observe)} +) + return (m=rand(rng, Normal(), length(model.args.x)),) +end +function posterior_mean_values(model::Model{typeof(demo_dot_assume_array_dot_observe)}) + vals = example_values(model) + vals.m .= 8 + return vals +end const DEMO_MODELS = ( demo_dot_assume_dot_observe(), @@ -431,7 +584,6 @@ function test_sampler_demo_models( meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; - target=8.0, atol=1e-1, rtol=1e-3, kwargs..., @@ -439,7 +591,11 @@ function test_sampler_demo_models( @testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) μ = meanfunction(chain) - @test μ ≈ target atol = atol rtol = rtol + target_values = posterior_mean_values(model) + for vn in keys(model) + target = get(target_values, vn) + @test μ ≈ target atol = atol rtol = rtol + end end end @@ -458,7 +614,7 @@ end function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) # Default for `MCMCChains.Chains`. - return test_sampler_continuous(sampler, args...; kwargs...) do chain + return test_sampler_continuous(sampler, args...; kwargs...) do chain, vn mean(Array(chain)) end end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9494ae6c1..2620be405 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -62,18 +62,20 @@ DynamicPPL.TestUtils.DEMO_MODELS # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. - m = model().m - svi_nt = if m isa AbstractArray - SimpleVarInfo((m=similar(m),)) - else - SimpleVarInfo() - end + svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.example_values(model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) - @testset "$(nameof(typeof(svi.values)))" for svi in (svi_nt, svi_dict) + @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( + svi_nt, + svi_dict, + DynamicPPL.settrans!!(svi_nt, true), + DynamicPPL.settrans!!(svi_dict, true), + ) + Random.seed!(42) + # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. - m = model().m + retval = model() ### Sampling ### # Sample a new varinfo! @@ -81,11 +83,7 @@ # Realization for `m` should be different wp. 1. for vn in keys(model) - # `VarName` functions similarly to `PropertyLens` so - # we just strip this part from `vn` to get a lens we can use - # to extract the corresponding value of `m`. - l = getlens(vn) - @test svi_new[vn] != get(m, l) + @test svi_new[vn] != get(retval, vn) end # Logjoint should be non-zero wp. 1. @@ -93,17 +91,12 @@ ### Evaluation ### # Sample some random testing values. - m_eval = if m isa AbstractArray - randn!(similar(m)) - else - randn(eltype(m)) - end + values_eval = DynamicPPL.TestUtils.example_values(model) # Update the realizations in `svi_new`. svi_eval = svi_new for vn in keys(model) - l = getlens(vn) - svi_eval = DynamicPPL.setindex!!(svi_eval, get(m_eval, l), vn) + svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end # Reset the logp field. @@ -114,12 +107,11 @@ # Values should not have changed. for vn in keys(model) - l = getlens(vn) - @test svi_eval[vn] == get(m_eval, l) + @test svi_eval[vn] == get(values_eval, vn) end # Compute the true `logjoint` and compare. - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, m_eval) + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) @test logπ ≈ logπ_true end end From 0498336481b3e92463ab9d849768b18e76129c53 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 30 Jun 2022 15:20:42 +0100 Subject: [PATCH 02/35] demo models in TestUtils are now a bit more complex, including constrained variables --- src/test_utils.jl | 389 ++++++++++++++++++++++++----------------- test/simple_varinfo.jl | 6 +- 2 files changed, 233 insertions(+), 162 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 621bbfe30..e6f11e0ab 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -141,122 +141,157 @@ end # A collection of models for which the mean-of-means for the posterior should # be same. @model function demo_dot_assume_dot_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` + s = TV(undef, length(x)) m = TV(undef, length(x)) - m .~ Normal() - x ~ MvNormal(m, 0.25 * I) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + s .~ InverseGamma(2, 3) + m .~ Normal.(0, sqrt.(s)) + + x ~ MvNormal(m, Diagonal(s)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) + return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_assume_index_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} # `assume` with indexing and `observe` + s = TV(undef, length(x)) + for i in eachindex(s) + s[i] ~ InverseGamma(2, 3) + end m = TV(undef, length(x)) for i in eachindex(m) - m[i] ~ Normal() + m[i] ~ Normal(0, sqrt(s[i])) end - x ~ MvNormal(m, 0.25 * I) + x ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_index_observe)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_index_observe)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end function Base.keys(model::Model{typeof(demo_assume_index_observe)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_index_observe)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_assume_index_observe)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_assume_multivariate_observe(x=[10.0, 10.0]) # Multivariate `assume` and `observe` - m ~ MvNormal(zero(x), I) - x ~ MvNormal(m, 0.25 * I) + s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m ~ MvNormal(zero(x), Diagonal(s)) + x ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) - return logpdf(MvNormal(zero(model.args.x), I), m) +function logprior_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) + s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m_dist = MvNormal(zero(model.args.x), Diagonal(s)) + return logpdf(s_dist, s) + logpdf(m_dist, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_multivariate_observe)} ) - return (m=rand(rng, MvNormal(zero(model.args.x), I)),) + s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) + return (s=s, m=rand(rng, MvNormal(zero(model.args.x), Diagonal(s)))) end function posterior_mean_values(model::Model{typeof(demo_assume_multivariate_observe)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_dot_assume_observe_index( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` with indexing + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) for i in eachindex(x) - x[i] ~ Normal(m[i], 0.5) + x[i] ~ Normal(m[i], sqrt(s[i])) end - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @@ -264,281 +299,314 @@ end # as the others. @model function demo_assume_dot_observe(x=[10.0]) # `assume` and `dot_observe` - m ~ Normal() - x .~ Normal(m, 0.5) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + x .~ Normal(m, sqrt(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return logpdf(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, m) - return sum(logpdf.(Normal.(m, 0.5), model.args.x)) +function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_dot_observe)} ) - return (m=rand(rng, Normal()),) + s = rand(rng, InverseGamma(2, 3)) + m = rand(rng, Normal(0, sqrt(s))) + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_assume_dot_observe)}) - return (m=8.0,) + return (s=2.375, m=0.75) end @model function demo_assume_observe_literal() # `assume` and literal `observe` - m ~ MvNormal(zeros(2), I) - [10.0, 10.0] ~ MvNormal(m, 0.25 * I) + s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m ~ MvNormal(zeros(2), Diagonal(s)) + [1.5, 1.5] ~ MvNormal(m, Diagonal(s)) - return (; m=m, x=[10.0, 10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 1.5], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, m) - return logpdf(MvNormal(zeros(2), I), m) +function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) + m_dist = MvNormal(zeros(2), Diagonal(s)) + return logpdf(s_dist, s) + logpdf(m_dist, m) end -function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, m) - return logpdf(MvNormal(m, 0.25 * I), [10.0, 10.0]) +function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), [1.5, 1.5]) end function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_observe_literal)} ) - return (m=rand(rng, MvNormal(zeros(2), I)),) + s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) + return (s=s, m=rand(rng, MvNormal(zeros(2), Diagonal(s)))) end function posterior_mean_values(model::Model{typeof(demo_assume_observe_literal)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing + s = TV(undef, 2) m = TV(undef, 2) - m .~ Normal() + s .~ InverseGamma(2, 3) + m .~ Normal.(0, sqrt.(s)) + for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) + 1.5 ~ Normal(m[i], sqrt(s[i])) end - return (; m=m, x=fill(10.0, length(m)), logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=fill(1.5, length(m)), logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, m) - return sum(logpdf.(Normal.(m, 0.5), fill(10.0, length(m)))) +function loglikelihood_true( + model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m +) + return sum(logpdf.(Normal.(m, sqrt.(s)), fill(1.5, length(m)))) end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index_literal)} ) - return (m=rand(rng, Normal(), 2),) + n = 2 + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index_literal)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` - m ~ Normal() - [10.0] .~ Normal(m, 0.5) + s ~ InverseGamma(2, 3) + m ~ Normal(0, sqrt(s)) + [1.5] .~ Normal(m, sqrt(s)) - return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) - return logpdf(Normal(), m) +function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) + return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end -function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, m) - return logpdf(Normal(m, 0.5), 10.0) +function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) + return logpdf(Normal(m, sqrt(s)), 1.5) end function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) - return [@varname(m)] + return [@varname(s), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_literal_dot_observe)} ) - return (m=rand(rng, Normal()),) + s = rand(rng, InverseGamma(2, 3)) + m = rand(rng, Normal(0, sqrt(s))) + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_assume_literal_dot_observe)}) - return (m=8.0,) + return (s=2.375, m=0.75) end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} + s = TV(undef, 2) + s .~ InverseGamma(2, 3) m = TV(undef, 2) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) - return m + return s, m end @model function demo_assume_submodel_observe_index_literal() # Submodel prior - @submodel m = _prior_dot_assume() - for i in eachindex(m) - 10.0 ~ Normal(m[i], 0.5) + @submodel s, m = _prior_dot_assume() + for i in eachindex(m, s) + 1.5 ~ Normal(m[i], sqrt(s[i])) end - return (; m=m, x=[10.0], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 1.5], logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m) - return loglikelihood(Normal(), m) +function logprior_true( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m +) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end function loglikelihood_true( - model::Model{typeof(demo_assume_submodel_observe_index_literal)}, m + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m ) - return sum(logpdf.(Normal.(m, 0.5), 10.0)) + return sum(logpdf.(Normal.(m, sqrt.(s)), 1.5)) end function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_assume_submodel_observe_index_literal)}, ) - return (m=rand(rng, Normal(), 2),) + n = 2 + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values( model::Model{typeof(demo_assume_submodel_observe_index_literal)} ) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end -@model function _likelihood_dot_observe(m, x) - return x ~ MvNormal(m, 0.25 * I) +@model function _likelihood_mltivariate_observe(s, m, x) + return x ~ MvNormal(m, Diagonal(s)) end @model function demo_dot_assume_observe_submodel( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} + x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) # Submodel likelihood - @submodel _likelihood_dot_observe(m, x) + @submodel _likelihood_mltivariate_observe(s, m, x) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, m) - return logpdf(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) + return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_submodel)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_submodel)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_dot_assume_dot_observe_matrix( - x=fill(10.0, 2, 1), ::Type{TV}=Vector{Float64} + x=fill(1.5, 2, 1), ::Type{TV}=Vector{Float64} ) where {TV} + s = TV(undef, length(x)) + s .~ InverseGamma(2, 3) m = TV(undef, length(x)) - m .~ Normal() + m .~ Normal.(0, sqrt.(s)) # Dotted observe for `Matrix`. - x .~ MvNormal(m, 0.25 * I) + x .~ MvNormal(m, Diagonal(s)) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) - return loglikelihood(Normal(), m) +function logprior_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) + return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) +function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) + return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - return [@varname(m[1]), @varname(m[2])] + return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe_matrix)} ) - return (m=rand(rng, Normal(), length(model.args.x)),) + n = length(model.args.x) + s = rand(rng, InverseGamma(2, 3), n) + m = similar(s) + for i in eachindex(m, s) + m[i] = rand(rng, Normal(0, sqrt(s[i]))) + end + return (s=s, m=m) end function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @model function demo_dot_assume_matrix_dot_observe_matrix( - x=fill(10.0, 2, 1), ::Type{TV}=Array{Float64} + x=fill(1.5, 2, 1), ::Type{TV}=Array{Float64} ) where {TV} d = length(x) ÷ 2 + s = TV(undef, d, 2) + s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) m = TV(undef, d, 2) m .~ MvNormal(zeros(d), I) # Dotted observe for `Matrix`. - x .~ MvNormal(vec(m), 0.25 * I) + x .~ MvNormal(vec(m), Diagonal(vec(s))) - return (; m=m, x=x, logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end -function logprior_true(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m) - return loglikelihood(Normal(), vec(m)) +function logprior_true( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + return loglikelihood(InverseGamma(2, 3), vec(s)) + loglikelihood(Normal(), vec(m)) end function loglikelihood_true( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, m + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m ) - return loglikelihood(MvNormal(vec(m), 0.25 * I), model.args.x) + return loglikelihood(MvNormal(vec(m), Diagonal(vec(s))), model.args.x) end function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(m[:, 1]), @varname(m[:, 2])] + return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[:, 1]), @varname(m[:, 2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} ) d = length(model.args.x) ÷ 2 - return (m=rand(rng, MvNormal(zeros(d), I), 2),) + s = rand(rng, product_distribution([InverseGamma(2, 3) for _ in 1:d]), 2) + m = similar(s) + for i in 1:size(m, 2) + m[:, i] = rand(rng, MvNormal(zeros(d), Diagonal(vec(s[:, i])))) + end + return (s=s, m=m) end function posterior_mean_values( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} ) vals = example_values(model) - vals.m .= 8 - return vals -end - -@model function demo_dot_assume_array_dot_observe( - x=[10.0, 10.0], ::Type{TV}=Vector{Float64} -) where {TV} - # `dot_assume` and `observe` - m = TV(undef, length(x)) - m .~ [Normal() for _ in 1:length(x)] - x ~ MvNormal(m, 0.25 * I) - return (; m=m, x=x, logp=getlogp(__varinfo__)) -end -function logprior_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) - return loglikelihood(Normal(), m) -end -function loglikelihood_true(model::Model{typeof(demo_dot_assume_array_dot_observe)}, m) - return loglikelihood(MvNormal(m, 0.25 * I), model.args.x) -end -function Base.keys(model::Model{typeof(demo_dot_assume_array_dot_observe)}) - return [@varname(m[1]), @varname(m[2])] -end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_array_dot_observe)} -) - return (m=rand(rng, Normal(), length(model.args.x)),) -end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_array_dot_observe)}) - vals = example_values(model) - vals.m .= 8 + vals.s .= 2.375 + vals.m .= 0.75 return vals end @@ -555,7 +623,6 @@ const DEMO_MODELS = ( demo_dot_assume_observe_submodel(), demo_dot_assume_dot_observe_matrix(), demo_dot_assume_matrix_dot_observe_matrix(), - demo_dot_assume_array_dot_observe(), ) # TODO: Is this really the best/most convenient "default" test method? @@ -590,6 +657,8 @@ function test_sampler_demo_models( ) @testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) + # TODO(torfjelde): Move `meanfunction` into loop below, and have it also + # take `vn` as input. μ = meanfunction(chain) target_values = posterior_mean_values(model) for vn in keys(model) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 2620be405..9a2c8a134 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -111,8 +111,10 @@ end # Compute the true `logjoint` and compare. - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) - @test logπ ≈ logπ_true + if !DynamicPPL.istrans(svi) + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) + @test logπ ≈ logπ_true + end end end From f86f264b5d591e2c3b387fec8b3558ba1486964a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 09:04:32 +0100 Subject: [PATCH 03/35] added logprior_true_with_logabsdet_jacobian for demo models --- src/test_utils.jl | 71 ++++++++++++++++++++++++++++++++++++++++-- test/simple_varinfo.jl | 20 ++++++------ 2 files changed, 80 insertions(+), 11 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index e6f11e0ab..6bcc58a31 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -138,8 +138,15 @@ function logprior_true_with_logabsdet_jacobian( return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp end -# A collection of models for which the mean-of-means for the posterior should -# be same. +# A collection of models for which the posterior should be "similar". +# Some utility methods for these. +function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) + b = Bijectors.bijector(InverseGamma(2, 3)) + s_unconstrained = b.(s) + Δlogp = sum(Base.Fix1(Bijectors.logabsdetjac, b).(s)) + return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp +end + @model function demo_dot_assume_dot_observe( x=[1.5, 1.5], ::Type{TV}=Vector{Float64} ) where {TV} @@ -158,6 +165,11 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe)}, s, m) return loglikelihood(MvNormal(m, Diagonal(s)), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -201,6 +213,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_index_observe)}, s, m) return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_index_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_index_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -238,6 +255,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_multivariate_observe)}, s, m) return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_multivariate_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(s), @varname(m)] end @@ -274,6 +296,11 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_index)}, s, m) return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_index)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -311,6 +338,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe)}, s, m) return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end @@ -341,6 +373,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(MvNormal(m, Diagonal(s)), [1.5, 1.5]) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_observe_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(s), @varname(m)] end @@ -378,6 +415,11 @@ function loglikelihood_true( ) return sum(logpdf.(Normal.(m, sqrt.(s)), fill(1.5, length(m)))) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -413,6 +455,11 @@ end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) return logpdf(Normal(m, sqrt(s)), 1.5) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_literal_dot_observe)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(s), @varname(m)] end @@ -455,6 +502,11 @@ function loglikelihood_true( ) return sum(logpdf.(Normal.(m, sqrt.(s)), 1.5)) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -502,6 +554,11 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m) return logpdf(MvNormal(m, Diagonal(s)), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_observe_submodel)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -542,6 +599,11 @@ end function loglikelihood_true(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m) return sum(logpdf.(Normal.(m, sqrt.(s)), model.args.x)) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_dot_observe_matrix)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end @@ -587,6 +649,11 @@ function loglikelihood_true( ) return loglikelihood(MvNormal(vec(m), Diagonal(vec(s))), model.args.x) end +function logprior_true_with_logabsdet_jacobian( + model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m +) + return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) +end function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[:, 1]), @varname(m[:, 2])] end diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 9a2c8a134..ed5919f5a 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -71,8 +71,6 @@ DynamicPPL.settrans!!(svi_nt, true), DynamicPPL.settrans!!(svi_dict, true), ) - Random.seed!(42) - # Random seed is set in each `@testset`, so we need to sample # a new realization for `m` here. retval = model() @@ -90,8 +88,15 @@ @test getlogp(svi_new) != 0 ### Evaluation ### - # Sample some random testing values. - values_eval = DynamicPPL.TestUtils.example_values(model) + values_eval_constrained = DynamicPPL.TestUtils.example_values(model) + if DynamicPPL.istrans(svi) + values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( + model, values_eval_constrained... + ) + else + values_eval = values_eval_constrained + logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) + end # Update the realizations in `svi_new`. svi_eval = svi_new @@ -110,11 +115,8 @@ @test svi_eval[vn] == get(values_eval, vn) end - # Compute the true `logjoint` and compare. - if !DynamicPPL.istrans(svi) - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) - @test logπ ≈ logπ_true - end + # Compare `logjoint` computations. + @test logπ ≈ logπ_true end end From 0d31137f90ee74fb1803a1f45eb2002c2e832ad7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 09:52:02 +0100 Subject: [PATCH 04/35] fixed mistakes in a couple of models in TestUtils --- src/test_utils.jl | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 6bcc58a31..cdc3da191 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -239,7 +239,7 @@ function posterior_mean_values(model::Model{typeof(demo_assume_index_observe)}) return vals end -@model function demo_assume_multivariate_observe(x=[10.0, 10.0]) +@model function demo_assume_multivariate_observe(x=[1.5, 1.5]) # Multivariate `assume` and `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zero(x), Diagonal(s)) @@ -324,7 +324,7 @@ end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function demo_assume_dot_observe(x=[10.0]) +@model function demo_assume_dot_observe(x=[1.5]) # `assume` and `dot_observe` s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) @@ -628,26 +628,29 @@ end @model function demo_dot_assume_matrix_dot_observe_matrix( x=fill(1.5, 2, 1), ::Type{TV}=Array{Float64} ) where {TV} + n = length(x) d = length(x) ÷ 2 s = TV(undef, d, 2) s .~ product_distribution([InverseGamma(2, 3) for _ in 1:d]) - m = TV(undef, d, 2) - m .~ MvNormal(zeros(d), I) + s_vec = vec(s) + m ~ MvNormal(zeros(n), Diagonal(s_vec)) # Dotted observe for `Matrix`. - x .~ MvNormal(vec(m), Diagonal(vec(s))) + x .~ MvNormal(m, Diagonal(s_vec)) return (; s=s, m=m, x=x, logp=getlogp(__varinfo__)) end function logprior_true( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m ) - return loglikelihood(InverseGamma(2, 3), vec(s)) + loglikelihood(Normal(), vec(m)) + n = length(model.args.x) + s_vec = vec(s) + return loglikelihood(InverseGamma(2, 3), s_vec) + logpdf(MvNormal(zeros(n), s_vec), m) end function loglikelihood_true( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m ) - return loglikelihood(MvNormal(vec(m), Diagonal(vec(s))), model.args.x) + return loglikelihood(MvNormal(m, Diagonal(vec(s))), model.args.x) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m @@ -655,17 +658,15 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[:, 1]), @varname(m[:, 2])] + return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[1]), @varname(m[2])] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} ) - d = length(model.args.x) ÷ 2 + n = length(model.args.x) + d = n ÷ 2 s = rand(rng, product_distribution([InverseGamma(2, 3) for _ in 1:d]), 2) - m = similar(s) - for i in 1:size(m, 2) - m[:, i] = rand(rng, MvNormal(zeros(d), Diagonal(vec(s[:, i])))) - end + m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) return (s=s, m=m) end function posterior_mean_values( From c52630b37761f1bd19e13eac7f5b18f8b6d086b6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 09:52:24 +0100 Subject: [PATCH 05/35] moved varnames method which creates iterator of leaf varnames into TestUtils and starting using this in test_continuous_models --- src/test_utils.jl | 39 ++++++++++++++++++++++++++++++++------- test/contexts.jl | 23 ++--------------------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index cdc3da191..1a6aff5f9 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -8,6 +8,28 @@ using Test using Random: Random using Bijectors: Bijectors +using Setfield: Setfield + +""" + varnames(vn::VarName, val) + +Return iterator over all varnames that are represented by `vn` on `val`, +e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +""" +varnames(vn::VarName, val::Real) = [vn] +function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) + return ( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for + I in CartesianIndices(val) + ) +end +function varnames(vn::VarName, val::AbstractArray) + return Iterators.flatten( + varnames( + VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] + ) for I in CartesianIndices(val) + ) +end """ logprior_true(model, θ) @@ -723,15 +745,17 @@ function test_sampler_demo_models( rtol=1e-3, kwargs..., ) - @testset "$(nameof(typeof(sampler))) on $(nameof(m))" for model in DEMO_MODELS + @testset "$(typeof(sampler)) on $(nameof(model))" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) - # TODO(torfjelde): Move `meanfunction` into loop below, and have it also - # take `vn` as input. - μ = meanfunction(chain) target_values = posterior_mean_values(model) for vn in keys(model) - target = get(target_values, vn) - @test μ ≈ target atol = atol rtol = rtol + # We want to compare elementwise which can be achieved by + # extracting the leaves of the `VarName` and the corresponding value. + for vn_leaf in varnames(vn, get(target_values, vn)) + target_value = get(target_values, vn_leaf) + chain_mean_value = meanfunction(chain, vn_leaf) + @test chain_mean_value ≈ target_value atol = atol rtol = rtol + end end end end @@ -752,7 +776,8 @@ end function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) # Default for `MCMCChains.Chains`. return test_sampler_continuous(sampler, args...; kwargs...) do chain, vn - mean(Array(chain)) + # HACK(torfjelde): This assumes that we can index into `chain` with `Symbol(vn)`. + mean(Array(chain[Symbol(vn)])) end end diff --git a/test/contexts.jl b/test/contexts.jl index 65629afec..24b039852 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -57,25 +57,6 @@ function remove_prefix(vn::VarName) ) end -""" - varnames(vn::VarName, val) - -Return iterator over all varnames that are represented by `vn` on `val`, -e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. -""" -varnames(vn::VarName, val::Real) = [vn] -function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) - return ( - VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for - I in CartesianIndices(val) - ) -end -function varnames(vn::VarName, val::AbstractArray) - return Iterators.flatten( - varnames(VarName(vn, getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I]) for - I in CartesianIndices(val) - ) -end @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] @@ -185,7 +166,7 @@ end vn_without_prefix = remove_prefix(vn) # Let's check elementwise. - for vn_child in varnames(vn_without_prefix, val) + for vn_child in DynamicPPL.TestUtils.varnames(vn_without_prefix, val) if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @@ -217,7 +198,7 @@ end # `ConditionContext` with the conditioned variable. vn_without_prefix = remove_prefix(vn) - for vn_child in varnames(vn_without_prefix, val) + for vn_child in DynamicPPL.TestUtils.varnames(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. From fff060c74ad8f804835bf0283c163bea3c59c9e7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 09:57:51 +0100 Subject: [PATCH 06/35] updated docstring for test_sampler_demo_models --- src/test_utils.jl | 6 +++--- test/contexts.jl | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 1a6aff5f9..69579073b 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -722,8 +722,9 @@ const DEMO_MODELS = ( Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain)` against `target` -provided in `kwargs...`. +`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain, vn)` +for every (leaf) varname `vn` against the corresponding value returned by +[`posterior_mean_values`](@ref) for each model. # Arguments - `meanfunction`: A callable which computes the mean of the marginal means from the @@ -732,7 +733,6 @@ provided in `kwargs...`. - `args...`: Arguments forwarded to `sample`. # Keyword arguments -- `target`: Value to compare result of `meanfunction(chain)` to. - `atol=1e-1`: Absolute tolerance used in `@test`. - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. diff --git a/test/contexts.jl b/test/contexts.jl index 24b039852..ef916b18c 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -57,7 +57,6 @@ function remove_prefix(vn::VarName) ) end - @testset "contexts.jl" begin child_contexts = [DefaultContext(), PriorContext(), LikelihoodContext()] From e21958c5bdc27c7899c22af61532b25421b9ecfc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:47:30 +0100 Subject: [PATCH 07/35] renamed varnames to varname_leaves and renamed keys(model) to varnames(model) --- src/test_utils.jl | 66 +++++++++++++++++++++++++++--------------- test/simple_varinfo.jl | 8 ++--- 2 files changed, 47 insertions(+), 27 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 69579073b..78d279067 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -11,21 +11,21 @@ using Bijectors: Bijectors using Setfield: Setfield """ - varnames(vn::VarName, val) + varname_leaves(vn::VarName, val) Return iterator over all varnames that are represented by `vn` on `val`, -e.g. `varnames(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. +e.g. `varname_leaves(@varname(x), rand(2))` results in an iterator over `[@varname(x[1]), @varname(x[2])]`. """ -varnames(vn::VarName, val::Real) = [vn] -function varnames(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) +varname_leaves(vn::VarName, val::Real) = [vn] +function varname_leaves(vn::VarName, val::AbstractArray{<:Union{Real,Missing}}) return ( VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))) for I in CartesianIndices(val) ) end -function varnames(vn::VarName, val::AbstractArray) +function varname_leaves(vn::VarName, val::AbstractArray) return Iterators.flatten( - varnames( + varname_leaves( VarName(vn, DynamicPPL.getlens(vn) ∘ Setfield.IndexLens(Tuple(I))), val[I] ) for I in CartesianIndices(val) ) @@ -110,18 +110,38 @@ See also: [`logprior_true`](@ref). """ function logprior_true_with_logabsdet_jacobian end +""" + varnames(model::Model) + +Return a collection of `VarName` as they are expected to appear in the model. + +Even though it is recommended to implement this by hand for a particular `Model`, +a default implementation using [`SimpleVarInfo{<:Dict}`](@ref) is provided. +""" +function varnames(model::Model) + return collect( + keys(last(DynamicPPL.evaluate!!(model, SimpleVarInfo(Dict()), SamplingContext()))) + ) +end + """ example_values(model::Model) -Return a `NamedTuple` compatible with `keys(model)` with values in support of `model`. +Return a `NamedTuple` compatible with `varnames(model)` with values in support of `model`. + +Compatible means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using the call `get(example_values(model), varname)`. """ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) """ posterior_mean_values(model::Model) -Return a `NamedTuple` compatible with `keys(model)` where the values represent +Return a `NamedTuple` compatible with `varnames(model)` where the values represent the posterior mean under `model`. + +Compatible means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using the call `get(posterior_mean_values(model), varname)`. """ function posterior_mean_values end @@ -143,7 +163,7 @@ end function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x) return zero(float(eltype(m))) end -function Base.keys(model::Model{typeof(demo_dynamic_constraint)}) +function varnames(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end function example_values( @@ -192,7 +212,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe)}) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -240,7 +260,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_index_observe)}) +function varnames(model::Model{typeof(demo_assume_index_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -282,7 +302,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_multivariate_observe)}) +function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(s), @varname(m)] end function example_values( @@ -323,7 +343,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_index)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -365,7 +385,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end function example_values( @@ -400,7 +420,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_observe_literal)}) +function varnames(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(s), @varname(m)] end function example_values( @@ -442,7 +462,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_index_literal)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -482,7 +502,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_literal_dot_observe)}) +function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(s), @varname(m)] end function example_values( @@ -529,7 +549,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) +function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -581,7 +601,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_observe_submodel)}) +function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -626,7 +646,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -679,7 +699,7 @@ function logprior_true_with_logabsdet_jacobian( ) return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end -function Base.keys(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) +function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[1]), @varname(m[2])] end function example_values( @@ -748,10 +768,10 @@ function test_sampler_demo_models( @testset "$(typeof(sampler)) on $(nameof(model))" for model in DEMO_MODELS chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) target_values = posterior_mean_values(model) - for vn in keys(model) + for vn in varnames(model) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. - for vn_leaf in varnames(vn, get(target_values, vn)) + for vn_leaf in varname_leaves(vn, get(target_values, vn)) target_value = get(target_values, vn_leaf) chain_mean_value = meanfunction(chain, vn_leaf) @test chain_mean_value ≈ target_value atol = atol rtol = rtol diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index ed5919f5a..5e598217a 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -80,7 +80,7 @@ _, svi_new = DynamicPPL.evaluate!!(model, svi, SamplingContext()) # Realization for `m` should be different wp. 1. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) @test svi_new[vn] != get(retval, vn) end @@ -100,7 +100,7 @@ # Update the realizations in `svi_new`. svi_eval = svi_new - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) svi_eval = DynamicPPL.setindex!!(svi_eval, get(values_eval, vn), vn) end @@ -111,7 +111,7 @@ logπ = logjoint(model, svi_eval) # Values should not have changed. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) @test svi_eval[vn] == get(values_eval, vn) end @@ -141,7 +141,7 @@ ) # Realizations from model should all be equal to the unconstrained realization. - for vn in keys(model) + for vn in DynamicPPL.TestUtils.varnames(model) @test get(retval_unconstrained, vn) ≈ svi[vn] rtol = 1e-6 end From 9669345f2b89e6ac50cf17323e2143fca6f4d2b7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:48:12 +0100 Subject: [PATCH 08/35] added test_sampler_on_models as a generalization of test_sampler_demo_models --- src/test_utils.jl | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 78d279067..75661f24e 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -735,13 +735,12 @@ const DEMO_MODELS = ( demo_dot_assume_matrix_dot_observe_matrix(), ) -# TODO: Is this really the best/most convenient "default" test method? """ - test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + test_sampler_on_models(meanfunction, models, sampler, args...; kwargs...) -Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. +Test that `sampler` produces correct marginal posterior means on each model in `models`. -In short, this method iterators through `demo_models`, calls `AbstractMCMC.sample` on the +In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the `model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain, vn)` for every (leaf) varname `vn` against the corresponding value returned by [`posterior_mean_values`](@ref) for each model. @@ -749,6 +748,7 @@ for every (leaf) varname `vn` against the corresponding value returned by # Arguments - `meanfunction`: A callable which computes the mean of the marginal means from the chain resulting from the `sample` call. +- `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on. - `sampler`: The `AbstractMCMC.AbstractSampler` to test. - `args...`: Arguments forwarded to `sample`. @@ -757,15 +757,16 @@ for every (leaf) varname `vn` against the corresponding value returned by - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. """ -function test_sampler_demo_models( +function test_sampler_on_models( meanfunction, + models, sampler::AbstractMCMC.AbstractSampler, args...; atol=1e-1, rtol=1e-3, kwargs..., ) - @testset "$(typeof(sampler)) on $(nameof(model))" for model in DEMO_MODELS + @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) target_values = posterior_mean_values(model) for vn in varnames(model) @@ -780,17 +781,30 @@ function test_sampler_demo_models( end end +""" + test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) + +Test `sampler` on every model in [`DEMO_MODELS`](@ref). + +This is just a proxy for `test_sampler_on_models(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. +""" +function test_sampler_on_demo_models( + meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... +) + return test_sampler_on_models(meanfunction, DEMO_MODELS, sampler, args...; kwargs...) +end + """ test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. -As of right now, this is just an alias for [`test_sampler_demo_models`](@ref). +As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). """ function test_sampler_continuous( meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_demo_models(meanfunction, sampler, args...; kwargs...) + return test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) end function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) From 7e027356415d6b3c19b4a0e1cd9cfdb249aa2ae2 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:48:40 +0100 Subject: [PATCH 09/35] updated docs --- docs/src/api.md | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 133b86e9b..debad2944 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -103,10 +103,15 @@ NamedDist DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. ```@docs -DynamicPPL.TestUtils.test_sampler_demo_models +DynamicPPL.TestUtils.test_sampler_on_models +DynamicPPL.TestUtils.test_sampler_on_demo_models DynamicPPL.TestUtils.test_sampler_continuous ``` +```@docs +DynamicPPL.TestUtils.DEMO_MODELS +``` + For every demo model, one can define the true log prior, log likelihood, and log joint probabilities. ```@docs @@ -115,6 +120,21 @@ DynamicPPL.TestUtils.loglikelihood_true DynamicPPL.TestUtils.logjoint_true ``` +And in the case where the model might include constrained variables, it can also be useful to define + +```@docs +DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian +DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian +``` + +Finally, the following methods can also be of use: + +```@docs +DynamicPPL.TestUtils.varnames +DynamicPPL.TestUtils.example_values +DynamicPPL.TestUtils.posterior_mean_values +``` + ## Advanced ### Variable names From a412029736905657c543ac56ea03be0695352acc Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:48:44 +0100 Subject: [PATCH 10/35] added docs for TestUtils.DEMO_MODELS --- src/test_utils.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/test_utils.jl b/src/test_utils.jl index 75661f24e..3f6ff5cf1 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -720,6 +720,21 @@ function posterior_mean_values( return vals end +""" +A collection of models corresponding to the posterior distribution defined by +the generative process + + s ~ InverseGamma(2, 3) + m ~ Normal(0, √s) + 1.5 ~ Normal(m, √s) + +_or_ a product of such distributions. + +The posterior for both `s` and `m` here is known in closed form. In particular, + + mean(s) == 19 / 8 + mean(m) == 3 / 4 +""" const DEMO_MODELS = ( demo_dot_assume_dot_observe(), demo_assume_index_observe(), From f3818c30ef79e106187e7e25c16d0dd459203331 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:51:12 +0100 Subject: [PATCH 11/35] updated some tests --- test/contexts.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/contexts.jl b/test/contexts.jl index ef916b18c..edcf5d0f3 100644 --- a/test/contexts.jl +++ b/test/contexts.jl @@ -165,7 +165,8 @@ end vn_without_prefix = remove_prefix(vn) # Let's check elementwise. - for vn_child in DynamicPPL.TestUtils.varnames(vn_without_prefix, val) + for vn_child in + DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) if get(val, getlens(vn_child)) === missing @test contextual_isassumption(context, vn_child) else @@ -197,7 +198,8 @@ end # `ConditionContext` with the conditioned variable. vn_without_prefix = remove_prefix(vn) - for vn_child in DynamicPPL.TestUtils.varnames(vn_without_prefix, val) + for vn_child in + DynamicPPL.TestUtils.varname_leaves(vn_without_prefix, val) # `vn_child` should be in `context`. @test hasvalue_nested(context, vn_child) # Value should be the same as extracted above. From 8b799a4f1665a8570d5f52ec262fc40830c0aea7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:53:23 +0100 Subject: [PATCH 12/35] fixed docstrings --- src/test_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 3f6ff5cf1..0cecd9d61 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -32,9 +32,9 @@ function varname_leaves(vn::VarName, val::AbstractArray) end """ - logprior_true(model, θ) + logprior_true(model, args...) -Return the `logprior` of `model` for `θ`. +Return the `logprior` of `model` for `args...`. This should generally be implemented by hand for every specific `model`. @@ -43,9 +43,9 @@ See also: [`logjoint_true`](@ref), [`loglikelihood_true`](@ref). function logprior_true end """ - loglikelihood_true(model, θ) + loglikelihood_true(model, args...) -Return the `loglikelihood` of `model` for `θ`. +Return the `loglikelihood` of `model` for `args...`. This should generally be implemented by hand for every specific `model`. From 93cb298ee1ffd921e3841c7e62a2d9cc093a78ef Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:53:55 +0100 Subject: [PATCH 13/35] fixed docstrings --- src/test_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 0cecd9d61..b015ce8fa 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -34,7 +34,7 @@ end """ logprior_true(model, args...) -Return the `logprior` of `model` for `args...`. +Return the `logprior` of `model` for `args`. This should generally be implemented by hand for every specific `model`. @@ -45,7 +45,7 @@ function logprior_true end """ loglikelihood_true(model, args...) -Return the `loglikelihood` of `model` for `args...`. +Return the `loglikelihood` of `model` for `args`. This should generally be implemented by hand for every specific `model`. @@ -56,7 +56,7 @@ function loglikelihood_true end """ logjoint_true(model, args...) -Return the `logjoint` of `model` for `args...`. +Return the `logjoint` of `model` for `args`. Defaults to `logprior_true(model, args...) + loglikelihood_true(model, args..)`. @@ -77,7 +77,7 @@ end """ logjoint_true_with_logabsdet_jacobian(model::Model, args...) -Return a tuple `(args_unconstrained, logjoint)` of `model` for `args...`. +Return a tuple `(args_unconstrained, logjoint)` of `model` for `args`. Unlike [`logjoint_true`](@ref), the returned logjoint computation includes the log-absdet-jacobian adjustment, thus computing logjoint for the unconstrained variables. From ba5852b85eed5a8a13ffd3c0ef76346927670d63 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:54:55 +0100 Subject: [PATCH 14/35] imprvoed docstring --- src/test_utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index b015ce8fa..fd42e86ab 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -129,7 +129,7 @@ end Return a `NamedTuple` compatible with `varnames(model)` with values in support of `model`. -Compatible means that a `varname` from `varnames(model)` can be used to extract the +\"Compatible\" means that a `varname` from `varnames(model)` can be used to extract the corresponding value using the call `get(example_values(model), varname)`. """ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) @@ -140,7 +140,7 @@ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) Return a `NamedTuple` compatible with `varnames(model)` where the values represent the posterior mean under `model`. -Compatible means that a `varname` from `varnames(model)` can be used to extract the +\"Compatible\" means that a `varname` from `varnames(model)` can be used to extract the corresponding value using the call `get(posterior_mean_values(model), varname)`. """ function posterior_mean_values end From 328f7134b6ce2be0f332b6ca38f7173b49f8f9b5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 10:57:03 +0100 Subject: [PATCH 15/35] improved docstrings --- src/test_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index fd42e86ab..1eca56350 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -129,8 +129,8 @@ end Return a `NamedTuple` compatible with `varnames(model)` with values in support of `model`. -\"Compatible\" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using the call `get(example_values(model), varname)`. +"Compatible" means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using `get`, e.g. `get(example_values(model), varname)`. """ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) @@ -140,8 +140,8 @@ example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) Return a `NamedTuple` compatible with `varnames(model)` where the values represent the posterior mean under `model`. -\"Compatible\" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using the call `get(posterior_mean_values(model), varname)`. +"Compatible" means that a `varname` from `varnames(model)` can be used to extract the +corresponding value using `get`, e.g. `get(posterior_mean_values(model), varname)`. """ function posterior_mean_values end From 58436994f423fff5077020766fa708e693f42566 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Fri, 1 Jul 2022 15:22:21 +0100 Subject: [PATCH 16/35] fixed tests of pointwise_loglikelihoods --- src/test_utils.jl | 2 +- test/loglikelihoods.jl | 22 ++++++++-------------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 1eca56350..7ce3cf3b8 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -700,7 +700,7 @@ function logprior_true_with_logabsdet_jacobian( return _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) - return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m[1]), @varname(m[2])] + return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)] end function example_values( rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index eaf1e00bd..bd04a76a5 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,15 +1,14 @@ @testset "loglikelihoods.jl" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - vi = VarInfo(m) + example_values = DynamicPPL.TestUtils.example_values(m) + # Instantiate a `VarInfo` with the example values. + vi = VarInfo(m) for vn in DynamicPPL.TestUtils.varnames(m) - if vi[vn] isa Real - vi = DynamicPPL.setindex!!(vi, 1.0, vn) - else - vi = DynamicPPL.setindex!!(vi, ones(size(vi[vn])), vn) - end + vi = DynamicPPL.setindex!!(vi, get(example_values, vn), vn) end + # Compute the pointwise loglikelihoods. lls = pointwise_loglikelihoods(m, vi) if isempty(lls) @@ -17,14 +16,9 @@ continue end - loglikelihood = if length(keys(lls)) == 1 && length(m.args.x) == 1 - # Only have one observation, so we need to double it - # for comparison with other models. - 2 * sum(lls[first(keys(lls))]) - else - sum(sum, values(lls)) - end + loglikelihood = sum(sum, values(lls)) + loglikelihood_true = DynamicPPL.TestUtils.loglikelihood_true(m, example_values...) - @test loglikelihood ≈ -324.45158270528947 + @test loglikelihood ≈ loglikelihood_true end end From 912d7f847fd4642837a55d1cabe4428e79be1146 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 12:03:45 +0100 Subject: [PATCH 17/35] Apply suggestions from code review Co-authored-by: David Widmann --- src/test_utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 7ce3cf3b8..ad28585d7 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -185,7 +185,7 @@ end function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) b = Bijectors.bijector(InverseGamma(2, 3)) s_unconstrained = b.(s) - Δlogp = sum(Base.Fix1(Bijectors.logabsdetjac, b).(s)) + Δlogp = sum(Base.Fix1(Bijectors.logabsdetjac, b), s) return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp end From a276e4a662795b9726e91823a52c8f5b20188965 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 12:05:56 +0100 Subject: [PATCH 18/35] renamed posterior_mean_values to posterior_mean --- docs/src/api.md | 2 +- src/test_utils.jl | 34 +++++++++++++++++----------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index debad2944..ab7e7fe60 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -132,7 +132,7 @@ Finally, the following methods can also be of use: ```@docs DynamicPPL.TestUtils.varnames DynamicPPL.TestUtils.example_values -DynamicPPL.TestUtils.posterior_mean_values +DynamicPPL.TestUtils.posterior_mean ``` ## Advanced diff --git a/src/test_utils.jl b/src/test_utils.jl index ad28585d7..bbc974677 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -135,15 +135,15 @@ corresponding value using `get`, e.g. `get(example_values(model), varname)`. example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) """ - posterior_mean_values(model::Model) + posterior_mean(model::Model) Return a `NamedTuple` compatible with `varnames(model)` where the values represent the posterior mean under `model`. "Compatible" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using `get`, e.g. `get(posterior_mean_values(model), varname)`. +corresponding value using `get`, e.g. `get(posterior_mean(model), varname)`. """ -function posterior_mean_values end +function posterior_mean end """ demo_dynamic_constraint() @@ -226,7 +226,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -274,7 +274,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_assume_index_observe)}) +function posterior_mean(model::Model{typeof(demo_assume_index_observe)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -311,7 +311,7 @@ function example_values( s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) return (s=s, m=rand(rng, MvNormal(zero(model.args.x), Diagonal(s)))) end -function posterior_mean_values(model::Model{typeof(demo_assume_multivariate_observe)}) +function posterior_mean(model::Model{typeof(demo_assume_multivariate_observe)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -357,7 +357,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -395,7 +395,7 @@ function example_values( m = rand(rng, Normal(0, sqrt(s))) return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_assume_dot_observe)}) +function posterior_mean(model::Model{typeof(demo_assume_dot_observe)}) return (s=2.375, m=0.75) end @@ -429,7 +429,7 @@ function example_values( s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) return (s=s, m=rand(rng, MvNormal(zeros(2), Diagonal(s)))) end -function posterior_mean_values(model::Model{typeof(demo_assume_observe_literal)}) +function posterior_mean(model::Model{typeof(demo_assume_observe_literal)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -476,7 +476,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_index_literal)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index_literal)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -512,7 +512,7 @@ function example_values( m = rand(rng, Normal(0, sqrt(s))) return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_assume_literal_dot_observe)}) +function posterior_mean(model::Model{typeof(demo_assume_literal_dot_observe)}) return (s=2.375, m=0.75) end @@ -564,7 +564,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values( +function posterior_mean( model::Model{typeof(demo_assume_submodel_observe_index_literal)} ) vals = example_values(model) @@ -615,7 +615,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_observe_submodel)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_observe_submodel)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -660,7 +660,7 @@ function example_values( end return (s=s, m=m) end -function posterior_mean_values(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) +function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) vals = example_values(model) vals.s .= 2.375 vals.m .= 0.75 @@ -711,7 +711,7 @@ function example_values( m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) return (s=s, m=m) end -function posterior_mean_values( +function posterior_mean( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} ) vals = example_values(model) @@ -758,7 +758,7 @@ Test that `sampler` produces correct marginal posterior means on each model in ` In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the `model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain, vn)` for every (leaf) varname `vn` against the corresponding value returned by -[`posterior_mean_values`](@ref) for each model. +[`posterior_mean`](@ref) for each model. # Arguments - `meanfunction`: A callable which computes the mean of the marginal means from the @@ -783,7 +783,7 @@ function test_sampler_on_models( ) @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) - target_values = posterior_mean_values(model) + target_values = posterior_mean(model) for vn in varnames(model) # We want to compare elementwise which can be achieved by # extracting the leaves of the `VarName` and the corresponding value. From 626eea212d8e500c0af51fb63a4b216c51a2a436 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 12:55:10 +0100 Subject: [PATCH 19/35] made demo models a bit more complex, now including different observations --- src/test_utils.jl | 175 +++++++++++++++++++++++++++++++--------------- 1 file changed, 119 insertions(+), 56 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index bbc974677..f88414a7d 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -190,7 +190,7 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m) end @model function demo_dot_assume_dot_observe( - x=[1.5, 1.5], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` s = TV(undef, length(x)) @@ -228,13 +228,18 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @model function demo_assume_index_observe( - x=[1.5, 1.5], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `assume` with indexing and `observe` s = TV(undef, length(x)) @@ -276,12 +281,17 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_assume_index_observe)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end -@model function demo_assume_multivariate_observe(x=[1.5, 1.5]) +@model function demo_assume_multivariate_observe(x=[1.5, 2.0]) # Multivariate `assume` and `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zero(x), Diagonal(s)) @@ -313,13 +323,18 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_assume_multivariate_observe)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @model function demo_dot_assume_observe_index( - x=[1.5, 1.5], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} # `dot_assume` and `observe` with indexing s = TV(undef, length(x)) @@ -359,14 +374,19 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. -@model function demo_assume_dot_observe(x=[1.5]) +@model function demo_assume_dot_observe(x=[1.5, 2.0]) # `assume` and `dot_observe` s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) @@ -396,16 +416,16 @@ function example_values( return (s=s, m=m) end function posterior_mean(model::Model{typeof(demo_assume_dot_observe)}) - return (s=2.375, m=0.75) + return (s=49 / 24, m=7 / 6) end @model function demo_assume_observe_literal() # `assume` and literal `observe` s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) m ~ MvNormal(zeros(2), Diagonal(s)) - [1.5, 1.5] ~ MvNormal(m, Diagonal(s)) + [1.5, 2.0] ~ MvNormal(m, Diagonal(s)) - return (; s=s, m=m, x=[1.5, 1.5], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)]) @@ -413,7 +433,7 @@ function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) return logpdf(s_dist, s) + logpdf(m_dist, m) end function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m) - return logpdf(MvNormal(m, Diagonal(s)), [1.5, 1.5]) + return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_assume_observe_literal)}, s, m @@ -431,8 +451,13 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_assume_observe_literal)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @@ -443,11 +468,10 @@ end s .~ InverseGamma(2, 3) m .~ Normal.(0, sqrt.(s)) - for i in eachindex(m) - 1.5 ~ Normal(m[i], sqrt(s[i])) - end + 1.5 ~ Normal(m[1], sqrt(s[1])) + 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=fill(1.5, length(m)), logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end function logprior_true(model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m) return loglikelihood(InverseGamma(2, 3), s) + sum(logpdf.(Normal.(0, sqrt.(s)), m)) @@ -455,7 +479,7 @@ end function loglikelihood_true( model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m ) - return sum(logpdf.(Normal.(m, sqrt.(s)), fill(1.5, length(m)))) + return sum(logpdf.(Normal.(m, sqrt.(s)), [1.5, 2.0])) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dot_assume_observe_index_literal)}, s, m @@ -478,8 +502,13 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index_literal)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @@ -487,15 +516,15 @@ end # `assume` and literal `dot_observe` s ~ InverseGamma(2, 3) m ~ Normal(0, sqrt(s)) - [1.5] .~ Normal(m, sqrt(s)) + [1.5, 2.0] .~ Normal(m, sqrt(s)) - return (; s=s, m=m, x=[1.5], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) - return logpdf(Normal(m, sqrt(s)), 1.5) + return logpdf(Normal(m, sqrt(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_assume_literal_dot_observe)}, s, m @@ -513,7 +542,7 @@ function example_values( return (s=s, m=m) end function posterior_mean(model::Model{typeof(demo_assume_literal_dot_observe)}) - return (s=2.375, m=0.75) + return (s=49 / 24, m=7 / 6) end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} @@ -528,11 +557,10 @@ end @model function demo_assume_submodel_observe_index_literal() # Submodel prior @submodel s, m = _prior_dot_assume() - for i in eachindex(m, s) - 1.5 ~ Normal(m[i], sqrt(s[i])) - end + 1.5 ~ Normal(m[1], sqrt(s[1])) + 2.0 ~ Normal(m[2], sqrt(s[2])) - return (; s=s, m=m, x=[1.5, 1.5], logp=getlogp(__varinfo__)) + return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__)) end function logprior_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -542,7 +570,7 @@ end function loglikelihood_true( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m ) - return sum(logpdf.(Normal.(m, sqrt.(s)), 1.5)) + return sum(logpdf.(Normal.(m, sqrt.(s)), [1.5, 2.0])) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_assume_submodel_observe_index_literal)}, s, m @@ -564,12 +592,15 @@ function example_values( end return (s=s, m=m) end -function posterior_mean( - model::Model{typeof(demo_assume_submodel_observe_index_literal)} -) +function posterior_mean(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @@ -578,7 +609,7 @@ end end @model function demo_dot_assume_observe_submodel( - x=[1.5, 1.5], ::Type{TV}=Vector{Float64} + x=[1.5, 2.0], ::Type{TV}=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) @@ -617,13 +648,18 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_observe_submodel)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @model function demo_dot_assume_dot_observe_matrix( - x=fill(1.5, 2, 1), ::Type{TV}=Vector{Float64} + x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} ) where {TV} s = TV(undef, length(x)) s .~ InverseGamma(2, 3) @@ -662,13 +698,18 @@ function example_values( end function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @model function demo_dot_assume_matrix_dot_observe_matrix( - x=fill(1.5, 2, 1), ::Type{TV}=Array{Float64} + x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} ) where {TV} n = length(x) d = length(x) ÷ 2 @@ -711,12 +752,15 @@ function example_values( m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) return (s=s, m=m) end -function posterior_mean( - model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} -) +function posterior_mean(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) vals = example_values(model) - vals.s .= 2.375 - vals.m .= 0.75 + + vals.s[1] = 19 / 8 + vals.m[1] = 3 / 4 + + vals.s[2] = 8 / 3 + vals.m[2] = 1 + return vals end @@ -727,13 +771,32 @@ the generative process s ~ InverseGamma(2, 3) m ~ Normal(0, √s) 1.5 ~ Normal(m, √s) + 2.0 ~ Normal(m, √s) + +or by + + s[1] ~ InverseGamma(2, 3) + s[2] ~ InverseGamma(2, 3) + m[1] ~ Normal(0, √s) + m[2] ~ Normal(0, √s) + 1.5 ~ Normal(m[1], √s[1]) + 2.0 ~ Normal(m[2], √s[2]) + +These are examples of a Normal-InverseGamma conjugate prior with Normal likelihood, +for which the posterior is known in closed form. + +In particular, for the univariate model (the former one): + + mean(s) == 49 / 24 + mean(m) == 7 / 6 -_or_ a product of such distributions. +And for the multivariate one (the latter one): -The posterior for both `s` and `m` here is known in closed form. In particular, + mean(s[1]) == 19 / 8 + mean(m[1]) == 3 / 4 + mean(s[2]) == 8 / 3 + mean(m[2]) == 1 - mean(s) == 19 / 8 - mean(m) == 3 / 4 """ const DEMO_MODELS = ( demo_dot_assume_dot_observe(), From 15589247ed2d26ff3144b9a4ddb7102f6e5fc047 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 18:11:34 +0100 Subject: [PATCH 20/35] Update docs/src/api.md Co-authored-by: David Widmann --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index ab7e7fe60..9f4dae3f5 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -120,7 +120,7 @@ DynamicPPL.TestUtils.loglikelihood_true DynamicPPL.TestUtils.logjoint_true ``` -And in the case where the model might include constrained variables, it can also be useful to define +And in the case where the model includes constrained variables, it can also be useful to define ```@docs DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian From a62c881ef3de794f47e8084d9b0b2112a9b5066b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 19:43:20 +0100 Subject: [PATCH 21/35] reduce number of method definitions by defining some useful type unions in TestUtils --- src/test_utils.jl | 145 +++++++++++++--------------------------------- 1 file changed, 39 insertions(+), 106 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index f88414a7d..ebab44d91 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -226,17 +226,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_assume_index_observe( x=[1.5, 2.0], ::Type{TV}=Vector{Float64} @@ -279,17 +268,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_assume_index_observe)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_assume_multivariate_observe(x=[1.5, 2.0]) # Multivariate `assume` and `observe` @@ -321,17 +299,6 @@ function example_values( s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) return (s=s, m=rand(rng, MvNormal(zero(model.args.x), Diagonal(s)))) end -function posterior_mean(model::Model{typeof(demo_assume_multivariate_observe)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_dot_assume_observe_index( x=[1.5, 2.0], ::Type{TV}=Vector{Float64} @@ -372,17 +339,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -415,9 +371,6 @@ function example_values( m = rand(rng, Normal(0, sqrt(s))) return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_assume_dot_observe)}) - return (s=49 / 24, m=7 / 6) -end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -449,17 +402,6 @@ function example_values( s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) return (s=s, m=rand(rng, MvNormal(zeros(2), Diagonal(s)))) end -function posterior_mean(model::Model{typeof(demo_assume_observe_literal)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -500,17 +442,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_observe_index_literal)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -541,9 +472,6 @@ function example_values( m = rand(rng, Normal(0, sqrt(s))) return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_assume_literal_dot_observe)}) - return (s=49 / 24, m=7 / 6) -end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} s = TV(undef, 2) @@ -592,17 +520,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function _likelihood_mltivariate_observe(s, m, x) return x ~ MvNormal(m, Diagonal(s)) @@ -646,17 +563,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_observe_submodel)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_dot_assume_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} @@ -696,17 +602,6 @@ function example_values( end return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) - vals = example_values(model) - - vals.s[1] = 19 / 8 - vals.m[1] = 3 / 4 - - vals.s[2] = 8 / 3 - vals.m[2] = 1 - - return vals -end @model function demo_dot_assume_matrix_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} @@ -752,7 +647,45 @@ function example_values( m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) return (s=s, m=m) end -function posterior_mean(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) + +const DemoModels = Union{ + Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_assume_index_observe)}, + Model{typeof(demo_assume_multivariate_observe)}, + Model{typeof(demo_dot_assume_observe_index)}, + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_literal_dot_observe)}, + Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_dot_assume_observe_index_literal)}, + Model{typeof(demo_assume_submodel_observe_index_literal)}, + Model{typeof(demo_dot_assume_observe_submodel)}, + Model{typeof(demo_dot_assume_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, +} +_observations(model::DemoModels) = [1.5, 2.0] + +const UnivariateAssumeDemoModels = Union{ + Model{typeof(demo_assume_dot_observe)}, + Model{typeof(demo_assume_literal_dot_observe)}, +} +function posterior_mean(model::UnivariateAssumeDemoModels) + return (s=49 / 24, m=7 / 6) +end + +const MultivariateAssumeDemoModels = Union{ + Model{typeof(demo_dot_assume_dot_observe)}, + Model{typeof(demo_assume_index_observe)}, + Model{typeof(demo_assume_multivariate_observe)}, + Model{typeof(demo_dot_assume_observe_index)}, + Model{typeof(demo_assume_observe_literal)}, + Model{typeof(demo_dot_assume_observe_index_literal)}, + Model{typeof(demo_assume_submodel_observe_index_literal)}, + Model{typeof(demo_dot_assume_observe_submodel)}, + Model{typeof(demo_dot_assume_dot_observe_matrix)}, + Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, +} +function posterior_mean(model::MultivariateAssumeDemoModels) + # Get some containers to fill. vals = example_values(model) vals.s[1] = 19 / 8 From 5cc195aea99e8a0df65abc9de9bf92494e0ed4f9 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 2 Jul 2022 22:42:27 +0100 Subject: [PATCH 22/35] removed unnecessary method --- src/test_utils.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index ebab44d91..8515fe4e6 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -662,7 +662,6 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_dot_observe_matrix)}, Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, } -_observations(model::DemoModels) = [1.5, 2.0] const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)}, From 702f2ff942d109e170eeb68355dddfa7dc6e5043 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:45:21 +0100 Subject: [PATCH 23/35] fixed a couple of loglikelihood_true definitions --- src/test_utils.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 8515fe4e6..7ece1e56e 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -455,7 +455,7 @@ function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m) end function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m) - return logpdf(Normal(m, sqrt(s)), [1.5, 2.0]) + return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0]) end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_assume_literal_dot_observe)}, s, m @@ -623,7 +623,8 @@ function logprior_true( ) n = length(model.args.x) s_vec = vec(s) - return loglikelihood(InverseGamma(2, 3), s_vec) + logpdf(MvNormal(zeros(n), s_vec), m) + return loglikelihood(InverseGamma(2, 3), s_vec) + + logpdf(MvNormal(zeros(n), Diagonal(s_vec)), m) end function loglikelihood_true( model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, s, m From d8f497019113c4d35793a1f394b5a1b89ae3e125 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:45:38 +0100 Subject: [PATCH 24/35] style --- src/test_utils.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 7ece1e56e..496a40e50 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -665,8 +665,7 @@ const DemoModels = Union{ } const UnivariateAssumeDemoModels = Union{ - Model{typeof(demo_assume_dot_observe)}, - Model{typeof(demo_assume_literal_dot_observe)}, + Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} } function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) From 56f30bc45b1a326f3edc715beb1e62647c648b73 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:46:08 +0100 Subject: [PATCH 25/35] added tests for logprior and loglikelihood computation for SimpleVarInfo --- test/simple_varinfo.jl | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 5e598217a..955c3676b 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -90,14 +90,27 @@ ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.example_values(model) if DynamicPPL.istrans(svi) + _, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( + model, values_eval_constrained... + ) values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_eval_constrained... ) else + logpri_true = DynamicPPL.TestUtils.logprior_true( + model, values_eval_constrained... + ) + logπ_true = DynamicPPL.TestUtils.logjoint_true( + model, values_eval_constrained... + ) values_eval = values_eval_constrained - logπ_true = DynamicPPL.TestUtils.logjoint_true(model, values_eval...) end + # No logabsdet-jacobian correction needed for the likelihood. + loglik_true = DynamicPPL.TestUtils.loglikelihood_true( + model, values_eval_constrained... + ) + # Update the realizations in `svi_new`. svi_eval = svi_new for vn in DynamicPPL.TestUtils.varnames(model) @@ -109,13 +122,19 @@ # Compute `logjoint` using the varinfo. logπ = logjoint(model, svi_eval) + logpri = logprior(model, svi_eval) + loglik = loglikelihood(model, svi_eval) + + retval_svi, _ = DynamicPPL.evaluate!!(model, svi, LikelihoodContext()) # Values should not have changed. for vn in DynamicPPL.TestUtils.varnames(model) @test svi_eval[vn] == get(values_eval, vn) end - # Compare `logjoint` computations. + # Compare log-probability computations. + @test logpri ≈ logpri_true + @test loglik ≈ loglik_true @test logπ ≈ logπ_true end end From 2eaef02d9e47fa3cd9de9b34689ab51652bf7dd4 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 00:46:52 +0100 Subject: [PATCH 26/35] fixed implementation of logpdf_with_trans for NoDist --- src/distribution_wrappers.jl | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/distribution_wrappers.jl b/src/distribution_wrappers.jl index 65479f035..d8968a68e 100644 --- a/src/distribution_wrappers.jl +++ b/src/distribution_wrappers.jl @@ -51,9 +51,21 @@ Distributions.logpdf(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 Distributions.minimum(d::NoDist) = minimum(d.dist) Distributions.maximum(d::NoDist) = maximum(d.dist) -Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real) = 0 -Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}) = 0 -function Bijectors.logpdf_with_trans(d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}) +Bijectors.logpdf_with_trans(d::NoDist{<:Univariate}, ::Real, ::Bool) = 0 +function Bijectors.logpdf_with_trans( + d::NoDist{<:Multivariate}, ::AbstractVector{<:Real}, ::Bool +) + return 0 +end +function Bijectors.logpdf_with_trans( + d::NoDist{<:Multivariate}, x::AbstractMatrix{<:Real}, ::Bool +) return zeros(Int, size(x, 2)) end -Bijectors.logpdf_with_trans(d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}) = 0 +function Bijectors.logpdf_with_trans( + d::NoDist{<:Matrixvariate}, ::AbstractMatrix{<:Real}, ::Bool +) + return 0 +end + +Bijectors.bijector(d::NoDist) = Bijectors.bijector(d.dist) From 78f22e177bcc071b42618eaf8696cc37daf5b2bb Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:01:24 +0100 Subject: [PATCH 27/35] removed unused variable --- test/simple_varinfo.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 955c3676b..7163d3106 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -125,8 +125,6 @@ logpri = logprior(model, svi_eval) loglik = loglikelihood(model, svi_eval) - retval_svi, _ = DynamicPPL.evaluate!!(model, svi, LikelihoodContext()) - # Values should not have changed. for vn in DynamicPPL.TestUtils.varnames(model) @test svi_eval[vn] == get(values_eval, vn) From 025a4d47e79ca4bddad6d11547407971190e4926 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:03:39 +0100 Subject: [PATCH 28/35] added test for transformed values for the logprior_true and loglikelihood_true methods --- test/simple_varinfo.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 7163d3106..175f264d4 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -90,12 +90,15 @@ ### Evaluation ### values_eval_constrained = DynamicPPL.TestUtils.example_values(model) if DynamicPPL.istrans(svi) - _, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( + _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( model, values_eval_constrained... ) values_eval, logπ_true = DynamicPPL.TestUtils.logjoint_true_with_logabsdet_jacobian( model, values_eval_constrained... ) + # Make sure that these two computation paths provide the same + # transformed values. + @test values_eval == _values_prior else logpri_true = DynamicPPL.TestUtils.logprior_true( model, values_eval_constrained... From f5c60aec1eb7ccad4aa7ec14b83f6c350b246140 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:17:21 +0100 Subject: [PATCH 29/35] renamed test_sampler_on_models to test_sampler --- src/test_utils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 496a40e50..4e728884c 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -746,7 +746,7 @@ const DEMO_MODELS = ( ) """ - test_sampler_on_models(meanfunction, models, sampler, args...; kwargs...) + test_sampler(meanfunction, models, sampler, args...; kwargs...) Test that `sampler` produces correct marginal posterior means on each model in `models`. @@ -767,7 +767,7 @@ for every (leaf) varname `vn` against the corresponding value returned by - `rtol=1e-3`: Relative tolerance used in `@test`. - `kwargs...`: Keyword arguments forwarded to `sample`. """ -function test_sampler_on_models( +function test_sampler( meanfunction, models, sampler::AbstractMCMC.AbstractSampler, @@ -796,12 +796,12 @@ end Test `sampler` on every model in [`DEMO_MODELS`](@ref). -This is just a proxy for `test_sampler_on_models(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. +This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. """ function test_sampler_on_demo_models( meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_on_models(meanfunction, DEMO_MODELS, sampler, args...; kwargs...) + return test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...) end """ From 25f05de007337c0b56799272a96948024909f7c6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 01:37:44 +0100 Subject: [PATCH 30/35] updated docs --- docs/src/api.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/api.md b/docs/src/api.md index 9f4dae3f5..9aa481cc4 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -103,7 +103,7 @@ NamedDist DynamicPPL provides several demo models and helpers for testing samplers in the `DynamicPPL.TestUtils` submodule. ```@docs -DynamicPPL.TestUtils.test_sampler_on_models +DynamicPPL.TestUtils.test_sampler DynamicPPL.TestUtils.test_sampler_on_demo_models DynamicPPL.TestUtils.test_sampler_continuous ``` From e05fa291e0b3dd835e887df1c44c95720a2995be Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sun, 3 Jul 2022 12:13:24 +0100 Subject: [PATCH 31/35] share implementation of example_values --- src/test_utils.jl | 139 +++++++--------------------------------------- 1 file changed, 20 insertions(+), 119 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 4e728884c..7b3758f40 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -166,12 +166,6 @@ end function varnames(model::Model{typeof(demo_dynamic_constraint)}) return [@varname(m), @varname(x)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dynamic_constraint)} -) - m = rand(rng, Normal()) - return (m=m, x=rand(rng, truncated(Normal(), m, Inf))) -end function logprior_true_with_logabsdet_jacobian( model::Model{typeof(demo_dynamic_constraint)}, m, x ) @@ -215,17 +209,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_dot_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_assume_index_observe( x=[1.5, 2.0], ::Type{TV}=Vector{Float64} @@ -257,17 +240,6 @@ end function varnames(model::Model{typeof(demo_assume_index_observe)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_index_observe)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_assume_multivariate_observe(x=[1.5, 2.0]) # Multivariate `assume` and `observe` @@ -293,12 +265,6 @@ end function varnames(model::Model{typeof(demo_assume_multivariate_observe)}) return [@varname(s), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_multivariate_observe)} -) - s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) - return (s=s, m=rand(rng, MvNormal(zero(model.args.x), Diagonal(s)))) -end @model function demo_dot_assume_observe_index( x=[1.5, 2.0], ::Type{TV}=Vector{Float64} @@ -328,17 +294,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_observe_index)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end # Using vector of `length` 1 here so the posterior of `m` is the same # as the others. @@ -364,13 +319,6 @@ end function varnames(model::Model{typeof(demo_assume_dot_observe)}) return [@varname(s), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_dot_observe)} -) - s = rand(rng, InverseGamma(2, 3)) - m = rand(rng, Normal(0, sqrt(s))) - return (s=s, m=m) -end @model function demo_assume_observe_literal() # `assume` and literal `observe` @@ -396,12 +344,6 @@ end function varnames(model::Model{typeof(demo_assume_observe_literal)}) return [@varname(s), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_observe_literal)} -) - s = rand(rng, product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])) - return (s=s, m=rand(rng, MvNormal(zeros(2), Diagonal(s)))) -end @model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -431,17 +373,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_index_literal)} -) - n = 2 - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_assume_literal_dot_observe() # `assume` and literal `dot_observe` @@ -465,13 +396,6 @@ end function varnames(model::Model{typeof(demo_assume_literal_dot_observe)}) return [@varname(s), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_assume_literal_dot_observe)} -) - s = rand(rng, InverseGamma(2, 3)) - m = rand(rng, Normal(0, sqrt(s))) - return (s=s, m=m) -end @model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV} s = TV(undef, 2) @@ -508,18 +432,6 @@ end function varnames(model::Model{typeof(demo_assume_submodel_observe_index_literal)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, - model::Model{typeof(demo_assume_submodel_observe_index_literal)}, -) - n = 2 - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function _likelihood_mltivariate_observe(s, m, x) return x ~ MvNormal(m, Diagonal(s)) @@ -552,17 +464,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_observe_submodel)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_dot_assume_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64} @@ -591,17 +492,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_dot_observe_matrix)}) return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_dot_observe_matrix)} -) - n = length(model.args.x) - s = rand(rng, InverseGamma(2, 3), n) - m = similar(s) - for i in eachindex(m, s) - m[i] = rand(rng, Normal(0, sqrt(s[i]))) - end - return (s=s, m=m) -end @model function demo_dot_assume_matrix_dot_observe_matrix( x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64} @@ -639,15 +529,6 @@ end function varnames(model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}) return [@varname(s[:, 1]), @varname(s[:, 2]), @varname(m)] end -function example_values( - rng::Random.AbstractRNG, model::Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)} -) - n = length(model.args.x) - d = n ÷ 2 - s = rand(rng, product_distribution([InverseGamma(2, 3) for _ in 1:d]), 2) - m = rand(rng, MvNormal(zeros(n), Diagonal(vec(s)))) - return (s=s, m=m) -end const DemoModels = Union{ Model{typeof(demo_dot_assume_dot_observe)}, @@ -670,6 +551,12 @@ const UnivariateAssumeDemoModels = Union{ function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) end +function example_values(rng::Random.AbstractRNG, model::UnivariateAssumeDemoModels) + s = rand(rng, InverseGamma(2, 3)) + m = rand(rng, Normal(0, sqrt(s))) + + return (s=s, m=m) +end const MultivariateAssumeDemoModels = Union{ Model{typeof(demo_dot_assume_dot_observe)}, @@ -695,6 +582,20 @@ function posterior_mean(model::MultivariateAssumeDemoModels) return vals end +function example_values( + rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels +) + # Get template values from `model`. + retval = model(rng) + vals = (s = retval.s, m = retval.m) + # Fill containers with realizations from prior. + for i in LinearIndices(vals.s) + vals.s[i] = rand(rng, InverseGamma(2, 3)) + vals.m[i] = rand(rng, Normal(0, sqrt(vals.s[i]))) + end + + return vals +end """ A collection of models corresponding to the posterior distribution defined by From 431664dbe0e2aa21e0c6bdb7dd3984659187f01b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 4 Jul 2022 11:45:32 +0100 Subject: [PATCH 32/35] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/test_utils.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 7b3758f40..f68e8cafd 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -582,12 +582,10 @@ function posterior_mean(model::MultivariateAssumeDemoModels) return vals end -function example_values( - rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels -) +function example_values(rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels) # Get template values from `model`. retval = model(rng) - vals = (s = retval.s, m = retval.m) + vals = (s=retval.s, m=retval.m) # Fill containers with realizations from prior. for i in LinearIndices(vals.s) vals.s[i] = rand(rng, InverseGamma(2, 3)) From b3499a37d4316fc0fe3819aaee580e3cf3ff1158 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 12 Jul 2022 12:48:54 +0100 Subject: [PATCH 33/35] added marginal_mean_of_samples according to suggestions --- src/test_utils.jl | 37 ++++++++++++++++++------------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index f68e8cafd..53068efbf 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -645,18 +645,26 @@ const DEMO_MODELS = ( ) """ - test_sampler(meanfunction, models, sampler, args...; kwargs...) + marginal_mean_of_samples(chain, varname) + +Return the mean of variable represented by `varname` in `chain`. +""" +marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) + +""" + test_sampler(models, sampler, args...; kwargs...) Test that `sampler` produces correct marginal posterior means on each model in `models`. In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks `meanfunction(chain, vn)` +`model` and `sampler` to produce a `chain`, and then checks [`marginal_mean_of_samples(chain, vn)`](@ref) for every (leaf) varname `vn` against the corresponding value returned by [`posterior_mean`](@ref) for each model. +To change how comparison is done for a particular `chain` type, one can overload +[`marginal_mean_of_samples(chain, vn)`](@ref) for the corresponding type. + # Arguments -- `meanfunction`: A callable which computes the mean of the marginal means from the - chain resulting from the `sample` call. - `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on. - `sampler`: The `AbstractMCMC.AbstractSampler` to test. - `args...`: Arguments forwarded to `sample`. @@ -667,7 +675,6 @@ for every (leaf) varname `vn` against the corresponding value returned by - `kwargs...`: Keyword arguments forwarded to `sample`. """ function test_sampler( - meanfunction, models, sampler::AbstractMCMC.AbstractSampler, args...; @@ -683,7 +690,7 @@ function test_sampler( # extracting the leaves of the `VarName` and the corresponding value. for vn_leaf in varname_leaves(vn, get(target_values, vn)) target_value = get(target_values, vn_leaf) - chain_mean_value = meanfunction(chain, vn_leaf) + chain_mean_value = marginal_mean_of_samples(chain, vn_leaf) @test chain_mean_value ≈ target_value atol = atol rtol = rtol end end @@ -698,30 +705,22 @@ Test `sampler` on every model in [`DEMO_MODELS`](@ref). This is just a proxy for `test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...)`. """ function test_sampler_on_demo_models( - meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... + sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler(meanfunction, DEMO_MODELS, sampler, args...; kwargs...) + return test_sampler(DEMO_MODELS, sampler, args...; kwargs...) end """ - test_sampler_continuous([meanfunction, ]sampler, args...; kwargs...) + test_sampler_continuous(sampler, args...; kwargs...) Test that `sampler` produces the correct marginal posterior means on all models in `demo_models`. As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). """ function test_sampler_continuous( - meanfunction, sampler::AbstractMCMC.AbstractSampler, args...; kwargs... + sampler::AbstractMCMC.AbstractSampler, args...; kwargs... ) - return test_sampler_on_demo_models(meanfunction, sampler, args...; kwargs...) -end - -function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) - # Default for `MCMCChains.Chains`. - return test_sampler_continuous(sampler, args...; kwargs...) do chain, vn - # HACK(torfjelde): This assumes that we can index into `chain` with `Symbol(vn)`. - mean(Array(chain[Symbol(vn)])) - end + return test_sampler_on_demo_models(sampler, args...; kwargs...) end end From 2bd5dcddcef5bd2aa0ab79197cd72f1026fc9383 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 13 Jul 2022 10:17:05 +0100 Subject: [PATCH 34/35] removed example_values in favour of rand with NamedTuple --- docs/src/api.md | 1 - src/test_utils.jl | 37 +++++++++++++++---------------------- test/loglikelihoods.jl | 2 +- test/simple_varinfo.jl | 4 ++-- 4 files changed, 18 insertions(+), 26 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index 9aa481cc4..c7133a5f9 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -131,7 +131,6 @@ Finally, the following methods can also be of use: ```@docs DynamicPPL.TestUtils.varnames -DynamicPPL.TestUtils.example_values DynamicPPL.TestUtils.posterior_mean ``` diff --git a/src/test_utils.jl b/src/test_utils.jl index 53068efbf..508eb275f 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -124,16 +124,6 @@ function varnames(model::Model) ) end -""" - example_values(model::Model) - -Return a `NamedTuple` compatible with `varnames(model)` with values in support of `model`. - -"Compatible" means that a `varname` from `varnames(model)` can be used to extract the -corresponding value using `get`, e.g. `get(example_values(model), varname)`. -""" -example_values(model::Model) = example_values(Random.GLOBAL_RNG, model) - """ posterior_mean(model::Model) @@ -545,13 +535,21 @@ const DemoModels = Union{ Model{typeof(demo_dot_assume_matrix_dot_observe_matrix)}, } +# We require demo models to have explict impleentations of `rand` since we want +# these to be considered as ground truth. +function Random.rand(rng::Random.AbstractRNG, ::Type{NamedTuple}, model::DemoModels) + return error("demo models requires explicit implementation of rand") +end + const UnivariateAssumeDemoModels = Union{ Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)} } function posterior_mean(model::UnivariateAssumeDemoModels) return (s=49 / 24, m=7 / 6) end -function example_values(rng::Random.AbstractRNG, model::UnivariateAssumeDemoModels) +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::UnivariateAssumeDemoModels +) s = rand(rng, InverseGamma(2, 3)) m = rand(rng, Normal(0, sqrt(s))) @@ -572,7 +570,7 @@ const MultivariateAssumeDemoModels = Union{ } function posterior_mean(model::MultivariateAssumeDemoModels) # Get some containers to fill. - vals = example_values(model) + vals = Random.rand(model) vals.s[1] = 19 / 8 vals.m[1] = 3 / 4 @@ -582,7 +580,9 @@ function posterior_mean(model::MultivariateAssumeDemoModels) return vals end -function example_values(rng::Random.AbstractRNG, model::MultivariateAssumeDemoModels) +function Random.rand( + rng::Random.AbstractRNG, ::Type{NamedTuple}, model::MultivariateAssumeDemoModels +) # Get template values from `model`. retval = model(rng) vals = (s=retval.s, m=retval.m) @@ -675,12 +675,7 @@ To change how comparison is done for a particular `chain` type, one can overload - `kwargs...`: Keyword arguments forwarded to `sample`. """ function test_sampler( - models, - sampler::AbstractMCMC.AbstractSampler, - args...; - atol=1e-1, - rtol=1e-3, - kwargs..., + models, sampler::AbstractMCMC.AbstractSampler, args...; atol=1e-1, rtol=1e-3, kwargs... ) @testset "$(typeof(sampler)) on $(nameof(model))" for model in models chain = AbstractMCMC.sample(model, sampler, args...; kwargs...) @@ -717,9 +712,7 @@ Test that `sampler` produces the correct marginal posterior means on all models As of right now, this is just an alias for [`test_sampler_on_demo_models`](@ref). """ -function test_sampler_continuous( - sampler::AbstractMCMC.AbstractSampler, args...; kwargs... -) +function test_sampler_continuous(sampler::AbstractMCMC.AbstractSampler, args...; kwargs...) return test_sampler_on_demo_models(sampler, args...; kwargs...) end diff --git a/test/loglikelihoods.jl b/test/loglikelihoods.jl index bd04a76a5..b390997af 100644 --- a/test/loglikelihoods.jl +++ b/test/loglikelihoods.jl @@ -1,6 +1,6 @@ @testset "loglikelihoods.jl" begin @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS - example_values = DynamicPPL.TestUtils.example_values(m) + example_values = rand(NamedTuple, m) # Instantiate a `VarInfo` with the example values. vi = VarInfo(m) diff --git a/test/simple_varinfo.jl b/test/simple_varinfo.jl index 175f264d4..6a8c545ca 100644 --- a/test/simple_varinfo.jl +++ b/test/simple_varinfo.jl @@ -62,7 +62,7 @@ DynamicPPL.TestUtils.DEMO_MODELS # We might need to pre-allocate for the variable `m`, so we need # to see whether this is the case. - svi_nt = SimpleVarInfo(DynamicPPL.TestUtils.example_values(model)) + svi_nt = SimpleVarInfo(rand(NamedTuple, model)) svi_dict = SimpleVarInfo(VarInfo(model), Dict) @testset "$(nameof(typeof(DynamicPPL.values_as(svi))))" for svi in ( @@ -88,7 +88,7 @@ @test getlogp(svi_new) != 0 ### Evaluation ### - values_eval_constrained = DynamicPPL.TestUtils.example_values(model) + values_eval_constrained = rand(NamedTuple, model) if DynamicPPL.istrans(svi) _values_prior, logpri_true = DynamicPPL.TestUtils.logprior_true_with_logabsdet_jacobian( model, values_eval_constrained... From 61a594cb6c7283b8079982c32d3ef9ee4b22c063 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Wed, 13 Jul 2022 10:39:22 +0100 Subject: [PATCH 35/35] updated docs --- docs/src/api.md | 1 + src/test_utils.jl | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index c7133a5f9..809e6c49e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -106,6 +106,7 @@ DynamicPPL provides several demo models and helpers for testing samplers in the DynamicPPL.TestUtils.test_sampler DynamicPPL.TestUtils.test_sampler_on_demo_models DynamicPPL.TestUtils.test_sampler_continuous +DynamicPPL.TestUtils.marginal_mean_of_samples ``` ```@docs diff --git a/src/test_utils.jl b/src/test_utils.jl index 508eb275f..ef314fa91 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -657,12 +657,12 @@ marginal_mean_of_samples(chain, varname) = mean(Array(chain[Symbol(varname)])) Test that `sampler` produces correct marginal posterior means on each model in `models`. In short, this method iterates through `models`, calls `AbstractMCMC.sample` on the -`model` and `sampler` to produce a `chain`, and then checks [`marginal_mean_of_samples(chain, vn)`](@ref) +`model` and `sampler` to produce a `chain`, and then checks `marginal_mean_of_samples(chain, vn)` for every (leaf) varname `vn` against the corresponding value returned by [`posterior_mean`](@ref) for each model. To change how comparison is done for a particular `chain` type, one can overload -[`marginal_mean_of_samples(chain, vn)`](@ref) for the corresponding type. +[`marginal_mean_of_samples`](@ref) for the corresponding type. # Arguments - `models`: A collection of instaces of [`DynamicPPL.Model`](@ref) to test on.