Skip to content

[WIP] Move from implicitly mapped measures and kernels to data tagged as mapped #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion ext/MeasureBaseChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,16 @@ ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checke

# = return type inference ====================================================

using MeasureBase: logdensityof_rt
using MeasureBase: logdensityof_rt, strict_logdensityof_rt

_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(logdensityof_rt), target, v)
logdensityof_rt(target, v), _logdensityof_rt_pullback
end

_strict_logdensityof_rt_pullback(::Any) = (NoTangent(), NoTangent(), ZeroTangent())
function ChainRulesCore.rrule(::typeof(strict_logdensityof_rt), target, v)
strict_logdensityof_rt(target, v), _strict_logdensityof_rt_pullback
end

end # module MeasureBaseChainRulesCoreExt
4 changes: 2 additions & 2 deletions src/combinators/half.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ function Base.rand(rng::AbstractRNG, ::Type{T}, μ::Half) where {T}
return abs(rand(rng, T, unhalf(μ)))
end

function logdensityof(μ::Half, x)
ld = logdensityof(unhalf(μ), x) - loghalf
function strict_logdensityof(μ::Half, x)
ld = strict_logdensityof(unhalf(μ), x) - loghalf
Comment on lines +22 to +23
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to have another look at this - should it just be logdensity_def?

If we need to merge before this is resolved, let's add a #TODO comment

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is part of specializing logdensityof (e242314) to avoid the multi-step logdensity_def machinery if users don't need logdensity_rel (until we revamp that machinery to make it more type-stable and Zygote-friendly).

return x ≥ 0 ? ld : oftype(ld, -Inf)
end

Expand Down
11 changes: 7 additions & 4 deletions src/combinators/likelihood.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
export AbstractLikelihood, Likelihood

abstract type AbstractLikelihood end
(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x))

DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity()

Base.:∘(::typeof(log), lik::AbstractLikelihood) = logdensityof(lik)

# @inline function logdensityof(ℓ::AbstractLikelihood, p)
# t() = dynamic(unsafe_logdensityof(ℓ, p))
Expand All @@ -11,6 +16,7 @@ abstract type AbstractLikelihood end
# insupport(ℓ::AbstractLikelihood, p) = insupport(ℓ.k(p), ℓ.x)

