Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
affc71d
initial work on using less of sampler and more of context
torfjelde Jan 13, 2023
e5ff14f
fixed optim and ad
torfjelde Jan 13, 2023
cf295f1
fixed ESS and MH
torfjelde Jan 13, 2023
2024348
bump version
torfjelde Jan 13, 2023
4e38fd0
fixed typo
torfjelde Jan 13, 2023
c986828
moved all LogDensityProblems related to DPPL
torfjelde Jan 14, 2023
67fda5e
Apply suggestions from code review
torfjelde Jan 14, 2023
8833092
Merge branch 'master' into torfjelde/less-sampler-more-context
torfjelde Jan 14, 2023
d43d6fb
Merge branch 'torfjelde/less-sampler-more-context' of github.com:Turi…
torfjelde Jan 14, 2023
093cd40
Apply suggestions from code review
torfjelde Jan 14, 2023
260b498
make compat bounds correct
torfjelde Jan 15, 2023
3505c58
fixed bug in MH
torfjelde Jan 15, 2023
21e922b
fixed MH again
torfjelde Jan 15, 2023
65bb1e2
fixed Emcee
torfjelde Jan 15, 2023
edf172a
fixed emcee
torfjelde Jan 16, 2023
6e5d304
fixed broken tests and removed mentionings of sampler in optimization
torfjelde Jan 16, 2023
88d53d7
fixed bug in optim
torfjelde Jan 16, 2023
b19f0aa
added tests for demo models
torfjelde Jan 16, 2023
06875d5
fixed missing support for certain models in optim
torfjelde Jan 16, 2023
6cee95e
fixed type
torfjelde Jan 16, 2023
5ec0a0d
fix
torfjelde Jan 16, 2023
596b0b2
use LogDensityModel instead of wrapping in Base.Fix
torfjelde Jan 16, 2023
9ebdf77
Revert "use LogDensityModel instead of wrapping in Base.Fix"
torfjelde Jan 16, 2023
b8542e2
disable failing unsupported models for reverse mode AD frameworks
torfjelde Jan 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ DataStructures = "0.18"
Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicPPL = "0.21"
DynamicPPL = "0.21.5"
EllipticalSliceSampling = "0.5, 1"
ForwardDiff = "0.10.3"
Libtask = "0.7, 0.8"
Expand Down
22 changes: 2 additions & 20 deletions src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using Libtask
using Tracker: Tracker

import AdvancedVI
using DynamicPPL: DynamicPPL, LogDensityFunction
import DynamicPPL: getspace, NoDist, NamedDist
import LogDensityProblems
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this can be removed now?

Suggested change
import LogDensityProblems

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need it for all the optimation related stuff.

import Random
Expand All @@ -26,26 +27,6 @@ function setprogress!(progress::Bool)
return progress
end

# Log density function
struct LogDensityFunction{V,M,S,C}
varinfo::V
model::M
sampler::S
context::C
end

function (f::LogDensityFunction)(θ::AbstractVector)
vi_new = DynamicPPL.unflatten(f.varinfo, f.sampler, θ)
return getlogp(last(DynamicPPL.evaluate!!(f.model, vi_new, f.sampler, f.context)))
end

# LogDensityProblems interface
LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) = f(θ)
LogDensityProblems.dimension(f::LogDensityFunction) = length(f.varinfo[f.sampler])
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
return LogDensityProblems.LogDensityOrder{0}()
end

# Standard tag: Improves stacktraces
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct TuringTag end
Expand Down Expand Up @@ -154,6 +135,7 @@ export @model, # modelling
generated_quantities,
logprior,
logjoint,
LogDensityFunction,

