diff --git a/HISTORY.md b/HISTORY.md index 131ae0e1b..13f2995e0 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -4,6 +4,8 @@ This update adds new variational inference algorithms in light of the flexibilit Specifically, the following measure-space optimization algorithms have been added: - `KLMinWassFwdBwd` + - `KLMinNaturalGradDescent` + - `KLMinSqrtNaturalGradDescent` In addition, `KLMinRepGradDescent`, `KLMinRepGradProxDescent`, `KLMinScoreGradDescent` will now throw a `RuntimException` if the objective value estimated at each step turns out to be degenerate (`Inf` or `NaN`). Previously, the algorithms ran until `max_iter` even if the optimization run has failed. diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 2b82b7b79..7d57e32a0 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -352,10 +352,13 @@ include("algorithms/common.jl") export KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent, ADVI, BBVI -# Other Algorithms +# Natural and Wasserstein gradient descent algorithms +include("algorithms/gauss_expected_grad_hess.jl") include("algorithms/klminwassfwdbwd.jl") +include("algorithms/klminsqrtnaturalgraddescent.jl") +include("algorithms/klminnaturalgraddescent.jl") -export KLMinWassFwdBwd +export KLMinWassFwdBwd, KLMinSqrtNaturalGradDescent, KLMinNaturalGradDescent end diff --git a/src/algorithms/gauss_expected_grad_hess.jl b/src/algorithms/gauss_expected_grad_hess.jl new file mode 100644 index 000000000..457439d14 --- /dev/null +++ b/src/algorithms/gauss_expected_grad_hess.jl @@ -0,0 +1,80 @@ + +""" + gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) + +Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. +For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. +Otherwise, it uses Stein's identity. + +!!! warning + The resulting `hess_buf` may not be perfectly symmetric due to numerical issues. It is therefore useful to wrap it in a `Symmetric` before usage. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over. +- `n_samples::Int`: Number of samples used for estimation. +- `grad_buf::AbstractVector`: Buffer for the gradient estimate. +- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate. +- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation. +""" +function gaussian_expectation_gradient_and_hessian!( + rng::Random.AbstractRNG, + q::MvLocationScale{<:LowerTriangular,<:Normal,L}, + n_samples::Int, + grad_buf::AbstractVector{T}, + hess_buf::AbstractMatrix{T}, + prob, +) where {T<:Real,L} + logπ_avg = zero(T) + fill!(grad_buf, zero(T)) + fill!(hess_buf, zero(T)) + + if LogDensityProblems.capabilities(typeof(prob)) ≤ + LogDensityProblems.LogDensityOrder{1}() + # First-order-only: use Stein/Price identity for the Hessian + # + # E_{z ~ N(m, CC')} ∇2 log π(z) + # = E_{z ~ N(m, CC')} (CC')^{-1} (z - m) ∇ log π(z)T + # = E_{u ~ N(0, I)} C \ (u ∇ log π(z)T) . + # + # Algorithmically, draw u ~ N(0, I), z = C u + m, where C = q.scale. + # Accumulate A = E[ u ∇ log π(z)T ], then map back: H = C \ A. + d = LogDensityProblems.dimension(prob) + u = randn(rng, T, d, n_samples) + m, C = q.location, q.scale + z = C*u .+ m + for b in 1:n_samples + zb, ub = view(z, :, b), view(u, :, b) + logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) + logπ_avg += logπ/n_samples + + rdiv!(∇logπ, n_samples) + ∇logπ_div_nsamples = ∇logπ + + grad_buf[:] .+= ∇logπ_div_nsamples + hess_buf[:, :] .+= ub*∇logπ_div_nsamples' + end + hess_buf[:, :] .= C \ hess_buf + return logπ_avg, grad_buf, hess_buf + else + # Second-order: use naive sample average + z = rand(rng, q, n_samples) + for b in 1:n_samples + zb = view(z, :, b) + logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( + prob, zb + ) + + rdiv!(∇logπ, n_samples) + ∇logπ_div_nsamples = ∇logπ + + rdiv!(∇2logπ, n_samples) + ∇2logπ_div_nsamples = ∇2logπ + + logπ_avg += logπ/n_samples + grad_buf[:] .+= ∇logπ_div_nsamples + hess_buf[:, :] .+= ∇2logπ_div_nsamples + end + return logπ_avg, grad_buf, hess_buf + end +end diff --git a/src/algorithms/klminnaturalgraddescent.jl b/src/algorithms/klminnaturalgraddescent.jl new file mode 100644 index 000000000..88dafc5fc --- /dev/null +++ b/src/algorithms/klminnaturalgraddescent.jl @@ -0,0 +1,175 @@ + +""" + KLMinNaturalGradDescent(stepsize, n_samples, ensure_posdef, subsampling) + KLMinNaturalGradDescent(; stepsize, n_samples, ensure_posdef, subsampling) + +KL divergence minimization by running natural gradient descent[^KL2017][^KR2023], also called variational online Newton. +This algorithm can be viewed as an instantiation of mirror descent, where the Bregman divergence is chosen to be the KL divergence. + +If the `ensure_posdef` argument is true, the algorithm applies the technique by Lin *et al.*[^LSK2020], where the precision matrix update includes an additional term that guarantees positive definiteness. +This, however, involves an additional set of matrix-matrix system solves that could be costly. + +The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. +If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. +If the target has only first-order capability, we use Stein's identity. + +# (Keyword) Arguments +- `stepsize::Float64`: Step size. +- `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`) +- `ensure_posdef::Bool`: Ensure that the updated precision preserves positive definiteness. (default: `true`) +- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. + +!!! note + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinNaturalGradDescent` does not support amortization or structured variational families. + +# Output +- `q`: The last iterate of the algorithm. + +# Callback Signature +The `callback` function supplied to `optimize` needs to have the following signature: + + callback(; rng, iteration, q, info) + +The keyword arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `q`: Current variational approximation. +- `info`: `NamedTuple` containing the information generated during the current iteration. + +# Requirements +- The variational family is [`FullRankGaussian`](@ref FullRankGaussian). +- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. +""" +@kwdef struct KLMinNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: + AbstractVariationalAlgorithm + stepsize::Float64 + n_samples::Int = 1 + ensure_posdef::Bool = true + subsampling::Sub = nothing +end + +struct KLMinNaturalGradDescentState{Q,P,S,Prec,GradBuf,HessBuf} + q::Q + prob::P + prec::Prec + iteration::Int + sub_st::S + grad_buf::GradBuf + hess_buf::HessBuf +end + +function init( + rng::Random.AbstractRNG, + alg::KLMinNaturalGradDescent, + q_init::MvLocationScale{<:LowerTriangular,<:Normal,L}, + prob, +) where {L} + sub = alg.subsampling + n_dims = LogDensityProblems.dimension(prob) + capability = LogDensityProblems.capabilities(typeof(prob)) + if capability < LogDensityProblems.LogDensityOrder{1}() + throw( + ArgumentError( + "`KLMinNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", + ), + ) + end + sub_st = isnothing(sub) ? nothing : init(rng, sub) + grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) + hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) + return KLMinNaturalGradDescentState( + q_init, prob, inv(cov(q_init)), 0, sub_st, grad_buf, hess_buf + ) +end + +output(::KLMinNaturalGradDescent, state) = state.q + +function step( + rng::Random.AbstractRNG, + alg::KLMinNaturalGradDescent, + state, + callback, + objargs...; + kwargs..., +) + (; ensure_posdef, n_samples, stepsize, subsampling) = alg + (; q, prob, prec, iteration, sub_st, grad_buf, hess_buf) = state + + m = mean(q) + S = prec + η = convert(eltype(m), stepsize) + iteration += 1 + + # Maybe apply subsampling + prob_sub, sub_st′, sub_inf = if isnothing(subsampling) + prob, sub_st, NamedTuple() + else + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) + prob_sub = subsample(prob, batch) + prob_sub, sub_st′, sub_inf + end + + logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( + rng, q, n_samples, grad_buf, hess_buf, prob_sub + ) + + S′ = Hermitian(((1 - η) * S + η * Symmetric(-hess_buf))) + if ensure_posdef + G_hat = S - Symmetric(-hess_buf) + S′ += η^2 / 2 * Symmetric(G_hat * (S′ \ G_hat)) + end + m′ = m - η * (S′ \ (-grad_buf)) + + q′ = MvLocationScale(m′, inv(cholesky(S′).L), q.dist) + + state = KLMinNaturalGradDescentState( + q′, prob, S′, iteration, sub_st′, grad_buf, hess_buf + ) + elbo = logπ_avg + entropy(q′) + info = merge((elbo=elbo,), sub_inf) + + if !isnothing(callback) + info′ = callback(; rng, iteration, q, info) + info = !isnothing(info′) ? merge(info′, info) : info + end + state, false, info +end + +""" + estimate_objective([rng,] alg, q, prob; n_samples) + +Estimate the ELBO of the variational approximation `q` against the target log-density `prob`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `alg::KLMinNaturalGradDescent`: Variational inference algorithm. +- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective( + rng::Random.AbstractRNG, + alg::KLMinNaturalGradDescent, + q::MvLocationScale{S,<:Normal,L}, + prob; + n_samples::Int=alg.n_samples, +) where {S,L} + obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) + if isnothing(alg.subsampling) + return estimate_objective(rng, obj, q, prob) + else + sub = alg.subsampling + sub_st = init(rng, sub) + return mapreduce(+, 1:length(sub)) do _ + batch, sub_st, _ = step(rng, sub, sub_st) + prob_sub = subsample(prob, batch) + estimate_objective(rng, obj, q, prob_sub) / length(sub) + end + end +end diff --git a/src/algorithms/klminsqrtnaturalgraddescent.jl b/src/algorithms/klminsqrtnaturalgraddescent.jl new file mode 100644 index 000000000..052e4124c --- /dev/null +++ b/src/algorithms/klminsqrtnaturalgraddescent.jl @@ -0,0 +1,165 @@ + +""" + KLMinSqrtNaturalGradDescent(stepsize, n_samples, subsampling) + KLMinSqrtNaturalGradDescent(; stepsize, n_samples, subsampling) + +KL divergence minimization algorithm obtained by discretizing the natural gradient flow (the Riemannian gradient flow with the Fisher information matrix as the metric tensor) under the square-root parameterization[^KMKL2025][^LDENKTM2024][^LDLNKS2023][^T2025]. + +The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. +If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. +If the target has only first-order capability, we use Stein's identity. + +# (Keyword) Arguments +- `stepsize::Float64`: Step size. +- `n_samples::Int`: Number of samples used to estimate the natural gradient. (default: `1`) +- `subsampling::Union{Nothing,<:AbstractSubsampling}`: Optional subsampling strategy. + +!!! note + The `subsampling` strategy is only applied to the target `LogDensityProblem` but not to the variational approximation `q`. That is, `KLMinSqrtNaturalGradDescent` does not support amortization or structured variational families. + +# Output +- `q`: The last iterate of the algorithm. + +# Callback Signature +The `callback` function supplied to `optimize` needs to have the following signature: + + callback(; rng, iteration, q, info) + +The keyword arguments are as follows: +- `rng`: Random number generator internally used by the algorithm. +- `iteration`: The index of the current iteration. +- `q`: Current variational approximation. +- `info`: `NamedTuple` containing the information generated during the current iteration. + +# Requirements +- The variational family is [`FullRankGaussian`](@ref FullRankGaussian). +- The target distribution has unconstrained support (\$\$\\mathbb{R}^d\$\$). +- The target `LogDensityProblems.logdensity(prob, x)` has at least first-order differentiation capability. +""" +@kwdef struct KLMinSqrtNaturalGradDescent{Sub<:Union{Nothing,<:AbstractSubsampling}} <: + AbstractVariationalAlgorithm + stepsize::Float64 + n_samples::Int = 1 + subsampling::Sub = nothing +end + +struct KLMinSqrtNaturalGradDescentState{Q,P,S,GradBuf,HessBuf} + q::Q + prob::P + iteration::Int + sub_st::S + grad_buf::GradBuf + hess_buf::HessBuf +end + +function init( + rng::Random.AbstractRNG, + alg::KLMinSqrtNaturalGradDescent, + q_init::MvLocationScale{<:LowerTriangular,<:Normal,L}, + prob, +) where {L} + sub = alg.subsampling + n_dims = LogDensityProblems.dimension(prob) + capability = LogDensityProblems.capabilities(typeof(prob)) + if capability < LogDensityProblems.LogDensityOrder{1}() + throw( + ArgumentError( + "`KLMinSqrtNaturalGradDescent` requires at least first-order differentiation capability. The capability of the supplied `LogDensityProblem` is $(capability).", + ), + ) + end + sub_st = isnothing(sub) ? nothing : init(rng, sub) + grad_buf = Vector{eltype(q_init.location)}(undef, n_dims) + hess_buf = Matrix{eltype(q_init.location)}(undef, n_dims, n_dims) + return KLMinSqrtNaturalGradDescentState(q_init, prob, 0, sub_st, grad_buf, hess_buf) +end + +output(::KLMinSqrtNaturalGradDescent, state) = state.q + +function step( + rng::Random.AbstractRNG, + alg::KLMinSqrtNaturalGradDescent, + state, + callback, + objargs...; + kwargs..., +) + (; n_samples, stepsize, subsampling) = alg + (; q, prob, iteration, sub_st, grad_buf, hess_buf) = state + + m = q.location + C = q.scale + η = convert(eltype(m), stepsize) + iteration += 1 + + # Maybe apply subsampling + prob_sub, sub_st′, sub_inf = if isnothing(subsampling) + prob, sub_st, NamedTuple() + else + batch, sub_st′, sub_inf = step(rng, subsampling, sub_st) + prob_sub = subsample(prob, batch) + prob_sub, sub_st′, sub_inf + end + + logπ_avg, grad_buf, hess_buf = gaussian_expectation_gradient_and_hessian!( + rng, q, n_samples, grad_buf, hess_buf, prob_sub + ) + + CtHCmI = C'*Symmetric(-hess_buf)*C - I + CtHCmI_tril = LowerTriangular(tril(CtHCmI) - Diagonal(diag(CtHCmI))/2) + + m′ = m - η * C * (C' * -grad_buf) + C′ = C - η * C * CtHCmI_tril + + q′ = MvLocationScale(m′, C′, q.dist) + + state = KLMinSqrtNaturalGradDescentState( + q′, prob, iteration, sub_st′, grad_buf, hess_buf + ) + elbo = logπ_avg + entropy(q′) + info = merge((elbo=elbo,), sub_inf) + + if !isnothing(callback) + info′ = callback(; rng, iteration, q, info) + info = !isnothing(info′) ? merge(info′, info) : info + end + state, false, info +end + +""" + estimate_objective([rng,] alg, q, prob; n_samples) + +Estimate the ELBO of the variational approximation `q` against the target log-density `prob`. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `alg::KLMinSqrtNaturalGradDescent`: Variational inference algorithm. +- `q::MvLocationScale{<:Any,<:Normal,<:Any}`: Gaussian variational approximation. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. + +# Keyword Arguments +- `n_samples::Int`: Number of Monte Carlo samples for estimating the objective. (default: Same as the the number of samples used for estimating the gradient during optimization.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_objective( + rng::Random.AbstractRNG, + alg::KLMinSqrtNaturalGradDescent, + q::MvLocationScale{S,<:Normal,L}, + prob; + n_samples::Int=alg.n_samples, +) where {S,L} + obj = RepGradELBO(n_samples; entropy=MonteCarloEntropy()) + if isnothing(alg.subsampling) + return estimate_objective(rng, obj, q, prob) + else + sub = alg.subsampling + sub_st = init(rng, sub) + return mapreduce(+, 1:length(sub)) do _ + batch, sub_st, _ = step(rng, sub, sub_st) + prob_sub = subsample(prob, batch) + estimate_objective(rng, obj, q, prob_sub) / length(sub) + end + end +end diff --git a/src/algorithms/klminwassfwdbwd.jl b/src/algorithms/klminwassfwdbwd.jl index f834b539a..14bd52dbe 100644 --- a/src/algorithms/klminwassfwdbwd.jl +++ b/src/algorithms/klminwassfwdbwd.jl @@ -5,7 +5,9 @@ KL divergence minimization by running stochastic proximal gradient descent (forward-backward splitting) in Wasserstein space[^DBCS2023]. -Denoting the target log-density as \$\$ \\log \\pi \$\$ and the current variational approximation as \$\$q\$\$, the original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$. If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. If the target has only first-order capability, we use Stein's identity. +The original algorithm requires estimating the quantity \$\$ \\mathbb{E}_q \\nabla^2 \\log \\pi \$\$, where \$\$ \\log \\pi \$\$ is the target log-density and \$\$q\$\$ is the current variational approximation. +If the target `LogDensityProblem` associated with \$\$ \\log \\pi \$\$ has second-order differentiation [capability](https://www.tamaspapp.eu/LogDensityProblems.jl/dev/#LogDensityProblems.capabilities), we use the sample average of the Hessian. +If the target has only first-order capability, we use Stein's identity. # (Keyword) Arguments - `n_samples::Int`: Number of samples used to estimate the Wasserstein gradient. (default: `1`) @@ -41,61 +43,6 @@ The keyword arguments are as follows: subsampling::Sub = nothing end -""" - gaussian_expectation_gradient_and_hessian!(rng, q, n_samples, grad_buf, hess_buf, prob) - -Estimate the expectations of the gradient and Hessians of the log-density of `prob` taken over the Gaussian `q`. For estimating the expectation of the Hessian, if `prob` has second-order differentiation capability, this function uses the sample average of the Hessian. Otherwise, it uses Stein's identity. - -# Arguments -- `rng::Random.AbstractRNG`: Random number generator. -- `q::MvLocationScale{<:LowerTriangular,<:Normal,L}`: Gaussian to take expectation over. -- `n_samples::Int`: Number of samples used for estimation. -- `grad_buf::AbstractVector`: Buffer for the gradient estimate. -- `hess_buf::AbstractMatrix`: Buffer for the Hessian estimate. -- `prob`: `LogDensityProblem` associated with the log-density gradient and Hessian subject to expectation. -""" -function gaussian_expectation_gradient_and_hessian!( - rng::Random.AbstractRNG, - q::MvLocationScale{<:LowerTriangular,<:Normal,L}, - n_samples::Int, - grad_buf::AbstractVector{T}, - hess_buf::AbstractMatrix{T}, - prob, -) where {T<:Real,L} - logπ_avg = zero(T) - fill!(grad_buf, zero(T)) - fill!(hess_buf, zero(T)) - - if LogDensityProblems.capabilities(typeof(prob)) ≤ - LogDensityProblems.LogDensityOrder{1}() - # Use Stein's identity - d = LogDensityProblems.dimension(prob) - u = randn(rng, T, d, n_samples) - z = q.scale*u .+ q.location - for b in 1:n_samples - zb, ub = view(z, :, b), view(u, :, b) - logπ, ∇logπ = LogDensityProblems.logdensity_and_gradient(prob, zb) - logπ_avg += logπ/n_samples - grad_buf += ∇logπ/n_samples - hess_buf += ub*(∇logπ/n_samples)' - end - return logπ_avg, grad_buf, hess_buf - else - # Use sample average of the Hessian. - z = rand(rng, q, n_samples) - for b in 1:n_samples - zb = view(z, :, b) - logπ, ∇logπ, ∇2logπ = LogDensityProblems.logdensity_gradient_and_hessian( - prob, zb - ) - logπ_avg += logπ/n_samples - grad_buf += ∇logπ/n_samples - hess_buf += ∇2logπ/n_samples - end - return logπ_avg, grad_buf, hess_buf - end -end - struct KLMinWassFwdBwdState{Q,P,S,Sigma,GradBuf,HessBuf} q::Q prob::P @@ -156,7 +103,7 @@ function step( ) m′ = m - η * (-grad_buf) - M = I - η*Hermitian(-hess_buf) + M = I - η*Symmetric(-hess_buf) Σ_half = Hermitian(M*Σ*M) # Compute the JKO proximal operator diff --git a/test/algorithms/klminnaturalgraddescent.jl b/test/algorithms/klminnaturalgraddescent.jl new file mode 100644 index 000000000..e5cd5b558 --- /dev/null +++ b/test/algorithms/klminnaturalgraddescent.jl @@ -0,0 +1,158 @@ + +@testset "KLMinNaturalGradDescent" begin + begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=2) + (; model, n_dims, μ_true, L_true) = modelstats + + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3) + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + @testset "callback" begin + T = 10 + callback(; iteration, kwargs...) = (iteration_check=iteration,) + _, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS) + @test [i.iteration_check for i in info] == 1:T + end + + @testset "estimate_objective" begin + q_true = FullRankGaussian(μ_true, LowerTriangular(Matrix(L_true))) + + obj_est = estimate_objective(alg, q_true, model) + @test isfinite(obj_est) + + obj_est = estimate_objective(alg, q_true, model; n_samples=10^5) + @test obj_est ≈ 0 atol=1e-2 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + T = 10 + + q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ == μ_repl + @test L == L_repl + end + end + + @testset "error low capability" begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=0) + (; model, n_dims) = modelstats + + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1.0) + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + @test_throws "first-order" optimize(alg, 1, model, q0) + end + + @testset "type stability type=$(realtype), capability=$(capability)" for realtype in [ + Float64, Float32 + ], + capability in [1, 2] + + modelstats = normal_meanfield(Random.default_rng(), realtype; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3) + T = 10 + + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(realtype, n_dims), L0) + + q, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + @test eltype(q.location) == eltype(μ_true) + @test eltype(q.scale) == eltype(L_true) + end + + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = normal_meanfield(Random.default_rng(), Float64; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3) + + q_avg, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q_avg.location - μ_true) + sum(abs2, q_avg.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + + @testset "subsampling" begin + n_data = 8 + + @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3) + alg_sub = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3, subsampling) + + obj_full = estimate_objective(alg, q0, model; n_samples=10^5) + obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5) + @test obj_full ≈ obj_sub rtol=0.1 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 10 + batchsize = 3 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-3, subsampling) + + q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = subsamplednormal(Random.default_rng(), n_data; capability) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 1000 + batchsize = 1 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinNaturalGradDescent(; n_samples=10, stepsize=1e-2, subsampling) + + q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + end +end diff --git a/test/algorithms/klminsqrtnaturalgraddescent.jl b/test/algorithms/klminsqrtnaturalgraddescent.jl new file mode 100644 index 000000000..7841c6d2d --- /dev/null +++ b/test/algorithms/klminsqrtnaturalgraddescent.jl @@ -0,0 +1,164 @@ + +@testset "KLMinSqrtNaturalGradDescent" begin + begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=2) + (; model, n_dims, μ_true, L_true) = modelstats + + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + @testset "callback" begin + T = 10 + callback(; iteration, kwargs...) = (iteration_check=iteration,) + _, info, _ = optimize(alg, T, model, q0; callback, show_progress=PROGRESS) + @test [i.iteration_check for i in info] == 1:T + end + + @testset "estimate_objective" begin + q_true = FullRankGaussian(μ_true, LowerTriangular(Matrix(L_true))) + + obj_est = estimate_objective(alg, q_true, model) + @test isfinite(obj_est) + + obj_est = estimate_objective(alg, q_true, model; n_samples=10^5) + @test obj_est ≈ 0 atol=1e-2 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + T = 10 + + q_avg, _, _ = optimize(rng, alg, T, model, q0; show_progress=PROGRESS) + μ = q_avg.location + L = q_avg.scale + + rng_repl = StableRNG(seed) + q_avg, _, _ = optimize(rng_repl, alg, T, model, q0; show_progress=PROGRESS) + μ_repl = q_avg.location + L_repl = q_avg.scale + @test μ == μ_repl + @test L == L_repl + end + end + + @testset "error low capability" begin + modelstats = normal_meanfield(Random.default_rng(), Float64; capability=0) + (; model, n_dims) = modelstats + + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1.0) + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + @test_throws "first-order" optimize(alg, 1, model, q0) + end + + @testset "type stability type=$(realtype), capability=$(capability)" for realtype in [ + Float64, Float32 + ], + capability in [1, 2] + + modelstats = normal_meanfield(Random.default_rng(), realtype; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) + T = 10 + + L0 = LowerTriangular(Matrix{realtype}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(realtype, n_dims), L0) + + q, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + @test eltype(q.location) == eltype(μ_true) + @test eltype(q.scale) == eltype(L_true) + end + + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = normal_meanfield(Random.default_rng(), Float64; capability) + (; model, μ_true, L_true, n_dims, strong_convexity, is_meanfield) = modelstats + + T = 1000 + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) + + q_avg, _, _ = optimize(alg, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q_avg.location - μ_true) + sum(abs2, q_avg.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + + @testset "subsampling" begin + n_data = 8 + + @testset "estimate_objective batchsize=$(batchsize)" for batchsize in [1, 3, 4] + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg = KLMinSqrtNaturalGradDescent(; n_samples=10, stepsize=1e-3) + alg_sub = KLMinSqrtNaturalGradDescent(; + n_samples=10, stepsize=1e-3, subsampling + ) + + obj_full = estimate_objective(alg, q0, model; n_samples=10^5) + obj_sub = estimate_objective(alg_sub, q0, model; n_samples=10^5) + @test obj_full ≈ obj_sub rtol=0.1 + end + + @testset "determinism" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = subsamplednormal(Random.default_rng(), n_data) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 10 + batchsize = 3 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinSqrtNaturalGradDescent(; + n_samples=10, stepsize=1e-3, subsampling + ) + + q, _, _ = optimize(rng, alg_sub, T, model, q0; show_progress=PROGRESS) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, _, _ = optimize(rng_repl, alg_sub, T, model, q0; show_progress=PROGRESS) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + + @testset "convergence capability=$(capability)" for capability in [1, 2] + modelstats = subsamplednormal(Random.default_rng(), n_data; capability) + (; model, n_dims, μ_true, L_true) = modelstats + + L0 = LowerTriangular(Matrix{Float64}(I, n_dims, n_dims)) + q0 = FullRankGaussian(zeros(Float64, n_dims), L0) + + T = 1000 + batchsize = 1 + subsampling = ReshufflingBatchSubsampling(1:n_data, batchsize) + alg_sub = KLMinSqrtNaturalGradDescent(; + n_samples=10, stepsize=1e-2, subsampling + ) + + q, stats, _ = optimize(alg_sub, T, model, q0; show_progress=PROGRESS) + + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ = sum(abs2, q.location - μ_true) + sum(abs2, q.scale - L_true) + + @test Δλ ≤ 0.1*Δλ0 + end + end +end diff --git a/test/general/gauss_expected_grad_hess.jl b/test/general/gauss_expected_grad_hess.jl new file mode 100644 index 000000000..ee254a6a9 --- /dev/null +++ b/test/general/gauss_expected_grad_hess.jl @@ -0,0 +1,54 @@ + +struct TestQuad{S,C} + Σ::S + cap::C +end + +function LogDensityProblems.logdensity(model::TestQuad, x) + Σ = model.Σ + return -x'*Σ*x/2 +end + +function LogDensityProblems.logdensity_and_gradient(model::TestQuad, x) + Σ = model.Σ + return (LogDensityProblems.logdensity(model, x), -Σ*x) +end + +function LogDensityProblems.logdensity_gradient_and_hessian(model::TestQuad, x) + Σ = model.Σ + ℓp, ∇ℓp = LogDensityProblems.logdensity_and_gradient(model, x) + return (ℓp, ∇ℓp, -Σ) +end + +function LogDensityProblems.dimension(model::TestQuad) + return size(model.Σ, 1) +end + +function LogDensityProblems.capabilities(::Type{TestQuad{S,C}}) where {S,C} + return C() +end + +@testset "gauss_expected_grad_hess" begin + n_samples = 10^6 + d = 2 + Σ = [2.0 -0.1; -0.1 2.0] + q = FullRankGaussian(ones(d), LowerTriangular(diagm(fill(0.1, d)))) + + # True expected gradient is E_{x ~ N(μ, 1)} -Σ x = -Σ μ + # True expected Hessian is E_{x ~ N(μ, 1)} -Σ = -Σ + E_∇ℓπ = -Σ*q.location + E_∇2ℓπ = -Σ + + @testset "$(cap)-order capability" for cap in [ + LogDensityProblems.LogDensityOrder{1}(), LogDensityProblems.LogDensityOrder{2}() + ] + grad_buf = zeros(d) + hess_buf = zeros(d, d) + prob = TestQuad(Σ, cap) + AdvancedVI.gaussian_expectation_gradient_and_hessian!( + Random.default_rng(), q, n_samples, grad_buf, hess_buf, prob + ) + @test grad_buf ≈ E_∇ℓπ atol=1e-1 + @test hess_buf ≈ E_∇2ℓπ atol=1e-1 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index ab67247b3..0d02d0168 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -65,7 +65,10 @@ if GROUP == "All" || GROUP == "GENERAL" include("families/location_scale.jl") include("families/location_scale_low_rank.jl") + include("general/gauss_expected_grad_hess.jl") include("algorithms/klminwassfwdbwd.jl") + include("algorithms/klminsqrtnaturalgraddescent.jl") + include("algorithms/klminnaturalgraddescent.jl") end if GROUP == "All" || GROUP == "AD"