@doc raw"""
Likelihood(k::Base.Callable, x)
Likelihood(k::AbstractTransitionKernel, x)

"Observe" a value `x`, yielding a function from the parameters to ℝ.
Expand Down Expand Up @@ -117,14 +123,11 @@ struct Likelihood{K,X} <: AbstractLikelihood
x::X

Likelihood(k::K, x::X) where {K<:AbstractTransitionKernel,X} = new{K,X}(k, x)
Likelihood(::Type{K}, x::X) where {K,X} = new{Type{K},X}(K, x)
Likelihood(k::K, x::X) where {K<:Function,X} = new{K,X}(k, x)
Likelihood(μ, x) = Likelihood(kernel(μ), x)
end

(lik::AbstractLikelihood)(p) = exp(ULogarithmic, logdensityof(lik.k(p), lik.x))

DensityInterface.DensityKind(::AbstractLikelihood) = IsDensity()

function Pretty.quoteof(ℓ::Likelihood)
k = Pretty.quoteof(ℓ.k)
x = Pretty.quoteof(ℓ.x)
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/power.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ params(d::PowerMeasure) = params(first(marginals(d)))
basemeasure(d.parent)^d.axes
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func(d::PowerMeasure{M}, x) where {M}
parent = d.parent
sum(x) do xj
Expand Down
6 changes: 3 additions & 3 deletions src/combinators/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function _rand_product(
end |> collect
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func(d::AbstractProductMeasure, x)
mapreduce($func, +, marginals(d), x)
end
Expand All @@ -82,7 +82,7 @@ struct ProductMeasure{M} <: AbstractProductMeasure
marginals::M
end

@inline function logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x)
@inline function strict_logdensity_rel(μ::ProductMeasure, ν::ProductMeasure, x)
mapreduce(logdensity_rel, +, marginals(μ), marginals(ν), x)
end

Expand All @@ -109,7 +109,7 @@ end
return q
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
# For tuples, `mapreduce` has trouble with type inference
@eval @inline function $func(d::ProductMeasure{T}, x) where {T<:Tuple}
ℓs = map($func, marginals(d), x)
Expand Down
2 changes: 1 addition & 1 deletion src/combinators/spikemixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end
SpikeMixture(basemeasure(μ.m), static(1.0), static(1.0))
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval @inline function $func(μ::SpikeMixture, x)
# NOTE: We could instead write this as
# R1 = typeof(log(one(μ.s)))
Expand Down
67 changes: 51 additions & 16 deletions src/combinators/transformedmeasure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ function Pretty.tile(ν::PushforwardMeasure)
end

# TODO: THIS IS ALMOST CERTAINLY WRONG
# @inline function logdensity_rel(
# @inline function strict_logdensity_rel(
# ν::PushforwardMeasure{FF1,IF1,M1,<:AdaptRootMeasure},
# β::PushforwardMeasure{FF2,IF2,M2,<:AdaptRootMeasure},
# y,
# ) where {FF1,IF1,M1,FF2,IF2,M2}
# x = β.inv_f(y)
# f = ν.inv_f ∘ β.f
# inv_f = β.inv_f ∘ ν.f
# logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
# strict_logdensity_rel(pushfwd(f, inv_f, ν.origin, AdaptRootMeasure()), β.origin, x)
# end

# TODO: Would profit from custom pullback:
Expand All @@ -132,7 +132,7 @@ function _combine_logd_with_ladj(logd_orig::Real, ladj::Real)
end
end

function logdensityof(
function strict_logdensityof(
@nospecialize(μ::_NonBijectivePusfwdMeasure{M,<:PushfwdRootMeasure}),
@nospecialize(v::Any)
) where {M}
Expand All @@ -143,7 +143,7 @@ function logdensityof(
)
end

function logdensityof(
function strict_logdensityof(
@nospecialize(μ::_NonBijectivePusfwdMeasure{M,<:AdaptRootMeasure}),
@nospecialize(v::Any)
) where {M}
Expand All @@ -154,7 +154,7 @@ function logdensityof(
)
end

for func in [:logdensityof, :logdensity_def]
for func in [:strict_logdensityof, :logdensity_def]
@eval function $func(ν::PushforwardMeasure{F,I,M,<:AdaptRootMeasure}, y) where {F,I,M}
f_inv = unwrap(ν.finv)
x, inv_ladj = with_logabsdet_jacobian(f_inv, y)
Expand Down Expand Up @@ -222,25 +222,52 @@ To manually specify an inverse, call
function pushfwd end
export pushfwd

@inline pushfwd(f, μ) = _pushfwd_impl(f, μ, AdaptRootMeasure())
@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl(f, μ, style)
@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl(f, μ, style)
@inline pushfwd(f, μ) = _pushfwd_impl1(f, μ, AdaptRootMeasure())
@inline pushfwd(f, μ, style::AdaptRootMeasure) = _pushfwd_impl1(f, μ, style)
@inline pushfwd(f, μ, style::PushfwdRootMeasure) = _pushfwd_impl1(f, μ, style)

_pushfwd_impl(f, μ, style) = PushforwardMeasure(f, inverse(f), μ, style)
_pushfwd_impl1(f, μ, style::PushFwdStyle) = _pushfwd_impl2(f, inverse(f), μ, style)
_pushfwd_impl1(::typeof(identity), μ, ::AdaptRootMeasure) = μ
_pushfwd_impl1(::typeof(identity), μ, ::PushfwdRootMeasure) = μ

function _pushfwd_impl(
_pushfwd_impl2(f, finv, μ, style::PushFwdStyle) = PushforwardMeasure(f, finv, μ, style)

function _pushfwd_impl2(
f,
finv,
μ::PushforwardMeasure{F,I,M,S},
style::S,
) where {F,I,M,S<:PushFwdStyle}
orig_μ = μ.origin
new_f = fcomp(f, μ.f)
new_f_inv = fcomp(μ.finv, inverse(f))
new_f_inv = fcomp(μ.finv, finv)
PushforwardMeasure(new_f, new_f_inv, orig_μ, style)
end

_pushfwd_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ
_pushfwd_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ
struct _CurriedPushfwd{F,I,S<:PushFwdStyle} <: Function
f::F
finv::I
style::S

function _CurriedPushfwd{F,I,S}(f::F, finv::I, style::S) where {F,I,S<:PushFwdStyle}
new{F,I,S}(f, finv, style)
end

function _CurriedPushfwd(f, finv, style::S) where {S<:PushFwdStyle}
new{Core.Typeof(f),Core.Typeof(finv),S}(f, finv, style)
end
end

@inline (cf::_CurriedPushfwd{F,FI})(μ) where {F,FI} =
_pushfwd_impl2(cf.f, cf.finv, μ, cf.style)

@inline pushfwd(f) = _curried_pushfwd_impl(f, AdaptRootMeasure())
@inline pushfwd(f, style::AdaptRootMeasure) = _curried_pushfwd_impl(f, style)
@inline pushfwd(f, style::PushfwdRootMeasure) = _curried_pushfwd_impl(f, style)

_curried_pushfwd_impl(f, style::PushFwdStyle) = _CurriedPushfwd(f, inverse(f), style)
@inline _curried_pushfwd_impl(::typeof(identity), ::AdaptRootMeasure) = identity
@inline _curried_pushfwd_impl(::typeof(identity), ::PushfwdRootMeasure) = identity

###############################################################################
# pullback
Expand All @@ -267,8 +294,16 @@ export pullbck
@inline pullbck(f, μ, style::AdaptRootMeasure) = _pullback_impl(f, μ, style)
@inline pullbck(f, μ, style::PushfwdRootMeasure) = _pullback_impl(f, μ, style)

function _pullback_impl(f, μ, style = AdaptRootMeasure())
pushfwd(inverse(f), μ, style)
end
_pullback_impl(f, μ, style::PushFwdStyle) = _pushfwd_impl2(inverse(f), f, μ, style)
_pullback_impl(::typeof(identity), μ, ::AdaptRootMeasure) = μ
_pullback_impl(::typeof(identity), μ, ::PushfwdRootMeasure) = μ

@inline pullbck(f) = _curried_pullbck_impl(f, AdaptRootMeasure())
@inline pullbck(f, style::AdaptRootMeasure) = _curried_pullbck_impl(f, style)
@inline pullbck(f, style::PushfwdRootMeasure) = _curried_pullbck_impl(f, style)

_curried_pullbck_impl(f, style::PushFwdStyle) = _CurriedPushfwd(inverse(f), f, style)
@inline _curried_pullbck_impl(::typeof(identity), ::AdaptRootMeasure) = identity
@inline _curried_pullbck_impl(::typeof(identity), ::PushfwdRootMeasure) = identity

@deprecate pullback(f, μ, style::PushFwdStyle = AdaptRootMeasure()) pullbck(f, μ, style)
53 changes: 44 additions & 9 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,15 @@ To compute log-density relative to `basemeasure(m)` or *define* a log-density
`logdensity_def`.

To compute a log-density relative to a specific base-measure, see
`logdensity_rel`.
`logdensity_rel`.

# Implementation

Do not specialize `logdensityof` directly for subtypes of `AbstractMeasure`,
specialize `MeasureBase.logdensity_def` and `MeasureBase.strict_logdensityof` instead.
"""
@inline function logdensityof(μ::AbstractMeasure, x)
result = dynamic(unsafe_logdensityof(μ, x))
_checksupport(insupport(μ, x), result)
@inline function logdensityof(μ::AbstractMeasure, x) #!!!!!!!!!!!!!!!!!
strict_logdensityof(μ, x)
end