constrained_space, # optimisation interface
MAP,
Expand Down
9 changes: 7 additions & 2 deletions src/essential/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,18 @@ Find the autodifferentiation backend of the algorithm `alg`.
"""
getADbackend(spl::Sampler) = getADbackend(spl.alg)
getADbackend(::SampleFromPrior) = ADBackend()()
Comment on lines 78 to 79
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these still needed? Or could we remove them and specialize getADbackend(::SamplingContext{<:Sampler}) and getADbackend(::SamplingContext{SampleFromPrior})? I think mentally it's easier if getADbackend only operates on contexts.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I can try doing that! I was worried this would break too many tests and discourage me from completing the PR, but now that the tests are passing I can see if it's worth it.

getADbackend(ctx::DynamicPPL.SamplingContext) = getADbackend(ctx.sampler)
getADbackend(ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.NodeTrait(ctx), ctx)

getADbackend(::DynamicPPL.IsLeaf, ctx::DynamicPPL.AbstractContext) = ADBackend()()
getADbackend(::DynamicPPL.IsParent, ctx::DynamicPPL.AbstractContext) = getADbackend(DynamicPPL.childcontext(ctx))

function LogDensityProblemsAD.ADgradient(ℓ::Turing.LogDensityFunction)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most codes in this ad file can be safely transferred to DynamicPPL too. One benefit is that users of DynamicPPL get the gradient feature of LogDensityFunction without loading the Turing package.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, though I think I'll leave it here for the time being just because I want to keep the PRs simple.

return LogDensityProblemsAD.ADgradient(getADbackend(ℓ.sampler), ℓ)
return LogDensityProblemsAD.ADgradient(getADbackend(ℓ.context), ℓ)
end

function LogDensityProblemsAD.ADgradient(ad::ForwardDiffAD, ℓ::Turing.LogDensityFunction)
θ = ℓ.varinfo[ℓ.sampler]
θ = DynamicPPL.getparams(ℓ)
f = Base.Fix1(LogDensityProblems.logdensity, ℓ)

# Define configuration for ForwardDiff.
Expand Down
4 changes: 3 additions & 1 deletion src/inference/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ function AbstractMCMC.step(
)
# Generate a log joint function.
vi = state.vi
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, SampleFromPrior(), DynamicPPL.DefaultContext()))
densitymodel = AMH.DensityModel(
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(model, vi))
)

# Compute the next states.
states = last(AbstractMCMC.step(rng, densitymodel, spl.alg.ensemble, state.states))
Expand Down
4 changes: 2 additions & 2 deletions src/inference/ess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,10 @@ end
Distributions.mean(p::ESSPrior) = p.μ

# Evaluate log-likelihood of proposals
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext()}
const ESSLogLikelihood{M<:Model,S<:Sampler{<:ESS},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.SamplingContext{<:S}}

function (ℓ::ESSLogLikelihood)(f::AbstractVector)
sampler = ℓ.sampler
sampler = DynamicPPL.getsampler(ℓ)
varinfo = setindex!!(ℓ.varinfo, f, sampler)
varinfo = last(DynamicPPL.evaluate!!(ℓ.model, varinfo, sampler))
return getlogp(varinfo)
Expand Down
14 changes: 9 additions & 5 deletions src/inference/mh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,11 +246,11 @@ A log density function for the MH sampler.

This variant uses the `set_namedtuple!` function to update the `VarInfo`.
"""
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,S,DynamicPPL.DefaultContext}
const MHLogDensityFunction{M<:Model,S<:Sampler{<:MH},V<:AbstractVarInfo} = Turing.LogDensityFunction{V,M,<:DynamicPPL.SamplingContext{<:S}}

function (f::MHLogDensityFunction)(x::NamedTuple)
function LogDensityProblems.logdensity(f::MHLogDensityFunction, x::NamedTuple)
# TODO: Make this work with immutable `f.varinfo` too.
sampler = f.sampler
sampler = DynamicPPL.getsampler(f)
vi = f.varinfo

x_old, lj_old = vi[sampler], getlogp(vi)
Expand Down Expand Up @@ -374,7 +374,9 @@ function propose!!(
prev_trans = AMH.Transition(vt, getlogp(vi))

# Make a new transition.
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
densitymodel = AMH.DensityModel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it easier to use a LogDensityModel? Then you don't need Fix1.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh true!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, does that work though? AFAIK we don't support wrapping something that implements LogDensityModel in AMH.DensityModel?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to wrap it in a DensityModel, do we? step seem to support LogDensityModels: https://github.com/TuringLang/AdvancedMH.jl/blob/1638f068261f936141404e232dd7f099b2cdde95/src/mh-core.jl#L74-L113

Copy link
Member Author

@torfjelde torfjelde Jan 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This didn't work unfortunately. It seems we're hitting https://github.com/TuringLang/AdvancedMH.jl/blob/1638f068261f936141404e232dd7f099b2cdde95/src/AdvancedMH.jl#L69 instead of https://github.com/TuringLang/AdvancedMH.jl/blob/1638f068261f936141404e232dd7f099b2cdde95/src/AdvancedMH.jl#L72. Also do you know the reasoning behind defining one for both DensityOrLogDensityModel and LogDensityModel? Why isn't it DensityModel and LogDensityModel separately?

EDIT: Looking at git-blame, it seems like I tried to do a replace and was a bit too aggressive 😅

Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl)))
)
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

