-
-
Notifications
You must be signed in to change notification settings - Fork 216
Description
ForwardDiff doesn't seem to support FFT or complex-type, but Zygote.gradient does. Per this discussion,
https://discourse.julialang.org/t/forwarddiff-and-zygote-cannot-automatically-differentiate-ad-function-from-c-n-to-r-that-uses-fft/52440/4?u=tholdem
I was made aware of that Zygote.hessian uses ForwardDiff. This prevents Zygote to AD the Hessian of my objective function involving FFT and complex-numbers. I don't have much experience with these things and I'm new to Julia. So I don't really know if I can pull off writing a Hessian function myself using Zygote only. I was wondering 1. why would Zygote call ForwardDiff instead of its own functions? 2. if there is more guidance or resource on how to write a Hessian AD function that calls Zygote only that supports complex-numbers and FFT? Thank you so much for your help.
For example, this code doesn't work because Zygote doesn't support mutating arrays, and it's unclear whether this code handles complex-numbers.
function jacobian(f,x)
y,back = Zygote.pullback(f,x)
k = length(y)
n = length(x)
J = Matrix{eltype(y)}(undef,k,n)
e_i = fill!(similar(y), 0)
@inbounds for i = 1:k
e_i[i] = oneunit(eltype(x))
J[i,:] = back(e_i)[1]
e_i[i] = zero(eltype(x))
end
(J,)
end
hessian(f, x) = jacobian(x -> gradient(f, x)[1], x)
There seems to be a workaround for mutating arrays, https://github.com/rakeshvar/Zygote-Mutating-Arrays-WorkAround.jl, but it is much harder to find the gradient of the mutating steps in the code above so it would be very daunting to find a workaround myself.