@inline function logdensityof_rt(::T, ::U) where {T,U}
Expand All @@ -41,6 +45,24 @@ _checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))

export unsafe_logdensityof

"""
MeasureBase.strict_logdensityof(μ, x)

Compute the log-density of the measure `μ` at `x` relative to `rootmeasure(m)`.
In contrast to [`logdensityof(μ, x)`](@ref), this will not take implicit pushforwards
of `μ` (depending on the type of `x`) into account.
"""
function strict_logdensityof end

@inline function strict_logdensityof(μ, x)
result = dynamic(unsafe_logdensityof(μ, x))
_checksupport(insupport(μ, x), result)
end

@inline function strict_logdensityof_rt(::T, ::U) where {T,U}
Core.Compiler.return_type(strict_logdensityof, Tuple{T,U})
end

# https://discourse.julialang.org/t/counting-iterations-to-a-type-fixpoint/75876/10?u=cscherrer
"""
unsafe_logdensityof(m, x)
Expand Down Expand Up @@ -68,14 +90,27 @@ See also `logdensityof`.
end

"""
logdensity_rel(m1, m2, x)
logdensity_rel(μ, ν, x)

Compute the log-density of `m1` relative to `m2` at `x`. This function checks
whether `x` is in the support of `m1` or `m2` (or both, or neither). If `x` is
Compute the log-density of `μ` relative to `ν` at `x`. This function checks
whether `x` is in the support of `μ` or `ν` (or both, or neither). If `x` is
known to be in the support of both, it can be more efficient to call
`unsafe_logdensity_rel`.
`unsafe_logdensity_rel`.
"""
function logdensity_rel(μ, ν, x)
strict_logdensity_rel(μ, ν, x)
end

"""
@inline function logdensity_rel(μ::M, ν::N, x::X) where {M,N,X}
MeasureBase.strict_logdensity_rel(μ, ν, x)

