diff --git a/Project.toml b/Project.toml index 29e73e61fb..d678f4e9cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.16.0" +version = "0.16.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -47,7 +47,7 @@ DocStringExtensions = "0.8" DynamicPPL = "0.11.0" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" -Libtask = "0.4, 0.5" +Libtask = "= 0.4.0, = 0.4.1, = 0.4.2, = 0.5.0, = 0.5.1" MCMCChains = "4" NamedArrays = "0.9" Reexport = "0.2, 1" diff --git a/src/inference/ess.jl b/src/inference/ess.jl index ebcfc4a17d..eeadbc40d2 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -112,7 +112,9 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) sampler = p.sampler varinfo = p.varinfo vns = _getvns(varinfo, sampler) - set_flag!(varinfo, vns[1][1], "del") + for vn in Iterators.flatten(values(vns)) + set_flag!(varinfo, vn, "del") + end p.model(rng, varinfo, sampler) return varinfo[sampler] end @@ -155,6 +157,6 @@ function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, end end -function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vi) +function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi) end diff --git a/src/variational/advi.jl b/src/variational/advi.jl index a048d9bcb5..8e98c9c4c7 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -34,14 +34,15 @@ function Bijectors.bijector( end bs = Bijectors.bijector.(tuple(dists...)) + rs = tuple(ranges...) if sym2ranges return ( - Bijectors.Stacked(bs, ranges), + Bijectors.Stacked(bs, rs), (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), ) else - return Bijectors.Stacked(bs, ranges) + return Bijectors.Stacked(bs, rs) end end diff --git a/test/inference/ess.jl b/test/inference/ess.jl index 9a1fcd3c1a..e5e73c8f65 100644 --- a/test/inference/ess.jl +++ b/test/inference/ess.jl @@ -54,5 +54,9 @@ ESS(:mu1), ESS(:mu2)) chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain, atol = 0.1) + + # Different "equivalent" models. + Random.seed!(125) + check_gdemo_models(ESS(), 1_000) end end diff --git a/test/modes/ModeEstimation.jl b/test/modes/ModeEstimation.jl index 209ad7ab78..a10a938d62 100644 --- a/test/modes/ModeEstimation.jl +++ b/test/modes/ModeEstimation.jl @@ -96,4 +96,13 @@ @test isapprox(mle1.values.array, mle2.values.array) @test isapprox(map1.values.array, map2.values.array) end + + @testset "MAP on $(m.name)" for m in gdemo_models + result = optimize(m, MAP()) + @test mean(result.values) ≈ 8.0 rtol=0.05 + end + @testset "MLE on $(m.name)" for m in gdemo_models + result = optimize(m, MLE()) + @test mean(result.values) ≈ 10.0 rtol=0.05 + end end diff --git a/test/test_utils/models.jl b/test/test_utils/models.jl index af207621bb..cf3b3f25e7 100644 --- a/test/test_utils/models.jl +++ b/test/test_utils/models.jl @@ -51,3 +51,94 @@ MoGtest_default = MoGtest([1.0 1.0 4.0 4.0]) # Declare empty model to make the Sampler constructor work. @model empty_model() = begin x = 1; end + +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x = 10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + x ~ MvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x = 10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + x .~ Normal(m, 0.5) +end + +@model function gdemo6() + # `assume` and literal `observe` + m ~ MvNormal(2, 1.0) + [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +end + +@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function gdemo8() + # `assume` and literal `dot_observe` + m ~ Normal() + [10.0, ] .~ Normal(m, 0.5) +end + +@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + x ~ MvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const gdemo_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo6(), gdemo7(), gdemo8(), gdemo9(), gdemo10()) diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 090dabb31a..c3f29fbffa 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -64,3 +64,13 @@ function check_MoGtest_default(chain; atol=0.2, rtol=0.0) [1.0, 1.0, 2.0, 2.0, 1.0, 4.0], atol=atol, rtol=rtol) end + +function check_gdemo_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...) + @testset "$(alg) on $(m.name)" for m in gdemo_models + # Log this so that if something goes wrong, we can identify the + # algorithm and model. + μ = mean(Array(sample(m, alg, nsamples, args...; kwargs...))) + + @test μ ≈ 8.0 atol=atol rtol=rtol + end +end