# TODO: Make this compatible with immutable `VarInfo`.
Expand All @@ -400,7 +402,9 @@ function propose!!(
prev_trans = AMH.Transition(vals, getlogp(vi))

# Make a new transition.
densitymodel = AMH.DensityModel(Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext()))
densitymodel = AMH.DensityModel(
Base.Fix1(LogDensityProblems.logdensity, Turing.LogDensityFunction(vi, model, DynamicPPL.SamplingContext(rng, spl)))
)
trans, _ = AbstractMCMC.step(rng, densitymodel, mh_sampler, prev_trans)

return setlogp!!(DynamicPPL.unflatten(vi, spl, trans.params), trans.lp)
Expand Down
95 changes: 44 additions & 51 deletions src/modes/ModeEstimation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,45 +46,42 @@ DynamicPPL.childcontext(context::OptimizationContext) = context.context
DynamicPPL.setchildcontext(::OptimizationContext, child) = OptimizationContext(child)

# assume
function DynamicPPL.tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, spl, dist, vn, vi)
return DynamicPPL.tilde_assume(ctx, spl, dist, vn, vi)
end

function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, spl, dist, vn, vi)
r = vi[vn]
function DynamicPPL.tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, dist, vn, vi)
r = vi[vn, dist]
return r, 0, vi
end

function DynamicPPL.tilde_assume(ctx::OptimizationContext, spl, dist, vn, vi)
r = vi[vn]
function DynamicPPL.tilde_assume(ctx::OptimizationContext, dist, vn, vi)
r = vi[vn, dist]
return r, Distributions.logpdf(dist, r), vi
end

# dot assume
function DynamicPPL.dot_tilde_assume(rng::Random.AbstractRNG, ctx::OptimizationContext, sampler, right, left, vns, vi)
return DynamicPPL.dot_tilde_assume(ctx, sampler, right, left, vns, vi)
end

function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, sampler::SampleFromPrior, right, left, vns, vi)
function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext{<:LikelihoodContext}, right, left, vns, vi)
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
# affect anything.
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
# TODO: Stop using `get_and_set_val!`.
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
return r, 0, vi
end

function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, sampler::SampleFromPrior, right, left, vns, vi)
_loglikelihood(dist::Distribution, x) = loglikelihood(dist, x)
_loglikelihood(dists::AbstractArray{<:Distribution}, x) = loglikelihood(arraydist(dists), x)

function DynamicPPL.dot_tilde_assume(ctx::OptimizationContext, right, left, vns, vi)
# Values should be set and we're using `SampleFromPrior`, hence the `rng` argument shouldn't
# affect anything.
r = DynamicPPL.get_and_set_val!(Random.GLOBAL_RNG, vi, vns, right, sampler)
return r, loglikelihood(right, r), vi
# TODO: Stop using `get_and_set_val!`.
r = DynamicPPL.get_and_set_val!(Random.default_rng(), vi, vns, right, SampleFromPrior())
return r, _loglikelihood(right, r), vi
end

"""
OptimLogDensity{M<:Model,C<:Context,V<:VarInfo}

A struct that stores the negative log density function of a `DynamicPPL` model.
"""
const OptimLogDensity{M<:Model,C<:OptimizationContext,V<:VarInfo} = Turing.LogDensityFunction{V,M,DynamicPPL.SampleFromPrior,C}
const OptimLogDensity{M<:Model,C<:OptimizationContext,V<:VarInfo} = Turing.LogDensityFunction{V,M,C}

"""
OptimLogDensity(model::Model, context::OptimizationContext)
Expand All @@ -93,21 +90,23 @@ Create a callable `OptimLogDensity` struct that evaluates a model using the give
"""
function OptimLogDensity(model::Model, context::OptimizationContext)
init = VarInfo(model)
return Turing.LogDensityFunction(init, model, DynamicPPL.SampleFromPrior(), context)
return Turing.LogDensityFunction(init, model, context)
end

"""
(f::OptimLogDensity)(z)
LogDensityProblems.logdensity(f::OptimLogDensity, z)

Evaluate the negative log joint (with `DefaultContext`) or log likelihood (with `LikelihoodContext`)
at the array `z`.
"""
function (f::OptimLogDensity)(z::AbstractVector)
sampler = f.sampler
varinfo = DynamicPPL.unflatten(f.varinfo, sampler, z)
return -getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, sampler, f.context)))
varinfo = DynamicPPL.unflatten(f.varinfo, z)
return -getlogp(last(DynamicPPL.evaluate!!(f.model, varinfo, f.context)))
end

