Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
6c39a79
removed unnecessary exports
torfjelde Jun 8, 2021
a158e73
updated OptimizationContext
torfjelde Jun 8, 2021
a2673c5
updated ESS smapler
torfjelde Jun 8, 2021
48b8463
fixed #1633
torfjelde Jun 8, 2021
ca81eb0
fixed bug where ESS didnt support dot_observe
torfjelde Jun 8, 2021
df8bb42
added some additional models to test against
torfjelde Jun 8, 2021
9898865
added test for ESS on the mean-of-mean models
torfjelde Jun 8, 2021
7351562
patch version bump
torfjelde Jun 8, 2021
7bdc4ee
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
20267ce
added tests on mean_of_mean_models for optimization methods too
torfjelde Jun 8, 2021
4a931c1
fixed bug in bijector after recent update to Bijectors.jl
torfjelde Jun 8, 2021
48030eb
use exact value in check_mean_of_mean_models
torfjelde Jun 8, 2021
044da91
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
d3c51d9
fixed bug in OptimizationContext
torfjelde Jun 8, 2021
92cabdf
just use MvNormal instead of TuringDiagMvNormal in test models
torfjelde Jun 8, 2021
06dd7da
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
ce4d8dd
renamed the mean_of_mean models used tests
torfjelde Jun 8, 2021
966b724
renamed the mean_of_mean_models in tests to gdemo_models
torfjelde Jun 8, 2021
544ddc9
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
2cc253b
removed redundant testset block
torfjelde Jun 8, 2021
6f3a102
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
2b5c5e1
upper-bound compat entries for Libtask while we wait for bugfix
torfjelde Jun 10, 2021
84fdceb
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
5912c78
compat entries with hyphens arent supported on Julia v1.3
torfjelde Jun 10, 2021
cab751a
compat entries with hyphens not supported on Julia 1.3
torfjelde Jun 10, 2021
61fd07d
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
20daa3e
also test models with literal observe
torfjelde Jun 10, 2021
687a97d
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
169a014
Update Project.toml
torfjelde Jun 10, 2021
cb9aec5
forgot to bump DPPL version
torfjelde Jun 10, 2021
f761163
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
64a816a
Apply suggestions from code review
torfjelde Jun 10, 2021
f330d21
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
d4065da
Merge branch 'master' into tor/dppl-update
torfjelde Jun 10, 2021
d90ed39
bump DPPL patch version to fix AdvancedPS samplers
torfjelde Jun 10, 2021
46a00a8
bump patch version
torfjelde Jun 10, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.16.1"
version = "0.16.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -44,7 +44,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8"
DynamicPPL = "0.11.0"
DynamicPPL = "0.12.1"
EllipticalSliceSampling = "0.4"
ForwardDiff = "0.10.3"
Libtask = "= 0.4.0, = 0.4.1, = 0.4.2, = 0.5.0, = 0.5.1"
Expand Down
23 changes: 12 additions & 11 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,27 @@ function (ℓ::ESSLogLikelihood)(f)
return getlogp(varinfo)
end

function DynamicPPL.tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)
function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn, inds, vi)
if inspace(vn, sampler)
return DynamicPPL.tilde(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi)
return DynamicPPL.tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi)
else
return DynamicPPL.tilde(rng, ctx, SampleFromPrior(), right, vn, inds, vi)
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, inds, vi)
end
end

function DynamicPPL.tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.tilde(ctx, SampleFromPrior(), right, left, vi)
function DynamicPPL.tilde_observe(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
end

function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi)
if inspace(vn, sampler)
return DynamicPPL.dot_tilde(rng, LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi)
function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vns, inds, vi)
# TODO: Or should we do `all(Base.Fix2(inspace, sampler), vns)`?
if inspace(first(vns), sampler)
return DynamicPPL.dot_tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, inds, vi)
else
return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vn, inds, vi)
return DynamicPPL.dot_tilde_assume(rng, ctx, SampleFromPrior(), right, left, vns, inds, vi)
end
end

function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi)
function DynamicPPL.dot_tilde_observe(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
return DynamicPPL.dot_tilde_observe(ctx, SampleFromPrior(), right, left, vi)
end
71 changes: 27 additions & 44 deletions src/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import ..AbstractMCMC: AbstractSampler
import ..DynamicPPL
import ..DynamicPPL: Model, AbstractContext, VarInfo, AbstractContext, VarName,
_getindex, getsym, getfield, settrans!, setorder!,
get_and_set_val!, istrans, tilde, dot_tilde, get_vns_and_dist
get_and_set_val!, istrans
import .Optim
import .Optim: optimize
import ..ForwardDiff
Expand All @@ -29,86 +29,69 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext
end

# assume
function DynamicPPL.tilde(rng, ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi)
return DynamicPPL.tilde(ctx, spl, dist, vn, inds, vi)
function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, spl, dist, vn, inds, vi)
return DynamicPPL.tilde_assume(ctx, spl, dist, vn, inds, vi)
end

function DynamicPPL.tilde(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn::VarName, inds, vi)
function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn, inds, vi)
r = vi[vn]
return r, 0
end

function DynamicPPL.tilde(ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi)
function DynamicPPL.tilde_assume(ctx::OptimizationContext, spl, dist, vn, inds, vi)
r = vi[vn]
return r, Distributions.logpdf(dist, r)
end


# observe
function DynamicPPL.tilde(rng, ctx::OptimizationContext, sampler, right, left, vi)
return DynamicPPL.tilde(ctx, sampler, right, left, vi)
function DynamicPPL.tilde_observe(ctx::OptimizationContext, sampler, right, left, vi)
return DynamicPPL.observe(right, left, vi)
end

function DynamicPPL.tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
function DynamicPPL.tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
return 0
end

function DynamicPPL.tilde(ctx::OptimizationContext, sampler, dist, value, vi)
return Distributions.logpdf(dist, value)
end

# dot assume
function DynamicPPL.dot_tilde(rng, ctx::OptimizationContext, sampler, right, left, vn::VarName, inds, vi)
return DynamicPPL.dot_tilde(ctx, sampler, right, left, vn, inds, vi)
function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, sampler, right, left, vns, inds, vi)
return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, inds, vi)
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vn::VarName, _, vi)
vns, dist = get_vns_and_dist(right, left, vn)
r = getval(vi, vns)
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, sampler::SampleFromPrior, right, left, vns, _, vi)
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
# affect anything.
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
return r, 0
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn::VarName, _, vi)
vns, dist = get_vns_and_dist(right, left, vn)
r = getval(vi, vns)
return r, loglikelihood(dist, r)
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi)
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
# affect anything.
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
return r, loglikelihood(right, r)
end

# dot observe
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi)
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi)
return 0
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
return 0
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn, _, vi)
vns, dist = get_vns_and_dist(right, left, vn)
r = getval(vi, vns)
return loglikelihood(dist, r)
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi)
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
# affect anything.
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
return loglikelihood(right, r)
end

function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, dists, value, vi)
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, dists, value, vi)
return sum(Distributions.logpdf.(dists, value))
end

function getval(
vi,
vns::AbstractVector{<:VarName},
)
r = vi[vns]
return r
end

function getval(
vi,
vns::AbstractArray{<:VarName},
)
r = reshape(vi[vec(vns)], size(vns))
return r
end

"""
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ CmdStan = "6.0.8"
Distributions = "0.23.8, 0.24, 0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.11.0"
DynamicPPL = "0.12"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12"
MCMCChains = "4.0.4"
Expand Down