-
Notifications
You must be signed in to change notification settings - Fork 19
Add natural gradient variational inference algorithms #211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1980f93
c84b453
48daaa0
3483e8d
8267a98
f3790c3
3ba8401
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,55 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Arguments | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `rng::Random.AbstractRNG`: Random number generator. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `n_samples::Int`: Number of samples used for estimation. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `grad_buf::AbstractVector`: Buffer for the gradient estimate. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| - `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| function gaussian_expectation_gradient_and_hessian!( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| rng::Random.AbstractRNG, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| q::MvLocationScale{<:LowerTriangular,<:Normal,L}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| n_samples::Int, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_buf::AbstractVector{T}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hess_buf::AbstractMatrix{T}, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prob, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) where {T<:Real,L} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logπ_avg = zero(T) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fill!(grad_buf, zero(T)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| fill!(hess_buf, zero(T)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if LogDensityProblems.capabilities(typeof(prob)) ≤ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| LogDensityProblems.LogDensityOrder{1}() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use Stein's identity | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| d = LogDensityProblems.dimension(prob) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| u = randn(rng, T, d, n_samples) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| z = q.scale*u .+ q.location | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for b in 1:n_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| zb, ub = view(z, :, b), view(u, :, b) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logπ_avg += logπ/n_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_buf += ∇logπ/n_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hess_buf += ub*(∇logπ/n_samples)' | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return logπ_avg, grad_buf, hess_buf | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+29
to
+40
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I am super fluent in stein's identity. LLM says
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Use sample average of the Hessian. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| z = rand(rng, q, n_samples) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for b in 1:n_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| zb = view(z, :, b) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prob, zb | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logπ_avg += logπ/n_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| grad_buf += ∇logπ/n_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| hess_buf += ∇2logπ/n_samples | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return logπ_avg, grad_buf, hess_buf | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| end | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,173 @@ | ||
|
|
||
| """ | ||
| KLMinNaturalGradDescent(stepsize, ensure_posdef, n_samples, subsampling) | ||
| KLMinNaturalGradDescent(; stepsize, ensure_posdef, n_samples, subsampling) | ||
|
|
||
| KL divergence minimization by running natural gradient descent[^KL2017][^KR2023], also called variational online Newton. | ||
| This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence. | ||
|
|
||
| If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness. | ||
| This, however, involves an additional set of matrix-matrix system solves that could be costly. | ||
|
|
||
| Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. | ||
|
|
||
| # (Keyword) Arguments | ||
| - `stepsize::Float64`: Step size. | ||
| - `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`) | ||
| - `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`) | ||
| - `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. | ||
|
|
||
| !!! note | ||
| The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinNaturalGradDescent` does not support amortization or structured variational families. | ||
|
|
||
| # Output | ||
| - `q`: The last iterate of the algorithm. | ||
|
|
||
| # Callback Signature | ||
| The `callback` function supplied to `optimize` needs to have the following signature: | ||
|
|
||
| callback(; rng, iteration, q, info) | ||
|
|
||
| The keyword arguments are as follows: | ||
| - `rng`: Random number generator internally used by the algorithm. | ||
| - `iteration`: The index of the current iteration. | ||
| - `q`: Current variational approximation. | ||
| - `info`: `NamedTuple` containing the information generated during the current iteration. | ||
|
|
||
| # Requirements | ||
| - The variational family is [`FullRankGaussian`](@ref FullRankGaussian). | ||
| - The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). | ||
| - The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. | ||
| """ | ||
| @kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: | ||
| AbstractVariationalAlgorithm | ||
| stepsize::Float64 | ||
| ensure_posdef::Bool = true | ||
| n_samples::Int = 1 | ||
| subsampling::Sub = nothing | ||
| end | ||
|
|
||
| struct KLMinNaturalGradDescentState{Q,P,S,Prec,GradBuf,HessBuf} | ||
| q::Q | ||
| prob::P | ||
| prec::Prec | ||
| iteration::Int | ||
| sub_st::S | ||
| grad_buf::GradBuf | ||
| hess_buf::HessBuf | ||
| end | ||
|
|
||
| function init( | ||
| rng::Random.AbstractRNG, | ||
| alg::KLMinNaturalGradDescent, | ||
| q_init::MvLocationScale{<:LowerTriangular,<:Normal,L}, | ||
| prob, | ||
| ) where {L} | ||
| sub = alg.subsampling | ||
| n_dims = LogDensityProblems.dimension(prob) | ||
| capability = LogDensityProblems.capabilities(typeof(prob)) | ||
| if capability < LogDensityProblems.LogDensityOrder{1}() | ||
| throw( | ||
| ArgumentError( | ||
| "`KLMinNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", | ||
| ), | ||
| ) | ||
| end | ||
| sub_st = isnothing(sub) ? nothing : init(rng, sub) | ||
| grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) | ||
| hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) | ||
| return KLMinNaturalGradDescentState( | ||
| q_init, prob, cov(q_init), 0, sub_st, grad_buf, hess_buf | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should we inverse |
||
| ) | ||
| end | ||
|
|
||
| output(::KLMinNaturalGradDescent, state) = state.q | ||
|
|
||
| function step( | ||
| rng::Random.AbstractRNG, | ||
| alg::KLMinNaturalGradDescent, | ||
| state, | ||
| callback, | ||
| objargs...; | ||
| kwargs..., | ||
| ) | ||
| (; ensure_posdef, n_samples, stepsize, subsampling) = alg | ||
| (; q, prob, prec, iteration, sub_st, grad_buf, hess_buf) = state | ||
|
|
||
| m = mean(q) | ||
| S = prec | ||
| η = convert(eltype(m), stepsize) | ||
| iteration += 1 | ||
|
|
||
| # Maybe apply subsampling | ||
| prob_sub, sub_st′, sub_inf = if isnothing(subsampling) | ||
| prob, sub_st, NamedTuple() | ||
| else | ||
| batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) | ||
| prob_sub = subsample(prob, batch) | ||
| prob_sub, sub_st′, sub_inf | ||
| end | ||
|
|
||
| logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( | ||
| rng, q, n_samples, grad_buf, hess_buf, prob_sub | ||
| ) | ||
|
|
||
| S′ = Hermitian(((1 - η) * S + η * (-hess_buf))) | ||
| if ensure_posdef | ||
| G_hat = S - (-hess_buf) | ||
| S′ += η^2 / 2 * Hermitian(G_hat * (S′ \ G_hat)) | ||
| end | ||
| m′ = m - η * (S′ \ (-grad_buf)) | ||
|
|
||
| q′ = MvLocationScale(m′, inv(cholesky(S′).L), q.dist) | ||
|
|
||
| state = KLMinNaturalGradDescentState( | ||
| q′, prob, S′, iteration, sub_st′, grad_buf, hess_buf | ||
| ) | ||
| elbo = logπ_avg + entropy(q′) | ||
| info = merge((elbo=elbo,), sub_inf) | ||
|
|
||
| if !isnothing(callback) | ||
| info′ = callback(; rng, iteration, q, info) | ||
| info = !isnothing(info′) ? merge(info′, info) : info | ||
| end | ||
| state, false, info | ||
| end | ||
|
|
||
| """ | ||
| estimate_objective([rng,] alg, q, prob; n_samples) | ||
|
|
||
| Estimate the ELBO of the variational approximation `q` against the target log-density `prob`. | ||
|
|
||
| # Arguments | ||
| - `rng::Random.AbstractRNG`: Random number generator. | ||
| - `alg::KLMinNaturalGradDescent`: Variational inference algorithm. | ||
| - `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. | ||
| - `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. | ||
|
|
||
| # Keyword Arguments | ||
| - `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) | ||
|
|
||
| # Returns | ||
| - `obj_est`: Estimate of the objective value. | ||
| """ | ||
| function estimate_objective( | ||
| rng::Random.AbstractRNG, | ||
| alg::KLMinNaturalGradDescent, | ||
| q::MvLocationScale{S,<:Normal,L}, | ||
| prob; | ||
| n_samples::Int=alg.n_samples, | ||
| ) where {S,L} | ||
| obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) | ||
| if isnothing(alg.subsampling) | ||
| return estimate_objective(rng, obj, q, prob) | ||
| else | ||
| sub = alg.subsampling | ||
| sub_st = init(rng, sub) | ||
| return mapreduce(+, 1:length(sub)) do _ | ||
| batch, sub_st, _ = step(rng, sub, sub_st) | ||
| prob_sub = subsample(prob, batch) | ||
| estimate_objective(rng, obj, q, prob_sub) / length(sub) | ||
| end | ||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe break this into multiple lines?