Compute the log-density of `μ` relative to `ν` at `x`. In contrast to
[`logdensity_rel(μ, ν, x)`](@ref), this will not take implicit pushforwards
of `μ` and `ν` (depending on the type of `x`) into account.
"""
function strict_logdensity_rel end

@inline function strict_logdensity_rel(μ::M, ν::N, x::X) where {M,N,X}
T = unstatic(
promote_type(
return_type(logdensity_def, (μ, x)),
Expand Down
4 changes: 2 additions & 2 deletions src/density.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,10 @@ logdensity_def(μ::DensityMeasure, x) = logdensityof(μ.f, x)

density_def(μ::DensityMeasure, x) = densityof(μ.f, x)

function logdensityof(μ::DensityMeasure, x::Any)
function strict_logdensityof(μ::DensityMeasure, x::Any)
integrand, μ_base = μ.f, μ.base

base_logval = logdensityof(μ_base, x)
base_logval = strict_logdensityof(μ_base, x)

T = typeof(base_logval)
U = logdensityof_rt(integrand, x)
Expand Down
1 change: 1 addition & 0 deletions src/kernel.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
export AbstractTransitionKernel,
GenericTransitionKernel, TypedTransitionKernel, ParameterizedTransitionKernel

# ToDo (breaking): A transition kernel should be a Function, not an AbstractMeasure.
abstract type AbstractTransitionKernel <: AbstractMeasure end

struct GenericTransitionKernel{F} <: AbstractTransitionKernel
Expand Down
4 changes: 2 additions & 2 deletions src/primitive.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ basemeasure(μ::PrimitiveMeasure) = μ

@inline basemeasure_depth(::PrimitiveMeasure) = static(0)

@inline logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
@inline logdensityof(::PrimitiveMeasure, x) = false
@inline strict_logdensityof(::PrimitiveMeasure, x::Real) = zero(float(typeof(x)))
@inline strict_logdensityof(::PrimitiveMeasure, x) = false

logdensity_def(::PrimitiveMeasure, x) = static(0.0)

Expand Down
6 changes: 3 additions & 3 deletions src/primitives/counting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ struct Counting{T} <: AbstractMeasure
Counting(supp) = new{Core.Typeof(supp)}(supp)
end

@inline function logdensityof(μ::Counting, x::Real)
@inline function strict_logdensityof(μ::Counting, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

@inline logdensityof(μ::Counting, x) = insupport(μ, x) ? 0.0 : -Inf
@inline strict_logdensityof(μ::Counting, x) = insupport(μ, x) ? 0.0 : -Inf

@inline logdensity_def(μ::Counting, x) = logdensityof(μ, x)
@inline logdensity_def(μ::Counting, x) = strict_logdensityof(μ, x)

basemeasure(::Counting) = CountingBase()

Expand Down
4 changes: 2 additions & 2 deletions src/primitives/dirac.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ basemeasure(d::Dirac) = CountingBase()

massof(::Dirac) = static(1.0)

function logdensityof(μ::Dirac, x::Real)
function strict_logdensityof(μ::Dirac, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

logdensityof(μ::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf
strict_logdensityof(μ::Dirac, x) = insupport(μ, x) ? 0.0 : -Inf

logdensity_def(::Dirac, x::Real) = zero(float(typeof(x)))
logdensity_def(::Dirac, x) = 0.0
Expand Down
4 changes: 2 additions & 2 deletions src/primitives/lebesgue.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ insupport(μ::Lebesgue, x) = x ∈ μ.support

insupport(::Lebesgue{RealNumbers}, ::Real) = true

@inline function logdensityof(μ::Lebesgue, x::Real)
@inline function strict_logdensityof(μ::Lebesgue, x::Real)
R = float(typeof(x))
insupport(μ, x) ? zero(R) : R(-Inf)
end

@inline logdensityof(μ::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf
@inline strict_logdensityof(μ::Lebesgue, x) = insupport(μ, x) ? 0.0 : -Inf

massof(::Lebesgue{RealNumbers}, s::Interval) = width(s)

Expand Down
Loading
Loading