Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ This update adds new variational inference algorithms in light of the flexibilit
Specifically, the following measure-space optimization algorithms have been added:

- `KLMinWassFwdBwd`
- `KLMinNaturalGradDescent`
- `KLMinSqrtNaturalGradDescent`

# Release 0.5

Expand Down
7 changes: 5 additions & 2 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,13 @@ include("algorithms/common.jl")

export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI

# Other Algorithms
# Natural and Wasserstein gradient descent algorithms

include("algorithms/gauss_expected_grad_hess.jl")
include("algorithms/klminwassfwdbwd.jl")
include("algorithms/klminsqrtnaturalgraddescent.jl")
include("algorithms/klminnaturalgraddescent.jl")

export KLMinWassFwdBwd
export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent

end
55 changes: 55 additions & 0 deletions src/algorithms/gauss_expected_grad_hess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@

"""
gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob)

Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe break this into multiple lines?


# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over.
- `n_samples::Int`: Number of samples used for estimation.
- `grad_buf::AbstractVector`: Buffer for the gradient estimate.
- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate.
- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation.
"""
function gaussian_expectation_gradient_and_hessian!(
rng::Random.AbstractRNG,
q::MvLocationScale{<:LowerTriangular,<:Normal,L},
n_samples::Int,
grad_buf::AbstractVector{T},
hess_buf::AbstractMatrix{T},
prob,
) where {T<:Real,L}
logπ_avg = zero(T)
fill!(grad_buf, zero(T))
fill!(hess_buf, zero(T))

if LogDensityProblems.capabilities(typeof(prob)) ≤
LogDensityProblems.LogDensityOrder{1}()
# Use Stein's identity
d = LogDensityProblems.dimension(prob)
u = randn(rng, T, d, n_samples)
z = q.scale*u .+ q.location
for b in 1:n_samples
zb, ub = view(z, :, b), view(u, :, b)
logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb)
logπ_avg += logπ/n_samples
grad_buf += ∇logπ/n_samples
hess_buf += ub*(∇logπ/n_samples)'
end
return logπ_avg, grad_buf, hess_buf
Comment on lines +29 to +40
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Use Stein's identity
d = LogDensityProblems.dimension(prob)
u = randn(rng, T, d, n_samples)
z = q.scale*u .+ q.location
for b in 1:n_samples
zb, ub = view(z, :, b), view(u, :, b)
logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb)
logπ_avg += logπ/n_samples
grad_buf += ∇logπ/n_samples
hess_buf += ub*(∇logπ/n_samples)'
end
return logπ_avg, grad_buf, hess_buf
# First-order-only: use Stein/Price identity.
# Draw u ~ N(0, I), z = m + L u with L = q.scale (lower-triangular).
# Accumulate A = E[ ∇ log π(z) uᵀ ], then map back: H = A / L.
d = LogDensityProblems.dimension(prob)
u = randn(rng, T, d, n_samples)
z = q.scale * u .+ q.location
for b in 1:n_samples
zb, ub = view(z, :, b), view(u, :, b)
logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb)
logπ_avg += logπ / n_samples
grad_buf .+= ∇logπ / n_samples
@inbounds hess_buf .+= (∇logπ / n_samples) * ub'
end
# Right triangular solve by L to obtain the Hessian in z-coordinates; symmetrize.
A_div = hess_buf / LowerTriangular(q.scale)
@inbounds hess_buf .= (A_div + A_div') / 2
return logπ_avg, grad_buf, hess_buf

I am super fluent in stein's identity. LLM says

Hessian Estimation in Whitened Coordinates

The current fallback accumulates a cross-moment in whitened coordinates and returns it as the "Hessian." With $z = m + L u$ (where $u \sim \mathcal{N}(0, I)$), the code forms something like $\mathbb{E}[u g(z)^\top]$ (or its transpose), but this is not $\mathbb{E}[\nabla^2_z \log \pi(z)]$.

The Correct Relationship

By Stein/Price identity, the correct relationship is:

$$ \mathbb{E}[\nabla^2_z \log \pi(z)] = \mathbb{E}[\nabla_z \log \pi(z) u^\top] L^{-1} $$

(or equivalently $L^{-\top} \mathbb{E}[u g(z)^\top]$ if you accumulate the transposed product).

Without the $L^{-1}$ (or $L^{-\top}$) mapping, you stay in $u$-space and feed the optimizer curvature with the wrong scale/rotation.

The Fix

Keep the same samples $u$ and $z = m + L u$, compute $g = \nabla_z \log \pi(z)$, average

$$ A = \frac{1}{B} \sum g u^\top $$

over samples, and then map back to $z$-coordinates with a single right triangular solve:

$$ \hat{H} = A L^{-1} $$

Symmetrizing $\hat{H}$ by $\frac{1}{2}(\hat{H} + \hat{H}^\top)$ reduces Monte Carlo asymmetry.

This yields an unbiased estimator for $\mathbb{E}_q[\nabla^2_z \log \pi(z)]$ in the correct coordinates.

else
# Use sample average of the Hessian.
z = rand(rng, q, n_samples)
for b in 1:n_samples
zb = view(z, :, b)
logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian(
prob, zb
)
logπ_avg += logπ/n_samples
grad_buf += ∇logπ/n_samples
hess_buf += ∇2logπ/n_samples
end
return logπ_avg, grad_buf, hess_buf
end
end
173 changes: 173 additions & 0 deletions src/algorithms/klminnaturalgraddescent.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@

"""
KLMinNaturalGradDescent(stepsize, ensure_posdef, n_samples, subsampling)
KLMinNaturalGradDescent(; stepsize, ensure_posdef, n_samples, subsampling)

KL divergence minimization by running natural gradient descent[^KL2017][^KR2023], also called variational online Newton.
This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence.

If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness.
This, however, involves an additional set of matrix-matrix system solves that could be costly.

Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity.

# (Keyword) Arguments
- `stepsize::Float64`: Step size.
- `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`)
- `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`)
- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy.

!!! note
The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinNaturalGradDescent` does not support amortization or structured variational families.

# Output
- `q`: The last iterate of the algorithm.

# Callback Signature
The `callback` function supplied to `optimize` needs to have the following signature:

callback(; rng, iteration, q, info)

The keyword arguments are as follows:
- `rng`: Random number generator internally used by the algorithm.
- `iteration`: The index of the current iteration.
- `q`: Current variational approximation.
- `info`: `NamedTuple` containing the information generated during the current iteration.

# Requirements
- The variational family is [`FullRankGaussian`](@ref FullRankGaussian).
- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$).
- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability.
"""
@kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <:
AbstractVariationalAlgorithm
stepsize::Float64
ensure_posdef::Bool = true
n_samples::Int = 1
subsampling::Sub = nothing
end

struct KLMinNaturalGradDescentState{Q,P,S,Prec,GradBuf,HessBuf}
q::Q
prob::P
prec::Prec
iteration::Int
sub_st::S
grad_buf::GradBuf
hess_buf::HessBuf
end

function init(
rng::Random.AbstractRNG,
alg::KLMinNaturalGradDescent,
q_init::MvLocationScale{<:LowerTriangular,<:Normal,L},
prob,
) where {L}
sub = alg.subsampling
n_dims = LogDensityProblems.dimension(prob)
capability = LogDensityProblems.capabilities(typeof(prob))
if capability < LogDensityProblems.LogDensityOrder{1}()
throw(
ArgumentError(
"`KLMinNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).",
),
)
end
sub_st = isnothing(sub) ? nothing : init(rng, sub)
grad_buf = Vector{eltype(q_init.location)}(undef, n_dims)
hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims)
return KLMinNaturalGradDescentState(
q_init, prob, cov(q_init), 0, sub_st, grad_buf, hess_buf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we inverse cov(q_init) to get the precision matrix?

)
end

