Since it is very typical for VI to repeatedly compute gradients and value of a given function it seems Zygote provide a 3x speedup over ReverseDiff. The reason is that apparently Zygote does some kind of caching.
Here is a quick benchmark:
using BenchmarkTools
using Distributions
using Zygote
using ReverseDiff
d = MvNormal(rand(50), rand(50, 50) |> x -> x * x')
f(x) = logpdf(d, x)
X = rand(d, 40)
@btime ReverseDiff.gradient($X) do x
sum(f, eachcol(x))
end
# 15.941 ms (534899 allocations: 23.11 MiB)
@btime Zygote.gradient($X) do x
sum(f, eachcol(x))
end
# 5.405 ms (22475 allocations: 4.51 MiB)