Skip to content

Conversation

@Red-Portal
Copy link
Member

@Red-Portal Red-Portal commented Nov 7, 2025

This adds the following natural gradient VI algorithms:

  • Variational online Newton12
  • Square-root variational online Newton3456

Natural gradient VI (NGVI) is a family of algorithms that correspond to mirror descent under the Bregman divergence. Since the pseudo-metric is a divergence between distributions, the algorithm can be thought of as a measure-space algorithm. Therefore, empirically, NGVI tends to converge faster than BBVI/ADVI. However, the algorithm also involves quantities defined in terms of variational parameters, so it is not a fully measure-space algorithm. As such, design decisions related to parametrizations and update rules result in different implementations (hence two algorithms in this PR). Furthermore, NGVI is restricted to (mixtures) exponential variational families. The PR only implements the Gaussian variational family variant. Another downside is that the update rules tend to involve operations that are costly ($\mathrm{O}(d^3)$ for a $d$-dimensional target) and sensitive to numerical errors.

This addresses #1

Footnotes

  1. Khan, M., & Lin, W. (2017, April). Conjugate-computation variational inference: Converting variational inference in non-conjugate models to inferences in conjugate models. AISTATS.

  2. Khan, M. E., & Rue, H. (2023). The Bayesian learning rule. Journal of Machine Learning Research, 24(281), 1-46.

  3. Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. TMLR.

  4. Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024, July). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. ICML.

  5. Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023, July). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. ICML.

  6. Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. JRSS:B.

subsampling::Sub = nothing
end

"""
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This portion has been moved to a separate file algorithms/gauss_expected_grad_hess.jl.

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
@github-actions
Copy link
Contributor

github-actions bot commented Nov 7, 2025

AdvancedVI.jl documentation for PR #211 is available at:
https://TuringLang.github.io/AdvancedVI.jl/previews/PR211/

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark Results

Benchmark suite Current: 3ba8401 Previous: f3790c3 Ratio
normal/RepGradELBO + STL/meanfield/Zygote 3873826750 ns 3947991767 ns 0.98
normal/RepGradELBO + STL/meanfield/ReverseDiff 1128485842 ns 1158667552 ns 0.97
normal/RepGradELBO + STL/meanfield/Mooncake 1204523642 ns 1195024786 ns 1.01
normal/RepGradELBO + STL/fullrank/Zygote 3867767080.5 ns 3913524681.5 ns 0.99
normal/RepGradELBO + STL/fullrank/ReverseDiff 1631515828.5 ns 1687251928 ns 0.97
normal/RepGradELBO + STL/fullrank/Mooncake 1242130714 ns 1259040441.5 ns 0.99
normal/RepGradELBO/meanfield/Zygote 2753786960 ns 2748544197 ns 1.00
normal/RepGradELBO/meanfield/ReverseDiff 781925812 ns 802111316 ns 0.97
normal/RepGradELBO/meanfield/Mooncake 1064810388 ns 1075175373 ns 0.99
normal/RepGradELBO/fullrank/Zygote 2725854614.5 ns 2747780958 ns 0.99
normal/RepGradELBO/fullrank/ReverseDiff 957663728.5 ns 986883626 ns 0.97
normal/RepGradELBO/fullrank/Mooncake 1113456897 ns 1110406036 ns 1.00
normal + bijector/RepGradELBO + STL/meanfield/Zygote 5470822166 ns 5438616109 ns 1.01
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff 2379772766 ns 2408900504 ns 0.99
normal + bijector/RepGradELBO + STL/meanfield/Mooncake 4061683853 ns 3976898614 ns 1.02
normal + bijector/RepGradELBO + STL/fullrank/Zygote 5522150676 ns 5609235006 ns 0.98
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff 3030590749 ns 3075689171 ns 0.99
normal + bijector/RepGradELBO + STL/fullrank/Mooncake 4170474373.5 ns 4143109561 ns 1.01
normal + bijector/RepGradELBO/meanfield/Zygote 4243544018 ns 4228898619.5 ns 1.00
normal + bijector/RepGradELBO/meanfield/ReverseDiff 2026830614 ns 2037504059 ns 0.99
normal + bijector/RepGradELBO/meanfield/Mooncake 3900253917.5 ns 3867072625 ns 1.01
normal + bijector/RepGradELBO/fullrank/Zygote 4345868079.5 ns 4369346900.5 ns 0.99
normal + bijector/RepGradELBO/fullrank/ReverseDiff 2280632962 ns 2272673566 ns 1.00
normal + bijector/RepGradELBO/fullrank/Mooncake 3988284784.5 ns 4075940041 ns 0.98

This comment was automatically generated by workflow using github-action-benchmark.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants