Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
6c39a79
removed unnecessary exports
torfjelde Jun 8, 2021
a158e73
updated OptimizationContext
torfjelde Jun 8, 2021
a2673c5
updated ESS smapler
torfjelde Jun 8, 2021
48b8463
fixed #1633
torfjelde Jun 8, 2021
ca81eb0
fixed bug where ESS didnt support dot_observe
torfjelde Jun 8, 2021
df8bb42
added some additional models to test against
torfjelde Jun 8, 2021
9898865
added test for ESS on the mean-of-mean models
torfjelde Jun 8, 2021
7351562
patch version bump
torfjelde Jun 8, 2021
7bdc4ee
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
20267ce
added tests on mean_of_mean_models for optimization methods too
torfjelde Jun 8, 2021
4a931c1
fixed bug in bijector after recent update to Bijectors.jl
torfjelde Jun 8, 2021
48030eb
use exact value in check_mean_of_mean_models
torfjelde Jun 8, 2021
044da91
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
d3c51d9
fixed bug in OptimizationContext
torfjelde Jun 8, 2021
92cabdf
just use MvNormal instead of TuringDiagMvNormal in test models
torfjelde Jun 8, 2021
06dd7da
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
ce4d8dd
renamed the mean_of_mean models used tests
torfjelde Jun 8, 2021
966b724
renamed the mean_of_mean_models in tests to gdemo_models
torfjelde Jun 8, 2021
544ddc9
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
2cc253b
removed redundant testset block
torfjelde Jun 8, 2021
6f3a102
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 8, 2021
2b5c5e1
upper-bound compat entries for Libtask while we wait for bugfix
torfjelde Jun 10, 2021
84fdceb
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
5912c78
compat entries with hyphens arent supported on Julia v1.3
torfjelde Jun 10, 2021
cab751a
compat entries with hyphens not supported on Julia 1.3
torfjelde Jun 10, 2021
61fd07d
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
20daa3e
also test models with literal observe
torfjelde Jun 10, 2021
687a97d
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
169a014
Update Project.toml
torfjelde Jun 10, 2021
cb9aec5
forgot to bump DPPL version
torfjelde Jun 10, 2021
f761163
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
64a816a
Apply suggestions from code review
torfjelde Jun 10, 2021
f330d21
Merge branch 'tor/fix-1633' into tor/dppl-update
torfjelde Jun 10, 2021
d4065da
Merge branch 'master' into tor/dppl-update
torfjelde Jun 10, 2021
d90ed39
bump DPPL patch version to fix AdvancedPS samplers
torfjelde Jun 10, 2021
46a00a8
bump patch version
torfjelde Jun 10, 2021
0bd5228
updated OptimizationContext to work with the new version of DPPL
torfjelde Aug 13, 2021
d17e6c9
Merge branch 'master' into tor/dppl-update
torfjelde Aug 14, 2021
3c78ff9
Merge branch 'tor/dppl-update' of github.com:TuringLang/Turing.jl int…
torfjelde Aug 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.17.0"
version = "0.17.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down Expand Up @@ -44,7 +44,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8"
DynamicPPL = "0.12.1, 0.13"
DynamicPPL = "0.14"
EllipticalSliceSampling = "0.4"
ForwardDiff = "0.10.3"
Libtask = "0.4, 0.5.3"
Expand Down
34 changes: 4 additions & 30 deletions src/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext
context::C
end

DynamicPPL.NodeTrait(::OptimizationContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(context::OptimizationContext) = context.context
DynamicPPL.setchildcontext(::OptimizationContext, child) = OptimizationContext(child)

# assume
function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, spl, dist, vn, inds, vi)
return DynamicPPL.tilde_assume(ctx, spl, dist, vn, inds, vi)
Expand All @@ -43,16 +47,6 @@ function DynamicPPL.tilde_assume(ctx::OptimizationContext, spl, dist, vn, inds,
return r, Distributions.logpdf(dist, r)
end


# observe
function DynamicPPL.tilde_observe(ctx::OptimizationContext, sampler, right, left, vi)
return DynamicPPL.observe(right, left, vi)
end

function DynamicPPL.tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
return 0
end

# dot assume
function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, sampler, right, left, vns, inds, vi)
return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, inds, vi)
Expand All @@ -72,26 +66,6 @@ function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler::SampleFr
return r, loglikelihood(right, r)
end

# dot observe
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi)
return 0
end

function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
return 0
end

function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi)
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
# affect anything.
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
return loglikelihood(right, r)
end

function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, dists, value, vi)
return sum(Distributions.logpdf.(dists, value))
end

"""
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}

Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ CmdStan = "6.0.8"
Distributions = "< 0.25.11"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.12, 0.13"
DynamicPPL = "0.14"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12"
MCMCChains = "4.13.0"
Expand Down