-
Notifications
You must be signed in to change notification settings - Fork 19
Add natural gradient variational inference algorithms #211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| subsampling::Sub = nothing | ||
| end | ||
|
|
||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This portion has been moved to a separate file algorithms/gauss_expected_grad_hess.jl.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
AdvancedVI.jl documentation for PR #211 is available at: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark Results
| Benchmark suite | Current: 3ba8401 | Previous: f3790c3 | Ratio |
|---|---|---|---|
normal/RepGradELBO + STL/meanfield/Zygote |
3873826750 ns |
3947991767 ns |
0.98 |
normal/RepGradELBO + STL/meanfield/ReverseDiff |
1128485842 ns |
1158667552 ns |
0.97 |
normal/RepGradELBO + STL/meanfield/Mooncake |
1204523642 ns |
1195024786 ns |
1.01 |
normal/RepGradELBO + STL/fullrank/Zygote |
3867767080.5 ns |
3913524681.5 ns |
0.99 |
normal/RepGradELBO + STL/fullrank/ReverseDiff |
1631515828.5 ns |
1687251928 ns |
0.97 |
normal/RepGradELBO + STL/fullrank/Mooncake |
1242130714 ns |
1259040441.5 ns |
0.99 |
normal/RepGradELBO/meanfield/Zygote |
2753786960 ns |
2748544197 ns |
1.00 |
normal/RepGradELBO/meanfield/ReverseDiff |
781925812 ns |
802111316 ns |
0.97 |
normal/RepGradELBO/meanfield/Mooncake |
1064810388 ns |
1075175373 ns |
0.99 |
normal/RepGradELBO/fullrank/Zygote |
2725854614.5 ns |
2747780958 ns |
0.99 |
normal/RepGradELBO/fullrank/ReverseDiff |
957663728.5 ns |
986883626 ns |
0.97 |
normal/RepGradELBO/fullrank/Mooncake |
1113456897 ns |
1110406036 ns |
1.00 |
normal + bijector/RepGradELBO + STL/meanfield/Zygote |
5470822166 ns |
5438616109 ns |
1.01 |
normal + bijector/RepGradELBO + STL/meanfield/ReverseDiff |
2379772766 ns |
2408900504 ns |
0.99 |
normal + bijector/RepGradELBO + STL/meanfield/Mooncake |
4061683853 ns |
3976898614 ns |
1.02 |
normal + bijector/RepGradELBO + STL/fullrank/Zygote |
5522150676 ns |
5609235006 ns |
0.98 |
normal + bijector/RepGradELBO + STL/fullrank/ReverseDiff |
3030590749 ns |
3075689171 ns |
0.99 |
normal + bijector/RepGradELBO + STL/fullrank/Mooncake |
4170474373.5 ns |
4143109561 ns |
1.01 |
normal + bijector/RepGradELBO/meanfield/Zygote |
4243544018 ns |
4228898619.5 ns |
1.00 |
normal + bijector/RepGradELBO/meanfield/ReverseDiff |
2026830614 ns |
2037504059 ns |
0.99 |
normal + bijector/RepGradELBO/meanfield/Mooncake |
3900253917.5 ns |
3867072625 ns |
1.01 |
normal + bijector/RepGradELBO/fullrank/Zygote |
4345868079.5 ns |
4369346900.5 ns |
0.99 |
normal + bijector/RepGradELBO/fullrank/ReverseDiff |
2280632962 ns |
2272673566 ns |
1.00 |
normal + bijector/RepGradELBO/fullrank/Mooncake |
3988284784.5 ns |
4075940041 ns |
0.98 |
This comment was automatically generated by workflow using github-action-benchmark.
This adds the following natural gradient VI algorithms:
Natural gradient VI (NGVI) is a family of algorithms that correspond to mirror descent under the Bregman divergence. Since the pseudo-metric is a divergence between distributions, the algorithm can be thought of as a measure-space algorithm. Therefore, empirically, NGVI tends to converge faster than BBVI/ADVI. However, the algorithm also involves quantities defined in terms of variational parameters, so it is not a fully measure-space algorithm. As such, design decisions related to parametrizations and update rules result in different implementations (hence two algorithms in this PR). Furthermore, NGVI is restricted to (mixtures) exponential variational families. The PR only implements the Gaussian variational family variant. Another downside is that the update rules tend to involve operations that are costly ($\mathrm{O}(d^3)$ for a $d$ -dimensional target) and sensitive to numerical errors.
This addresses #1
Footnotes
Khan, M., & Lin, W. (2017, April). Conjugate-computation variational inference: Converting variational inference in non-conjugate models to inferences in conjugate models. AISTATS. ↩
Khan, M. E., & Rue, H. (2023). The Bayesian learning rule. Journal of Machine Learning Research, 24(281), 1-46. ↩
Kumar, N., Möllenhoff, T., Khan, M. E., & Lucchi, A. (2025). Optimization Guarantees for Square-Root Natural-Gradient Variational Inference. TMLR. ↩
Lin, W., Dangel, F., Eschenhagen, R., Bae, J., Turner, R. E., & Makhzani, A. (2024, July). Can We Remove the Square-Root in Adaptive Gradient Methods? A Second-Order Perspective. ICML. ↩
Lin, W., Duruisseaux, V., Leok, M., Nielsen, F., Khan, M. E., & Schmidt, M. (2023, July). Simplifying momentum-based positive-definite submanifold optimization with applications to deep learning. ICML. ↩
Tan, L. S. (2025). Analytic natural gradient updates for Cholesky factor in Gaussian variational approximation. JRSS:B. ↩