diff --git a/Project.toml b/Project.toml index d678f4e9cc..b950ff8ff7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.16.1" +version = "0.16.2" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -44,7 +44,7 @@ DataStructures = "0.18" Distributions = "0.23.3, 0.24, 0.25" DistributionsAD = "0.6" DocStringExtensions = "0.8" -DynamicPPL = "0.11.0" +DynamicPPL = "0.12.1" EllipticalSliceSampling = "0.4" ForwardDiff = "0.10.3" Libtask = "= 0.4.0, = 0.4.1, = 0.4.2, = 0.5.0, = 0.5.1" diff --git a/src/inference/ess.jl b/src/inference/ess.jl index eeadbc40d2..38b4d22266 100644 --- a/src/inference/ess.jl +++ b/src/inference/ess.jl @@ -137,26 +137,27 @@ function (ℓ::ESSLogLikelihood)(f) return getlogp(varinfo) end -function DynamicPPL.tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn::VarName, inds, vi) +function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, vn, inds, vi) if inspace(vn, sampler) - return DynamicPPL.tilde(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi) + return DynamicPPL.tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, vn, inds, vi) else - return DynamicPPL.tilde(rng, ctx, SampleFromPrior(), right, vn, inds, vi) + return DynamicPPL.tilde_assume(rng, ctx, SampleFromPrior(), right, vn, inds, vi) end end -function DynamicPPL.tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.tilde(ctx, SampleFromPrior(), right, left, vi) +function DynamicPPL.tilde_observe(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return DynamicPPL.tilde_observe(ctx, SampleFromPrior(), right, left, vi) end -function DynamicPPL.dot_tilde(rng, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vn::VarName, inds, vi) - if inspace(vn, sampler) - return DynamicPPL.dot_tilde(rng, LikelihoodContext(), SampleFromPrior(), right, left, vn, inds, vi) +function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vns, inds, vi) + # TODO: Or should we do `all(Base.Fix2(inspace, sampler), vns)`? + if inspace(first(vns), sampler) + return DynamicPPL.dot_tilde_assume(rng, LikelihoodContext(), SampleFromPrior(), right, left, vns, inds, vi) else - return DynamicPPL.dot_tilde(rng, ctx, SampleFromPrior(), right, left, vn, inds, vi) + return DynamicPPL.dot_tilde_assume(rng, ctx, SampleFromPrior(), right, left, vns, inds, vi) end end -function DynamicPPL.dot_tilde(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) - return DynamicPPL.dot_tilde(ctx, SampleFromPrior(), right, left, vi) +function DynamicPPL.dot_tilde_observe(ctx::DefaultContext, sampler::Sampler{<:ESS}, right, left, vi) + return DynamicPPL.dot_tilde_observe(ctx, SampleFromPrior(), right, left, vi) end diff --git a/src/modes/ModeEstimation.jl b/src/modes/ModeEstimation.jl index a4626b9667..92564770c1 100644 --- a/src/modes/ModeEstimation.jl +++ b/src/modes/ModeEstimation.jl @@ -6,7 +6,7 @@ import ..AbstractMCMC: AbstractSampler import ..DynamicPPL import ..DynamicPPL: Model, AbstractContext, VarInfo, AbstractContext, VarName, _getindex, getsym, getfield, settrans!, setorder!, - get_and_set_val!, istrans, tilde, dot_tilde, get_vns_and_dist + get_and_set_val!, istrans import .Optim import .Optim: optimize import ..ForwardDiff @@ -29,86 +29,69 @@ struct OptimizationContext{C<:AbstractContext} <: AbstractContext end # assume -function DynamicPPL.tilde(rng, ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi) - return DynamicPPL.tilde(ctx, spl, dist, vn, inds, vi) +function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, spl, dist, vn, inds, vi) + return DynamicPPL.tilde_assume(ctx, spl, dist, vn, inds, vi) end -function DynamicPPL.tilde(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn::VarName, inds, vi) +function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn, inds, vi) r = vi[vn] return r, 0 end -function DynamicPPL.tilde(ctx::OptimizationContext, spl, dist, vn::VarName, inds, vi) +function DynamicPPL.tilde_assume(ctx::OptimizationContext, spl, dist, vn, inds, vi) r = vi[vn] return r, Distributions.logpdf(dist, r) end # observe -function DynamicPPL.tilde(rng, ctx::OptimizationContext, sampler, right, left, vi) - return DynamicPPL.tilde(ctx, sampler, right, left, vi) +function DynamicPPL.tilde_observe(ctx::OptimizationContext, sampler, right, left, vi) + return DynamicPPL.observe(right, left, vi) end -function DynamicPPL.tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) +function DynamicPPL.tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) return 0 end -function DynamicPPL.tilde(ctx::OptimizationContext, sampler, dist, value, vi) - return Distributions.logpdf(dist, value) -end - # dot assume -function DynamicPPL.dot_tilde(rng, ctx::OptimizationContext, sampler, right, left, vn::VarName, inds, vi) - return DynamicPPL.dot_tilde(ctx, sampler, right, left, vn, inds, vi) +function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, sampler, right, left, vns, inds, vi) + return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, inds, vi) end -function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:LikelihoodContext}, sampler, right, left, vn::VarName, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) - r = getval(vi, vns) +function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, sampler::SampleFromPrior, right, left, vns, _, vi) + # Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't + # affect anything. + r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler) return r, 0 end -function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn::VarName, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) - r = getval(vi, vns) - return r, loglikelihood(dist, r) +function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi) + # Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't + # affect anything. + r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler) + return r, loglikelihood(right, r) end # dot observe -function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vn, _, vi) return 0 end -function DynamicPPL.dot_tilde(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext{<:PriorContext}, sampler, right, left, vi) return 0 end -function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, right, left, vn, _, vi) - vns, dist = get_vns_and_dist(right, left, vn) - r = getval(vi, vns) - return loglikelihood(dist, r) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, _, vi) + # Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't + # affect anything. + r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler) + return loglikelihood(right, r) end -function DynamicPPL.dot_tilde(ctx::OptimizationContext, sampler, dists, value, vi) +function DynamicPPL.dot_tilde_observe(ctx::OptimizationContext, sampler, dists, value, vi) return sum(Distributions.logpdf.(dists, value)) end -function getval( - vi, - vns::AbstractVector{<:VarName}, -) - r = vi[vns] - return r -end - -function getval( - vi, - vns::AbstractArray{<:VarName}, -) - r = reshape(vi[vec(vns)], size(vns)) - return r -end - """ OptimLogDensity{M<:Model,C<:Context,V<:VarInfo} diff --git a/test/Project.toml b/test/Project.toml index c1e3fe044b..afdee32170 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -37,7 +37,7 @@ CmdStan = "6.0.8" Distributions = "0.23.8, 0.24, 0.25" DistributionsAD = "0.6.3" DynamicHMC = "2.1.6, 3.0" -DynamicPPL = "0.11.0" +DynamicPPL = "0.12" FiniteDifferences = "0.10.8, 0.11, 0.12" ForwardDiff = "0.10.12" MCMCChains = "4.0.4"