Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.16.0"
version = "0.16.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -47,7 +47,7 @@ DocStringExtensions = "0.8"
DynamicPPL = "0.11.0"
EllipticalSliceSampling = "0.4"
ForwardDiff = "0.10.3"
Libtask = "0.4, 0.5"
Libtask = "= 0.4.0, = 0.4.1, = 0.4.2, = 0.5.0, = 0.5.1"
MCMCChains = "4"
NamedArrays = "0.9"
Reexport = "0.2, 1"
Expand Down
8 changes: 5 additions & 3 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ function Base.rand(rng::Random.AbstractRNG, p::ESSPrior)
sampler = p.sampler
varinfo = p.varinfo
vns = _getvns(varinfo, sampler)
set_flag!(varinfo, vns[1][1], "del")
for vn in Iterators.flatten(values(vns))
set_flag!(varinfo, vn, "del")
end
Comment on lines +115 to +117
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add a context that does this 😛 At least it would be more convenient than dealing with DynamicPPL internals here. I remember that I was very confused and uncertain if I did it correctly when I implemented this. It seemed to work 😬

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works sometimes because in assume we only check if the flag is set for vns[1], but that is only for "sub-symbols". This is the line I'm referring to: https://github.com/TuringLang/DynamicPPL.jl/blob/9083299db3f623136895cae80ef5f10d7fcf8d2c/src/context_implementations.jl#L268. But this won't work if we have more than one key in vi.metadata or if tilde_assume and others are called with a varname subsumed e.g. m[2].

And this won't be an issue once we have a clear separation between sampling and evaluation. These sorts of bugs show up soooo often (and I agree, it's super-confusing), so looking forward to not having those:)

p.model(rng, varinfo, sampler)
return varinfo[sampler]
end
Expand Down Expand Up @@ -155,6 +157,6 @@ function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS},
end
end

function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vi)
function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi)
Comment on lines +160 to +161
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this a bug or is needed because of recent changes in DynamicPPL?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a bug. This is never actually hit and given the arguments I'm assuming it was intended to be a dot_tilde_observe rather than a dot_tilde_assume (which is the case when rng is passed, but because the signature doesn't match the rest of the dot_tilde_assume, it was never hit). The result was that ESS didn't work for dotted observations.