output(::KLMinNaturalGradDescent, state) = state.q

function step(
rng::Random.AbstractRNG,
alg::KLMinNaturalGradDescent,
state,
callback,
objargs...;
kwargs...,
)
(; ensure_posdef, n_samples, stepsize, subsampling) = alg
(; q, prob, prec, iteration, sub_st, grad_buf, hess_buf) = state

m = mean(q)
S = prec
η = convert(eltype(m), stepsize)
iteration += 1

# Maybe apply subsampling
prob_sub, sub_st′, sub_inf = if isnothing(subsampling)
prob, sub_st, NamedTuple()
else
batch, sub_st′, sub_inf = step(rng, subsampling, sub_st)
prob_sub = subsample(prob, batch)
prob_sub, sub_st′, sub_inf
end

logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!(
rng, q, n_samples, grad_buf, hess_buf, prob_sub
)

S′ = Hermitian(((1 - η) * S + η * (-hess_buf)))
if ensure_posdef
G_hat = S - (-hess_buf)
S′ += η^2 / 2 * Hermitian(G_hat * (S′ \ G_hat))
end
m′ = m - η * (S′ \ (-grad_buf))

q′ = MvLocationScale(m′, inv(cholesky(S′).L), q.dist)

state = KLMinNaturalGradDescentState(
q′, prob, S′, iteration, sub_st′, grad_buf, hess_buf
)
elbo = logπ_avg + entropy(q′)
info = merge((elbo=elbo,), sub_inf)

if !isnothing(callback)
info′ = callback(; rng, iteration, q, info)
info = !isnothing(info′) ? merge(info′, info) : info
end
state, false, info
end

"""
estimate_objective([rng,] alg, q, prob; n_samples)

Estimate the ELBO of the variational approximation `q` against the target log-density `prob`.

# Arguments
- `rng::Random.AbstractRNG`: Random number generator.
- `alg::KLMinNaturalGradDescent`: Variational inference algorithm.
- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation.
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.

# Keyword Arguments
- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.)

# Returns
- `obj_est`: Estimate of the objective value.
"""
function estimate_objective(
rng::Random.AbstractRNG,
alg::KLMinNaturalGradDescent,
q::MvLocationScale{S,<:Normal,L},
prob;
n_samples::Int=alg.n_samples,
) where {S,L}
obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy())
if isnothing(alg.subsampling)
return estimate_objective(rng, obj, q, prob)
else
sub = alg.subsampling
sub_st = init(rng, sub)
return mapreduce(+, 1:length(sub)) do _
batch, sub_st, _ = step(rng, sub, sub_st)
prob_sub = subsample(prob, batch)
estimate_objective(rng, obj, q, prob_sub) / length(sub)
end
end
end
Loading
Loading