From 6c39a79a9e5d8c9f7b6a144880822281d055c996 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 06:22:43 +0100 Subject: [PATCH 01/26] removed unnecessary exports --- src/modes/ModeEstimation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl index a4626b9667..2816e0aa1f 100644 --- a/src/modes/ModeEstimation.jl +++ b/src/modes/ModeEstimation.jl @@ -6,7 +6,7 @@ import ..AbstractMCMC: AbstractSampler import ..DynamicPPL import ..DynamicPPL: Model, AbstractContext, VarInfo, AbstractContext, VarName, _getindex, getsym, getfield, settrans!, setorder!, - get_and_set_val!, istrans, tilde, dot_tilde, get_vns_and_dist + get_and_set_val!, istrans import .Optim import .Optim: optimize import ..ForwardDiff From a158e739a1771bf463b63734babff5578b6e16f7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 06:22:58 +0100 Subject: [PATCH 02/26] updated OptimizationContext --- src/modes/ModeEstimation.jl | 41 +++++++++++++++---------------------- 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl index 2816e0aa1f..40ab2c4469 100644 --- a/src/modes/ModeEstimation.jl +++ b/src/modes/ModeEstimation.jl @@ -29,67 +29,60 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext end # assume -function DynamicPPL.tilde(rng, ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi) - return DynamicPPL.tilde(ctx, spl, dist, vn, inds, vi) +function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, spl, dist, vn, inds, vi) + return DynamicPPL.tilde_assume(ctx, spl, dist, vn, inds, vi) end -function DynamicPPL.tilde(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn::VarName, inds, vi) +function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn, inds, vi) r = vi[vn] return r, 0 end -function DynamicPPL.tilde(ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi) +function DynamicPPL.tilde_assume(ctx::OptimizationContext, spl, dist, vn, inds, vi) r = vi[vn] return r, Distributions.logpdf(dist, r) end # observe -function DynamicPPL.tilde(rng, ctx::OptimizationContext, sampler, right, left, vi) - return DynamicPPL.tilde(ctx, sampler, right, left, vi) +function DynamicPPL.tilde_observe(ctx::OptimizationContext, sampler, right, left, vi) + return DynamicPPL.observe(right, left, vi) end -function DynamicPPL.tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) +function DynamicPPL.tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) return 0 end -function DynamicPPL.tilde(ctx::OptimizationContext, sampler, dist, value, vi) - return Distributions.logpdf(dist, value) -end - # dot assume -function DynamicPPL.dot_tilde(rng, ctx::OptimizationContext, sampler, right, left, vn::VarName, inds, vi) - return DynamicPPL.dot_tilde(ctx, sampler, right, left, vn, inds, vi) +function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, sampler, right, left, vns, inds, vi) + return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, inds, vi) end -function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vn::VarName, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vns, _, vi) r = getval(vi, vns) return r, 0 end -function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn::VarName, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) +function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler, right, left, vns, _, vi) r = getval(vi, vns) - return r, loglikelihood(dist, r) + return r, loglikelihood(right, r) end # dot observe -function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi) return 0 end -function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) return 0 end -function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, right, left, vns, _, vi) r = getval(vi, vns) - return loglikelihood(dist, r) + return loglikelihood(right, r) end -function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, dists, value, vi) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, dists, value, vi) return sum(Distributions.logpdf.(dists, value)) end From a2673c56f0226dad3a96af8458f5abcdb80c30c6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 06:23:13 +0100 Subject: [PATCH 03/26] updated ESS smapler --- src/inference/ess.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/inference/ess.jl b/src/inference/ess.jl index ebcfc4a17d..a7289e019b 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -135,26 +135,27 @@ function (ℓ::ESSLogLikelihood)(f) return getlogp(varinfo) end -function DynamicPPL.tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi) +function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn, inds, vi) if inspace(vn, sampler) - return DynamicPPL.tilde(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi) + return DynamicPPL.tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi) else - return DynamicPPL.tilde(rng, ctx, SampleFromPrior(), right, vn, inds, vi) + return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, inds, vi) end end -function DynamicPPL.tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.tilde(ctx, SampleFromPrior(), right, left, vi) +function DynamicPPL.tilde_observe(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi) end -function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi) - if inspace(vn, sampler) - return DynamicPPL.dot_tilde(rng, LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi) +function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vns, inds, vi) + # TODO: Or should we do `all(Base.Fix2(inspace, sampler), vns)`? + if inspace(first(vns), sampler) + return DynamicPPL.dot_tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, inds, vi) else - return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vn, inds, vi) + return DynamicPPL.dot_tilde_assume(rng, ctx, SampleFromPrior(), right, left, vns, inds, vi) end end -function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vi) +function DynamicPPL.dot_tilde_observe(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return DynamicPPL.dot_tilde_observe(ctx, SampleFromPrior(), right, left, vi) end From 48b8463ffc9e3f5c5534ee2ffd22a33931588cac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 07:35:50 +0100 Subject: [PATCH 04/26] fixed #1633 --- src/inference/ess.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/inference/ess.jl b/src/inference/ess.jl index ebcfc4a17d..484d557116 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -112,7 +112,9 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior) sampler = p.sampler varinfo = p.varinfo vns = _getvns(varinfo, sampler) - set_flag!(varinfo, vns[1][1], "del") + for vn in Iterators.flatten(values(vns)) + set_flag!(varinfo, vn, "del") + end p.model(rng, varinfo, sampler) return varinfo[sampler] end From ca81eb00e6e4a433ce31d8ce352ffc859aeacfb0 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 07:36:02 +0100 Subject: [PATCH 05/26] fixed bug where ESS didnt support dot_observe --- src/inference/ess.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/inference/ess.jl b/src/inference/ess.jl index 484d557116..eeadbc40d2 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -157,6 +157,6 @@ function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, end end -function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vi) +function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi) end From df8bb42a71a8a41c72d8f7db7a68f54fd1226072 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 07:36:26 +0100 Subject: [PATCH 06/26] added some additional models to test against --- test/test_utils/models.jl | 91 ++++++++++++++++++++++++++++++ test/test_utils/numerical_tests.jl | 17 ++++++ 2 files changed, 108 insertions(+) diff --git a/test/test_utils/models.jl b/test/test_utils/models.jl index af207621bb..0f9f62cfdc 100644 --- a/test/test_utils/models.jl +++ b/test/test_utils/models.jl @@ -51,3 +51,94 @@ MoGtest_default = MoGtest([1.0 1.0 4.0 4.0]) # Declare empty model to make the Sampler constructor work. @model empty_model() = begin x = 1; end + +# A collection of models for which the mean-of-means for the posterior should +# be same. +@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` + m = TV(undef, length(x)) + m .~ Normal() + x ~ TuringDiagMvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `assume` with indexing and `observe` + m = TV(undef, length(x)) + for i in eachindex(m) + m[i] ~ Normal() + end + x ~ TuringDiagMvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo3(x = 10 * ones(2)) + # Multivariate `assume` and `observe` + m ~ MvNormal(length(x), 1.0) + x ~ TuringDiagMvNormal(m, 0.5 * ones(length(x))) +end + +@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and `observe` with indexing + m = TV(undef, length(x)) + m .~ Normal() + for i in eachindex(x) + x[i] ~ Normal(m[i], 0.5) + end +end + +# Using vector of `length` 1 here so the posterior of `m` is the same +# as the others. +@model function gdemo5(x = 10 * ones(1)) + # `assume` and `dot_observe` + m ~ Normal() + x .~ Normal(m, 0.5) +end + +# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `observe` +# m ~ MvNormal(length(x), 1.0) +# [10.0, 10.0] ~ TuringDiagMvNormal(m, 0.5 * ones(2)) +# end + +@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} + # `dot_assume` and literal `observe` with indexing + m = TV(undef, 2) + m .~ Normal() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV} +# # `assume` and literal `dot_observe` +# m ~ Normal() +# [10.0, ] .~ Normal(m, 0.5) +# end + +@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, 2) + m .~ Normal() + + return m +end + +@model function gdemo9() + # Submodel prior + m = @submodel _prior_dot_assume() + for i in eachindex(m) + 10.0 ~ Normal(m[i], 0.5) + end +end + +@model function _likelihood_dot_observe(m, x) + x ~ TuringDiagMvNormal(m, 0.5 * ones(length(m))) +end + +@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} + m = TV(undef, length(x)) + m .~ Normal() + + # Submodel likelihood + @submodel _likelihood_dot_observe(m, x) +end + +const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 090dabb31a..7f81288f81 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -64,3 +64,20 @@ function check_MoGtest_default(chain; atol=0.2, rtol=0.0) [1.0, 1.0, 2.0, 2.0, 1.0, 4.0], atol=atol, rtol=rtol) end + +function check_mean_of_mean_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...) + means = [] + for m in mean_of_mean_models + # Log this so that if something goes wrong, we can identify the + # algorithm and model. + @info "Testing $(alg) on $(m.name)" + μ = mean(Array(sample(m, alg, nsamples, args...; kwargs...))) + push!(means, μ) + end + + for i in 1:length(means) + for j = i + 1:length(means) + @test means[i] ≈ means[j] atol=atol rtol=rtol + end + end +end From 989886591785ddb9a233b8071d1d2ed9cc38b6fa Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 07:36:39 +0100 Subject: [PATCH 07/26] added test for ESS on the mean-of-mean models --- test/inference/ess.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/inference/ess.jl b/test/inference/ess.jl index 9a1fcd3c1a..5f8f396e60 100644 --- a/test/inference/ess.jl +++ b/test/inference/ess.jl @@ -54,5 +54,9 @@ ESS(:mu1), ESS(:mu2)) chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain, atol = 0.1) + + # Mean of means models + Random.seed!(125) + check_mean_of_mean_models(ESS(), 1_000) end end From 735156224c38cdc49800fd2aa65a060a0537f78d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 07:37:32 +0100 Subject: [PATCH 08/26] patch version bump --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 29e73e61fb..d773a7b6d7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.16.0" +version = "0.16.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 20267cee1cd8ad6dfb2fd279afc259323105416a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 08:11:04 +0100 Subject: [PATCH 09/26] added tests on mean_of_mean_models for optimization methods too --- test/modes/ModeEstimation.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/modes/ModeEstimation.jl b/test/modes/ModeEstimation.jl index 209ad7ab78..e716c5ff7b 100644 --- a/test/modes/ModeEstimation.jl +++ b/test/modes/ModeEstimation.jl @@ -96,4 +96,16 @@ @test isapprox(mle1.values.array, mle2.values.array) @test isapprox(map1.values.array, map2.values.array) end + + @testset "Mean of mean models" begin + for m in mean_of_mean_models + @info "Testing MAP on $(m)" + result = optimize(m, MAP()) + @test mean(result.values) ≈ 8.0 rtol=0.05 + + @info "Testing MLE on $(m)" + result = optimize(m, MLE()) + @test mean(result.values) ≈ 10.0 rtol=0.05 + end + end end From 4a931c113804988d199e77332aa4b088a5e718ac Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 08:20:58 +0100 Subject: [PATCH 10/26] fixed bug in bijector after recent update to Bijectors.jl --- src/variational/advi.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/variational/advi.jl b/src/variational/advi.jl index a048d9bcb5..8e98c9c4c7 100644 --- a/src/variational/advi.jl +++ b/src/variational/advi.jl @@ -34,14 +34,15 @@ function Bijectors.bijector( end bs = Bijectors.bijector.(tuple(dists...)) + rs = tuple(ranges...) if sym2ranges return ( - Bijectors.Stacked(bs, ranges), + Bijectors.Stacked(bs, rs), (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), ) else - return Bijectors.Stacked(bs, ranges) + return Bijectors.Stacked(bs, rs) end end From 48030eb468d8c75fe97a54a596df6642ef146f72 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 08:22:42 +0100 Subject: [PATCH 11/26] use exact value in check_mean_of_mean_models --- test/test_utils/numerical_tests.jl | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 7f81288f81..78cd3e015e 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -66,18 +66,12 @@ function check_MoGtest_default(chain; atol=0.2, rtol=0.0) end function check_mean_of_mean_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...) - means = [] for m in mean_of_mean_models # Log this so that if something goes wrong, we can identify the # algorithm and model. @info "Testing $(alg) on $(m.name)" μ = mean(Array(sample(m, alg, nsamples, args...; kwargs...))) - push!(means, μ) - end - for i in 1:length(means) - for j = i + 1:length(means) - @test means[i] ≈ means[j] atol=atol rtol=rtol - end + @test μ ≈ 8.0 atol=atol rtol=rtol end end From d3c51d9cec1f7198873f92053dd9e24c5022a57b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 08:36:23 +0100 Subject: [PATCH 12/26] fixed bug in OptimizationContext --- src/modes/ModeEstimation.jl | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl index 40ab2c4469..92564770c1 100644 --- a/src/modes/ModeEstimation.jl +++ b/src/modes/ModeEstimation.jl @@ -58,13 +58,17 @@ function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationC return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, inds, vi) end -function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vns, _, vi) - r = getval(vi, vns) +function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, sampler::SampleFromPrior, right, left, vns, _, vi) + # Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't + # affect anything. + r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler) return r, 0 end -function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler, right, left, vns, _, vi) - r = getval(vi, vns) +function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi) + # Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't + # affect anything. + r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler) return r, loglikelihood(right, r) end @@ -77,8 +81,10 @@ function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, return 0 end -function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, right, left, vns, _, vi) - r = getval(vi, vns) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi) + # Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't + # affect anything. + r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler) return loglikelihood(right, r) end @@ -86,22 +92,6 @@ function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, dists, return sum(Distributions.logpdf.(dists, value)) end -function getval( - vi, - vns::AbstractVector{<:VarName}, -) - r = vi[vns] - return r -end - -function getval( - vi, - vns::AbstractArray{<:VarName}, -) - r = reshape(vi[vec(vns)], size(vns)) - return r -end - """ OptimLogDensity{M<:Model,C<:Context,V<:VarInfo} From 92cabdff74f4321bc12ce4ddd4a2c4d08ece0387 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 13:05:25 +0100 Subject: [PATCH 13/26] just use MvNormal instead of TuringDiagMvNormal in test models --- test/modes/ModeEstimation.jl | 4 ++-- test/test_utils/models.jl | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/test/modes/ModeEstimation.jl b/test/modes/ModeEstimation.jl index e716c5ff7b..efd026386e 100644 --- a/test/modes/ModeEstimation.jl +++ b/test/modes/ModeEstimation.jl @@ -99,11 +99,11 @@ @testset "Mean of mean models" begin for m in mean_of_mean_models - @info "Testing MAP on $(m)" + @info "Testing MAP on $(m.name)" result = optimize(m, MAP()) @test mean(result.values) ≈ 8.0 rtol=0.05 - @info "Testing MLE on $(m)" + @info "Testing MLE on $(m.name)" result = optimize(m, MLE()) @test mean(result.values) ≈ 10.0 rtol=0.05 end diff --git a/test/test_utils/models.jl b/test/test_utils/models.jl index 0f9f62cfdc..b88d7d71f4 100644 --- a/test/test_utils/models.jl +++ b/test/test_utils/models.jl @@ -58,7 +58,7 @@ MoGtest_default = MoGtest([1.0 1.0 4.0 4.0]) # `dot_assume` and `observe` m = TV(undef, length(x)) m .~ Normal() - x ~ TuringDiagMvNormal(m, 0.5 * ones(length(x))) + x ~ MvNormal(m, 0.5 * ones(length(x))) end @model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} @@ -67,13 +67,13 @@ end for i in eachindex(m) m[i] ~ Normal() end - x ~ TuringDiagMvNormal(m, 0.5 * ones(length(x))) + x ~ MvNormal(m, 0.5 * ones(length(x))) end @model function gdemo3(x = 10 * ones(2)) # Multivariate `assume` and `observe` m ~ MvNormal(length(x), 1.0) - x ~ TuringDiagMvNormal(m, 0.5 * ones(length(x))) + x ~ MvNormal(m, 0.5 * ones(length(x))) end @model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} @@ -96,7 +96,7 @@ end # @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} # # `assume` and literal `observe` # m ~ MvNormal(length(x), 1.0) -# [10.0, 10.0] ~ TuringDiagMvNormal(m, 0.5 * ones(2)) +# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) # end @model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} @@ -130,7 +130,7 @@ end end @model function _likelihood_dot_observe(m, x) - x ~ TuringDiagMvNormal(m, 0.5 * ones(length(m))) + x ~ MvNormal(m, 0.5 * ones(length(m))) end @model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV} From ce4d8dd2a6d5b985fbf34a71613edeba6fa06e0a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 14:25:27 +0100 Subject: [PATCH 14/26] renamed the mean_of_mean models used tests --- test/inference/ess.jl | 2 +- test/modes/ModeEstimation.jl | 2 +- test/test_utils/models.jl | 2 +- test/test_utils/numerical_tests.jl | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/inference/ess.jl b/test/inference/ess.jl index 5f8f396e60..e236b9c709 100644 --- a/test/inference/ess.jl +++ b/test/inference/ess.jl @@ -57,6 +57,6 @@ # Mean of means models Random.seed!(125) - check_mean_of_mean_models(ESS(), 1_000) + check_gdemo_models(ESS(), 1_000) end end diff --git a/test/modes/ModeEstimation.jl b/test/modes/ModeEstimation.jl index efd026386e..cdacaa51e5 100644 --- a/test/modes/ModeEstimation.jl +++ b/test/modes/ModeEstimation.jl @@ -98,7 +98,7 @@ end @testset "Mean of mean models" begin - for m in mean_of_mean_models + for m in gdemo_models @info "Testing MAP on $(m.name)" result = optimize(m, MAP()) @test mean(result.values) ≈ 8.0 rtol=0.05 diff --git a/test/test_utils/models.jl b/test/test_utils/models.jl index b88d7d71f4..9dc792aa75 100644 --- a/test/test_utils/models.jl +++ b/test/test_utils/models.jl @@ -141,4 +141,4 @@ end @submodel _likelihood_dot_observe(m, x) end -const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) +const gdemo_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 78cd3e015e..8bdb42270d 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -65,8 +65,8 @@ function check_MoGtest_default(chain; atol=0.2, rtol=0.0) atol=atol, rtol=rtol) end -function check_mean_of_mean_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...) - for m in mean_of_mean_models +function check_gdemo_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...) + for m in gdemo_models # Log this so that if something goes wrong, we can identify the # algorithm and model. @info "Testing $(alg) on $(m.name)" From 966b724e16fa8cd20ebd0839de710bf857963e0a Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 20:02:32 +0100 Subject: [PATCH 15/26] renamed the mean_of_mean_models in tests to gdemo_models --- test/inference/ess.jl | 4 ++-- test/modes/ModeEstimation.jl | 7 +++---- test/test_utils/models.jl | 2 +- test/test_utils/numerical_tests.jl | 5 ++--- 4 files changed, 8 insertions(+), 10 deletions(-) diff --git a/test/inference/ess.jl b/test/inference/ess.jl index 5f8f396e60..e5e73c8f65 100644 --- a/test/inference/ess.jl +++ b/test/inference/ess.jl @@ -55,8 +55,8 @@ chain = sample(MoGtest_default, alg, 6000) check_MoGtest_default(chain, atol = 0.1) - # Mean of means models + # Different "equivalent" models. Random.seed!(125) - check_mean_of_mean_models(ESS(), 1_000) + check_gdemo_models(ESS(), 1_000) end end diff --git a/test/modes/ModeEstimation.jl b/test/modes/ModeEstimation.jl index efd026386e..fd32e17407 100644 --- a/test/modes/ModeEstimation.jl +++ b/test/modes/ModeEstimation.jl @@ -98,12 +98,11 @@ end @testset "Mean of mean models" begin - for m in mean_of_mean_models - @info "Testing MAP on $(m.name)" + @testset "MAP on $(m.name)" for m in gdemo_models result = optimize(m, MAP()) @test mean(result.values) ≈ 8.0 rtol=0.05 - - @info "Testing MLE on $(m.name)" + end + @testset "MLE on $(m.name)" for m in gdemo_models result = optimize(m, MLE()) @test mean(result.values) ≈ 10.0 rtol=0.05 end diff --git a/test/test_utils/models.jl b/test/test_utils/models.jl index b88d7d71f4..9dc792aa75 100644 --- a/test/test_utils/models.jl +++ b/test/test_utils/models.jl @@ -141,4 +141,4 @@ end @submodel _likelihood_dot_observe(m, x) end -const mean_of_mean_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) +const gdemo_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) diff --git a/test/test_utils/numerical_tests.jl b/test/test_utils/numerical_tests.jl index 78cd3e015e..c3f29fbffa 100644 --- a/test/test_utils/numerical_tests.jl +++ b/test/test_utils/numerical_tests.jl @@ -65,11 +65,10 @@ function check_MoGtest_default(chain; atol=0.2, rtol=0.0) atol=atol, rtol=rtol) end -function check_mean_of_mean_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...) - for m in mean_of_mean_models +function check_gdemo_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...) + @testset "$(alg) on $(m.name)" for m in gdemo_models # Log this so that if something goes wrong, we can identify the # algorithm and model. - @info "Testing $(alg) on $(m.name)" μ = mean(Array(sample(m, alg, nsamples, args...; kwargs...))) @test μ ≈ 8.0 atol=atol rtol=rtol From 2cc253b33a6bfa1a3e6d6d5511a492f1d554065f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 8 Jun 2021 20:04:55 +0100 Subject: [PATCH 16/26] removed redundant testset block --- test/modes/ModeEstimation.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/test/modes/ModeEstimation.jl b/test/modes/ModeEstimation.jl index fd32e17407..a10a938d62 100644 --- a/test/modes/ModeEstimation.jl +++ b/test/modes/ModeEstimation.jl @@ -97,14 +97,12 @@ @test isapprox(map1.values.array, map2.values.array) end - @testset "Mean of mean models" begin - @testset "MAP on $(m.name)" for m in gdemo_models - result = optimize(m, MAP()) - @test mean(result.values) ≈ 8.0 rtol=0.05 - end - @testset "MLE on $(m.name)" for m in gdemo_models - result = optimize(m, MLE()) - @test mean(result.values) ≈ 10.0 rtol=0.05 - end + @testset "MAP on $(m.name)" for m in gdemo_models + result = optimize(m, MAP()) + @test mean(result.values) ≈ 8.0 rtol=0.05 + end + @testset "MLE on $(m.name)" for m in gdemo_models + result = optimize(m, MLE()) + @test mean(result.values) ≈ 10.0 rtol=0.05 end end From 2b5c5e1c6ab18209cd93fc0a131d8865d94fd706 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 08:18:24 +0100 Subject: [PATCH 17/26] upper-bound compat entries for Libtask while we wait for bugfix --- Project.toml | 2 +- test/Project.toml | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d773a7b6d7..824a22ee91 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DocStringExtensions = "0.8" DynamicPPL = "0.11.0" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" -Libtask = "0.4, 0.5" +Libtask = "0.4 - 0.5.1" MCMCChains = "4" NamedArrays = "0.9" Reexport = "0.2, 1" diff --git a/test/Project.toml b/test/Project.toml index c1e3fe044b..593a29c133 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,6 +11,7 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @@ -40,6 +41,7 @@ DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.11.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12" +Libtask = "< 0.5.2" MCMCChains = "4.0.4" Memoization = "0.1.4" NamedArrays = "0.9.4" From 5912c78c1f8400f87fe4c029856261616833fc7e Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 08:21:40 +0100 Subject: [PATCH 18/26] compat entries with hyphens arent supported on Julia v1.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 824a22ee91..103dc2caef 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DocStringExtensions = "0.8" DynamicPPL = "0.11.0" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" -Libtask = "0.4 - 0.5.1" +Libtask = "< 0.5.2" MCMCChains = "4" NamedArrays = "0.9" Reexport = "0.2, 1" From cab751a46e00d7dc05584acaaae4fdd70a571d36 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 08:22:14 +0100 Subject: [PATCH 19/26] compat entries with hyphens not supported on Julia 1.3 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 824a22ee91..103dc2caef 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DocStringExtensions = "0.8" DynamicPPL = "0.11.0" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" -Libtask = "0.4 - 0.5.1" +Libtask = "< 0.5.2" MCMCChains = "4" NamedArrays = "0.9" Reexport = "0.2, 1" From 20daa3e0324b8ba88874fec3f8403cddf51ec21f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 08:37:32 +0100 Subject: [PATCH 20/26] also test models with literal observe --- test/test_utils/models.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/test/test_utils/models.jl b/test/test_utils/models.jl index 9dc792aa75..cf3b3f25e7 100644 --- a/test/test_utils/models.jl +++ b/test/test_utils/models.jl @@ -93,11 +93,11 @@ end x .~ Normal(m, 0.5) end -# @model function gdemo6(::Type{TV} = Vector{Float64}) where {TV} -# # `assume` and literal `observe` -# m ~ MvNormal(length(x), 1.0) -# [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) -# end +@model function gdemo6() + # `assume` and literal `observe` + m ~ MvNormal(2, 1.0) + [10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2)) +end @model function gdemo7(::Type{TV} = Vector{Float64}) where {TV} # `dot_assume` and literal `observe` with indexing @@ -108,11 +108,11 @@ end end end -# @model function gdemo8(::Type{TV} = Vector{Float64}) where {TV} -# # `assume` and literal `dot_observe` -# m ~ Normal() -# [10.0, ] .~ Normal(m, 0.5) -# end +@model function gdemo8() + # `assume` and literal `dot_observe` + m ~ Normal() + [10.0, ] .~ Normal(m, 0.5) +end @model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV} m = TV(undef, 2) @@ -141,4 +141,4 @@ end @submodel _likelihood_dot_observe(m, x) end -const gdemo_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo7(), gdemo9(), gdemo10()) +const gdemo_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo6(), gdemo7(), gdemo8(), gdemo9(), gdemo10()) From 169a014419484a62690b731d73ec67aa672696d6 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 09:08:53 +0100 Subject: [PATCH 21/26] Update Project.toml Co-authored-by: David Widmann --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 103dc2caef..d678f4e9cc 100644 --- a/Project.toml +++ b/Project.toml @@ -47,7 +47,7 @@ DocStringExtensions = "0.8" DynamicPPL = "0.11.0" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" -Libtask = "< 0.5.2" +Libtask = "= 0.4.0, = 0.4.1, = 0.4.2, = 0.5.0, = 0.5.1" MCMCChains = "4" NamedArrays = "0.9" Reexport = "0.2, 1" From cb9aec565cd8ba158b513ffe236b2aaf06663837 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 09:16:27 +0100 Subject: [PATCH 22/26] forgot to bump DPPL version --- Project.toml | 2 +- test/Project.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 103dc2caef..3a5260ad0b 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8" -DynamicPPL = "0.11.0" +DynamicPPL = "0.12" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" Libtask = "< 0.5.2" diff --git a/test/Project.toml b/test/Project.toml index 593a29c133..2c79a85906 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -38,7 +38,7 @@ CmdStan = "6.0.8" Distributions = "0.23.8, 0.24, 0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.11.0" +DynamicPPL = "0.12" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12" Libtask = "< 0.5.2" From 64a816aeade182a6dd776753f5dbf90ebb258ed7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 11:22:45 +0100 Subject: [PATCH 23/26] Apply suggestions from code review --- test/Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index 593a29c133..c1e3fe044b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -11,7 +11,6 @@ DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73" @@ -41,7 +40,6 @@ DynamicHMC = "2.1.6, 3.0" DynamicPPL = "0.11.0" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12" -Libtask = "< 0.5.2" MCMCChains = "4.0.4" Memoization = "0.1.4" NamedArrays = "0.9.4" From d90ed39d88930d3ceb6c87fe88ff348b23607fc1 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 13:13:07 +0100 Subject: [PATCH 24/26] bump DPPL patch version to fix AdvancedPS samplers --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 2cd2a234f7..a6b318762c 100644 --- a/Project.toml +++ b/Project.toml @@ -44,7 +44,7 @@ DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8" -DynamicPPL = "0.12" +DynamicPPL = "0.12.1" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" Libtask = "= 0.4.0, = 0.4.1, = 0.4.2, = 0.5.0, = 0.5.1" From 46a00a8260d69457f3fd91e31eaa4b02976e370b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Thu, 10 Jun 2021 15:32:11 +0100 Subject: [PATCH 25/26] bump patch version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a6b318762c..b950ff8ff7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.16.1" +version = "0.16.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" From 0bd5228267015a4fad8795ccf2199d041d4b368b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 14 Aug 2021 00:59:45 +0100 Subject: [PATCH 26/26] updated OptimizationContext to work with the new version of DPPL --- Project.toml | 4 ++-- src/modes/ModeEstimation.jl | 34 ++++------------------------------ 2 files changed, 6 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index e9b2fca278..5b75d5b427 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.17.0" +version = "0.17.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -44,7 +44,7 @@ DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8" -DynamicPPL = "0.12.1, 0.13" +DynamicPPL = "0.14" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" Libtask = "0.4, 0.5.3" diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl index 92564770c1..9b2448a439 100644 --- a/src/modes/ModeEstimation.jl +++ b/src/modes/ModeEstimation.jl @@ -28,6 +28,10 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext context::C end +DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(context::OptimizationContext) = context.context +DynamicPPL.setchildcontext(::OptimizationContext, child) = OptimizationContext(child) + # assume function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, spl, dist, vn, inds, vi) return DynamicPPL.tilde_assume(ctx, spl, dist, vn, inds, vi) @@ -43,16 +47,6 @@ function DynamicPPL.tilde_assume(ctx::OptimizationContext, spl, dist, vn, inds, return r, Distributions.logpdf(dist, r) end - -# observe -function DynamicPPL.tilde_observe(ctx::OptimizationContext, sampler, right, left, vi) - return DynamicPPL.observe(right, left, vi) -end - -function DynamicPPL.tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) - return 0 -end - # dot assume function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, sampler, right, left, vns, inds, vi) return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, inds, vi) @@ -72,26 +66,6 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler::SampleFr return r, loglikelihood(right, r) end -# dot observe -function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi) - return 0 -end - -function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) - return 0 -end - -function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi) - # Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't - # affect anything. - r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler) - return loglikelihood(right, r) -end - -function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, dists, value, vi) - return sum(Distributions.logpdf.(dists, value)) -end - """ OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}