end
5 changes: 3 additions & 2 deletions src/variational/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ function Bijectors.bijector(
end

bs = Bijectors.bijector.(tuple(dists...))
rs = tuple(ranges...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is not related to ESS?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct; this is just to ensure that we have the same behavior as we had before the most recent release of Bijectors.jl. I just noticed it when trying to figure out why the tests weren't passing.


if sym2ranges
return (
Bijectors.Stacked(bs, ranges),
Bijectors.Stacked(bs, rs),
(; collect(zip(keys(sym_lookup), values(sym_lookup)))...),
)
else
return Bijectors.Stacked(bs, ranges)
return Bijectors.Stacked(bs, rs)
end
end

Expand Down
4 changes: 4 additions & 0 deletions test/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,9 @@
ESS(:mu1), ESS(:mu2))
chain = sample(MoGtest_default, alg, 6000)
check_MoGtest_default(chain, atol = 0.1)

# Different "equivalent" models.
Random.seed!(125)
check_gdemo_models(ESS(), 1_000)
end
end
9 changes: 9 additions & 0 deletions test/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,13 @@
@test isapprox(mle1.values.array, mle2.values.array)
@test isapprox(map1.values.array, map2.values.array)
end

@testset "MAP on $(m.name)" for m in gdemo_models
result = optimize(m, MAP())
@test mean(result.values) ≈ 8.0 rtol=0.05
end
@testset "MLE on $(m.name)" for m in gdemo_models
result = optimize(m, MLE())
@test mean(result.values) ≈ 10.0 rtol=0.05
end
end
91 changes: 91 additions & 0 deletions test/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,94 @@ MoGtest_default = MoGtest([1.0 1.0 4.0 4.0])

# Declare empty model to make the Sampler constructor work.
@model empty_model() = begin x = 1; end

# A collection of models for which the mean-of-means for the posterior should
# be same.
Comment on lines +55 to +56
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what exactly do you mean with mean-of-means? And is the value the same as the prior? Or between the models? And only with the default arguments or in general?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to have a collection of models which tries out all the combinations of *_tilde_*, but this means that we'll sometimes have univariate latent variables rather than multivariate (e.g. gdemo5 below). Therefore I compare the mean of the mean of the latent variables rather than the variables directly.

Buuuuut now that we're comparing to the true mean rather than pitting the different models against each other, I guess we don't need to do that 😅

@model function gdemo1(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV}
# `dot_assume` and `observe`
m = TV(undef, length(x))
m .~ Normal()
x ~ MvNormal(m, 0.5 * ones(length(x)))
end

@model function gdemo2(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV}
# `assume` with indexing and `observe`
m = TV(undef, length(x))
for i in eachindex(m)
m[i] ~ Normal()
end
x ~ MvNormal(m, 0.5 * ones(length(x)))
end

@model function gdemo3(x = 10 * ones(2))
# Multivariate `assume` and `observe`
m ~ MvNormal(length(x), 1.0)
x ~ MvNormal(m, 0.5 * ones(length(x)))
end

@model function gdemo4(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV}
# `dot_assume` and `observe` with indexing
m = TV(undef, length(x))
m .~ Normal()
for i in eachindex(x)
x[i] ~ Normal(m[i], 0.5)
end
end

# Using vector of `length` 1 here so the posterior of `m` is the same
# as the others.
@model function gdemo5(x = 10 * ones(1))
# `assume` and `dot_observe`
m ~ Normal()
x .~ Normal(m, 0.5)
end

@model function gdemo6()
# `assume` and literal `observe`
m ~ MvNormal(2, 1.0)
[10.0, 10.0] ~ MvNormal(m, 0.5 * ones(2))
end

@model function gdemo7(::Type{TV} = Vector{Float64}) where {TV}
# `dot_assume` and literal `observe` with indexing
m = TV(undef, 2)
m .~ Normal()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end
end

@model function gdemo8()
# `assume` and literal `dot_observe`
m ~ Normal()
[10.0, ] .~ Normal(m, 0.5)
end

@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV}
m = TV(undef, 2)
m .~ Normal()

return m
end

@model function gdemo9()
# Submodel prior
m = @submodel _prior_dot_assume()
for i in eachindex(m)
10.0 ~ Normal(m[i], 0.5)
end
end

@model function _likelihood_dot_observe(m, x)
x ~ MvNormal(m, 0.5 * ones(length(m)))
end

@model function gdemo10(x = 10 * ones(2), ::Type{TV} = Vector{Float64}) where {TV}
m = TV(undef, length(x))
m .~ Normal()

# Submodel likelihood
@submodel _likelihood_dot_observe(m, x)
end

const gdemo_models = (gdemo1(), gdemo2(), gdemo3(), gdemo4(), gdemo5(), gdemo6(), gdemo7(), gdemo8(), gdemo9(), gdemo10())
10 changes: 10 additions & 0 deletions test/test_utils/numerical_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,13 @@ function check_MoGtest_default(chain; atol=0.2, rtol=0.0)
[1.0, 1.0, 2.0, 2.0, 1.0, 4.0],
atol=atol, rtol=rtol)
end

function check_gdemo_models(alg, nsamples, args...; atol=0.0, rtol=0.2, kwargs...)
@testset "$(alg) on $(m.name)" for m in gdemo_models
# Log this so that if something goes wrong, we can identify the
# algorithm and model.
μ = mean(Array(sample(m, alg, nsamples, args...; kwargs...)))

@test μ ≈ 8.0 atol=atol rtol=rtol
end
end