# NOTE: This seems a bit weird IMO since this is the _negative_ log-likelihood.
LogDensityProblems.logdensity(f::OptimLogDensity, z::AbstractVector) = f(z)

function (f::OptimLogDensity)(F, G, z)
if G !== nothing
# Calculate negative log joint and its gradient.
Expand All @@ -127,7 +126,7 @@ function (f::OptimLogDensity)(F, G, z)

# Only negative log joint requested but no gradient.
if F !== nothing
return f(z)
return LogDensityProblems.logdensity(f, z)
end

return nothing
Expand All @@ -140,50 +139,44 @@ end
#################################################

function transform!!(f::OptimLogDensity)
spl = f.sampler

## Check link status of vi in OptimLogDensity
linked = DynamicPPL.islinked(f.varinfo, spl)
linked = DynamicPPL.istrans(f.varinfo)

## transform into constrained or unconstrained space depending on current state of vi
@set! f.varinfo = if !linked
DynamicPPL.link!!(f.varinfo, spl, f.model)
DynamicPPL.link!!(f.varinfo, f.model)
else
DynamicPPL.invlink!!(f.varinfo, spl, f.model)
DynamicPPL.invlink!!(f.varinfo, f.model)
end

return f
end

function transform!!(p::AbstractArray, vi::DynamicPPL.VarInfo, model::DynamicPPL.Model, ::constrained_space{true})
spl = DynamicPPL.SampleFromPrior()

linked = DynamicPPL.islinked(vi, spl)
linked = DynamicPPL.istrans(vi)

!linked && return identity(p) # TODO: why do we do `identity` here?
vi = DynamicPPL.setindex!!(vi, p, spl)
vi = DynamicPPL.invlink!!(vi, spl, model)
p .= vi[spl]
vi = DynamicPPL.unflatten(vi, p)
vi = DynamicPPL.invlink!!(vi, model)
p .= vi[:]

# If linking mutated, we need to link once more.
linked && DynamicPPL.link!!(vi, spl, model)
linked && DynamicPPL.link!!(vi, model)

return p
end

function transform!!(p::AbstractArray, vi::DynamicPPL.VarInfo, model::DynamicPPL.Model, ::constrained_space{false})
spl = DynamicPPL.SampleFromPrior()

linked = DynamicPPL.islinked(vi, spl)
linked = DynamicPPL.istrans(vi)
if linked
vi = DynamicPPL.invlink!!(vi, spl, model)
vi = DynamicPPL.invlink!!(vi, model)
end
vi = DynamicPPL.setindex!!(vi, p, spl)
vi = DynamicPPL.link!!(vi, spl, model)
p .= vi[spl]
vi = DynamicPPL.unflatten(vi, p)
vi = DynamicPPL.link!!(vi, model)
p .= vi[:]

# If linking mutated, we need to link once more.
!linked && DynamicPPL.invlink!!(vi, spl, model)
!linked && DynamicPPL.invlink!!(vi, model)

return p
end
Expand All @@ -208,26 +201,26 @@ end

function (t::AbstractTransform)(p::AbstractArray)
return transform(p, t.vi, t.model, t.space)
end
end

function (t::Init)()
return t.vi[DynamicPPL.SampleFromPrior()]
end

function get_parameter_bounds(model::DynamicPPL.Model)
vi = DynamicPPL.VarInfo(model)
spl = DynamicPPL.SampleFromPrior()

## Check link status of vi
linked = DynamicPPL.islinked(vi, spl)
linked = DynamicPPL.istrans(vi)

## transform into unconstrained
if !linked
vi = DynamicPPL.link!!(vi, spl, model)
vi = DynamicPPL.link!!(vi, model)
end

lb = transform(fill(-Inf,length(vi[DynamicPPL.SampleFromPrior()])), vi, model, constrained_space{true}())
ub = transform(fill(Inf,length(vi[DynamicPPL.SampleFromPrior()])), vi, model, constrained_space{true}())

d = length(vi[:])
lb = transform(fill(-Inf, d), vi, model, constrained_space{true}())
ub = transform(fill(Inf, d), vi, model, constrained_space{true}())

return lb, ub
end
Expand Down
Loading