Skip to content

Commit 57e0512

Browse files
torfjeldedevmotion
andauthored
Update to new DPPL version (#1636)
* removed unnecessary exports * updated OptimizationContext * updated ESS smapler * fixed #1633 * fixed bug where ESS didnt support dot_observe * added some additional models to test against * added test for ESS on the mean-of-mean models * patch version bump * added tests on mean_of_mean_models for optimization methods too * fixed bug in bijector after recent update to Bijectors.jl * use exact value in check_mean_of_mean_models * fixed bug in OptimizationContext * just use MvNormal instead of TuringDiagMvNormal in test models * renamed the mean_of_mean models used tests * renamed the mean_of_mean_models in tests to gdemo_models * removed redundant testset block * upper-bound compat entries for Libtask while we wait for bugfix * compat entries with hyphens arent supported on Julia v1.3 * compat entries with hyphens not supported on Julia 1.3 * also test models with literal observe * Update Project.toml Co-authored-by: David Widmann <[email protected]> * forgot to bump DPPL version * Apply suggestions from code review * bump DPPL patch version to fix AdvancedPS samplers * bump patch version Co-authored-by: David Widmann <[email protected]>
1 parent 9f52d75 commit 57e0512

File tree

4 files changed

+42
-58
lines changed

4 files changed

+42
-58
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.16.1"
3+
version = "0.16.2"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -44,7 +44,7 @@ DataStructures = "0.18"
4444
Distributions = "0.23.3, 0.24, 0.25"
4545
DistributionsAD = "0.6"
4646
DocStringExtensions = "0.8"
47-
DynamicPPL = "0.11.0"
47+
DynamicPPL = "0.12.1"
4848
EllipticalSliceSampling = "0.4"
4949
ForwardDiff = "0.10.3"
5050
Libtask = "= 0.4.0, = 0.4.1, = 0.4.2, = 0.5.0, = 0.5.1"

src/inference/ess.jl

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,26 +137,27 @@ function (ℓ::ESSLogLikelihood)(f)
137137
return getlogp(varinfo)
138138
end
139139

140-
function DynamicPPL.tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi)
140+
function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn, inds, vi)
141141
if inspace(vn, sampler)
142-
return DynamicPPL.tilde(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi)
142+
return DynamicPPL.tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi)
143143
else
144-
return DynamicPPL.tilde(rng, ctx, SampleFromPrior(), right, vn, inds, vi)
144+
return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, inds, vi)
145145
end
146146
end
147147

148-
function DynamicPPL.tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
149-
return DynamicPPL.tilde(ctx, SampleFromPrior(), right, left, vi)
148+
function DynamicPPL.tilde_observe(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
149+
return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi)
150150
end
151151

152-
function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi)
153-
if inspace(vn, sampler)
154-
return DynamicPPL.dot_tilde(rng, LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi)
152+
function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vns, inds, vi)
153+
# TODO: Or should we do `all(Base.Fix2(inspace, sampler), vns)`?
154+
if inspace(first(vns), sampler)
155+
return DynamicPPL.dot_tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, inds, vi)
155156
else
156-
return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vn, inds, vi)
157+
return DynamicPPL.dot_tilde_assume(rng, ctx, SampleFromPrior(), right, left, vns, inds, vi)
157158
end
158159
end
159160

160-
function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
161-
return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi)
161+
function DynamicPPL.dot_tilde_observe(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi)
162+
return DynamicPPL.dot_tilde_observe(ctx, SampleFromPrior(), right, left, vi)
162163
end

src/modes/ModeEstimation.jl

Lines changed: 27 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import ..AbstractMCMC: AbstractSampler
66
import ..DynamicPPL
77
import ..DynamicPPL: Model, AbstractContext, VarInfo, AbstractContext, VarName,
88
_getindex, getsym, getfield, settrans!, setorder!,
9-
get_and_set_val!, istrans, tilde, dot_tilde, get_vns_and_dist
9+
get_and_set_val!, istrans
1010
import .Optim
1111
import .Optim: optimize
1212
import ..ForwardDiff
@@ -29,86 +29,69 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext
2929
end
3030

3131
# assume
32-
function DynamicPPL.tilde(rng, ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi)
33-
return DynamicPPL.tilde(ctx, spl, dist, vn, inds, vi)
32+
function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, spl, dist, vn, inds, vi)
33+
return DynamicPPL.tilde_assume(ctx, spl, dist, vn, inds, vi)
3434
end
3535

36-
function DynamicPPL.tilde(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn::VarName, inds, vi)
36+
function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn, inds, vi)
3737
r = vi[vn]
3838
return r, 0
3939
end
4040

41-
function DynamicPPL.tilde(ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi)
41+
function DynamicPPL.tilde_assume(ctx::OptimizationContext, spl, dist, vn, inds, vi)
4242
r = vi[vn]
4343
return r, Distributions.logpdf(dist, r)
4444
end
4545

4646

4747
# observe
48-
function DynamicPPL.tilde(rng, ctx::OptimizationContext, sampler, right, left, vi)
49-
return DynamicPPL.tilde(ctx, sampler, right, left, vi)
48+
function DynamicPPL.tilde_observe(ctx::OptimizationContext, sampler, right, left, vi)
49+
return DynamicPPL.observe(right, left, vi)
5050
end
5151

52-
function DynamicPPL.tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
52+
function DynamicPPL.tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
5353
return 0
5454
end
5555

56-
function DynamicPPL.tilde(ctx::OptimizationContext, sampler, dist, value, vi)
57-
return Distributions.logpdf(dist, value)
58-
end
59-
6056
# dot assume
61-
function DynamicPPL.dot_tilde(rng, ctx::OptimizationContext, sampler, right, left, vn::VarName, inds, vi)
62-
return DynamicPPL.dot_tilde(ctx, sampler, right, left, vn, inds, vi)
57+
function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, sampler, right, left, vns, inds, vi)
58+
return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, inds, vi)
6359
end
6460

65-
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vn::VarName, _, vi)
66-
vns, dist = get_vns_and_dist(right, left, vn)
67-
r = getval(vi, vns)
61+
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, sampler::SampleFromPrior, right, left, vns, _, vi)
62+
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
63+
# affect anything.
64+
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
6865
return r, 0
6966
end
7067

71-
function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn::VarName, _, vi)
72-
vns, dist = get_vns_and_dist(right, left, vn)
73-
r = getval(vi, vns)
74-
return r, loglikelihood(dist, r)
68+
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi)
69+
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
70+
# affect anything.
71+
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
72+
return r, loglikelihood(right, r)
7573
end
7674

7775
# dot observe
78-
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi)
76+
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi)
7977
return 0
8078
end
8179

82-
function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
80+
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi)
8381
return 0
8482
end
8583

86-
function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn, _, vi)
87-
vns, dist = get_vns_and_dist(right, left, vn)
88-
r = getval(vi, vns)
89-
return loglikelihood(dist, r)
84+
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi)
85+
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
86+
# affect anything.
87+
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
88+
return loglikelihood(right, r)
9089
end
9190

92-
function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, dists, value, vi)
91+
function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, dists, value, vi)
9392
return sum(Distributions.logpdf.(dists, value))
9493
end
9594

96-
function getval(
97-
vi,
98-
vns::AbstractVector{<:VarName},
99-
)
100-
r = vi[vns]
101-
return r
102-
end
103-
104-
function getval(
105-
vi,
106-
vns::AbstractArray{<:VarName},
107-
)
108-
r = reshape(vi[vec(vns)], size(vns))
109-
return r
110-
end
111-
11295
"""
11396
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}
11497

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ CmdStan = "6.0.8"
3737
Distributions = "0.23.8, 0.24, 0.25"
3838
DistributionsAD = "0.6.3"
3939
DynamicHMC = "2.1.6, 3.0"
40-
DynamicPPL = "0.11.0"
40+
DynamicPPL = "0.12"
4141
FiniteDifferences = "0.10.8, 0.11, 0.12"
4242
ForwardDiff = "0.10.12"
4343
MCMCChains = "4.0.4"

0 commit comments

Comments
 (0)