Skip to content

multivalue autodiff #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <cont
version = "2.0.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
46 changes: 2 additions & 44 deletions ext/SparseDiffToolsZygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@ module SparseDiffToolsZygote
if isdefined(Base, :get_extension)
import Zygote
using LinearAlgebra
using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
using SparseDiffTools: SparseDiffTools, DeivVecTag
using ForwardDiff: ForwardDiff, Dual, partials
using SciMLOperators: FunctionOperator
using Tricks: static_hasmethod
else
import ..Zygote
using ..LinearAlgebra
using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
using ..SparseDiffTools: SparseDiffTools, DeivVecTag
using ..ForwardDiff: ForwardDiff, Dual, partials
using ..SciMLOperators: FunctionOperator
using ..Tricks: static_hasmethod
end

### Jac, Hes products
Expand Down Expand Up @@ -69,40 +65,6 @@ function SparseDiffTools.autoback_hesvec(f, x, v)
ForwardDiff.partials.(g(y), 1)
end

# Operator Forms

function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)

if autodiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
cache2 = copy(u)
else
cache1 = similar(u)
cache2 = similar(u)
end

cache = (cache1, cache2,)

vecprod = autodiff ? SparseDiffTools.autoback_hesvec : SparseDiffTools.numback_hesvec
vecprod! = autodiff ? SparseDiffTools.autoback_hesvec! : SparseDiffTools.numback_hesvec!

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u,)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u).")
end

L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)

FunctionOperator(L, u, u;
isinplace = isinplace, outofplace = outofplace,
p = p, t = t, islinear = true,
)
end

## VecJac products

function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
Expand All @@ -115,8 +77,4 @@ function SparseDiffTools.auto_vecjac(f, x, v)
return vec(back(reshape(v, size(vv)))[1])
end

function SparseDiffTools.ZygoteVecJac(args...; autodiff = true, kwargs...)
VecJac(args...; autodiff = autodiff, kwargs...)
end

end # module
10 changes: 4 additions & 6 deletions src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ using Graphs
using Graphs: SimpleGraph
using VertexSafeGraphs
using Adapt
using Reexport
@reexport using ADTypes

using LinearAlgebra
using SparseArrays, ArrayInterface
Expand Down Expand Up @@ -69,30 +71,26 @@ parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

import Requires
import Reexport

function numback_hesvec end
function numback_hesvec! end
function autoback_hesvec end
function autoback_hesvec! end
function auto_vecjac end
function auto_vecjac! end
function ZygoteVecJac end
function ZygoteHesVec end

@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
include("../ext/SparseDiffToolsZygote.jl")
Reexport.@reexport using .SparseDiffToolsZygote
@reexport using .SparseDiffToolsZygote
end
end
end

export
numback_hesvec, numback_hesvec!,
autoback_hesvec, autoback_hesvec!,
auto_vecjac, auto_vecjac!,
ZygoteVecJac, ZygoteHesVec
auto_vecjac, auto_vecjac!

end # module
67 changes: 39 additions & 28 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,24 +223,25 @@ function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
L.vecprod!(dv, L.f, L.u, v, L.cache...)
end

function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())

if autodiff
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)

(cache1, cache2), num_jacvec, num_jacvec!
elseif autodiff isa AutoForwardDiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))

cache2 = copy(cache1)

(cache1, cache2), auto_jacvec, auto_jacvec!
else
cache1 = similar(u)
cache2 = similar(u)
@error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
end

cache = (cache1, cache2,)

vecprod = autodiff ? auto_jacvec : num_jacvec
vecprod! = autodiff ? auto_jacvec! : num_jacvec!

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u,)))

Expand All @@ -256,22 +257,32 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
)
end

function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())

if autodiff
cache1 = ForwardDiff.GradientConfig(f, u)
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)
cache3 = similar(u)
else
cache1 = similar(u)

(cache1, cache2, cache3), num_hesvec, num_hesvec!
elseif autodiff isa AutoForwardDiff
cache1 = ForwardDiff.GradientConfig(f, u)
cache2 = similar(u)
cache3 = similar(u)
end

cache = (cache1, cache2, cache3,)
(cache1, cache2, cache3), numauto_hesvec, numauto_hesvec!
elseif autodiff isa AutoZygote
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"

vecprod = autodiff ? numauto_hesvec : num_hesvec
vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))
cache2 = copy(u)

(cache1, cache2), autoback_hesvec, autoback_hesvec!
else
@error("Set autodiff to either AutoForwardDiff(), AutoZygote(), or AutoFiniteDiff()")
end

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u,)))
Expand All @@ -288,24 +299,24 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
)
end

function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())

if autodiff
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
cache1 = similar(u)
cache2 = similar(u)

(cache1, cache2), num_hesvecgrad, num_hesvecgrad!
elseif autodiff isa AutoForwardDiff
cache1 = Dual{
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
}.(u, ForwardDiff.Partials.(tuple.(u)))

cache2 = copy(cache1)

(cache1, cache2), auto_hesvecgrad, auto_hesvecgrad!
else
cache1 = similar(u)
cache2 = similar(u)
@error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
end

cache = (cache1, cache2,)

vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u,)))

Expand Down
18 changes: 10 additions & 8 deletions src/differentiation/vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@ struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
vecprod::V
vecprod!::V!

function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = false,
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!;
autodiff = AutoFiniteDiff(),
isinplace = false, outofplace = true)
@assert isinplace || outofplace

new{
autodiff,
typeof(autodiff),
isinplace,
outofplace,
typeof(f),
Expand Down Expand Up @@ -86,18 +87,19 @@ function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad}
L.vecprod!(dv, (_du, _u) -> L.f(_du, _u, p, t), L.u, v, L.cache...)
end

function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = false,
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
ishermitian = false, opnrom = true)

if autodiff
@assert isdefined(SparseDiffTools, :auto_vecjac) "Please load Zygote with `using Zygote`, or `import Zygote` to use VecJac with `autodiff = true`."
vecprod, vecprod! = if autodiff isa AutoFiniteDiff
num_vecjac, num_vecjac!
elseif autodiff isa AutoZygote
@assert static_hasmethod(auto_vecjac, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"

auto_vecjac, auto_vecjac!
end

cache = (similar(u), similar(u),)

vecprod = autodiff ? auto_vecjac : num_vecjac
vecprod! = autodiff ? auto_vecjac! : num_vecjac!

outofplace = static_hasmethod(f, typeof((u, p, t)))
isinplace = static_hasmethod(f, typeof((u, u, p, t)))

Expand Down
22 changes: 6 additions & 16 deletions test/test_jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v) ≈ auto_jacvec(f, v, v)
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*auto_jacvec(f,x,v) + b*_dy

L = JacVec(f, x, autodiff = false)
L = JacVec(f, x, autodiff = AutoFiniteDiff())
@test L * x ≈ num_jacvec(f, x, x)
@test L * v ≈ num_jacvec(f, x, v)
@test mul!(dy, L, v)≈num_jacvec(f, x, v) rtol=1e-6
Expand All @@ -92,7 +92,7 @@ gmres!(out, L, v)

x = rand(N)
v = rand(N)
L = HesVec(g, x, autodiff = false)
L = HesVec(g, x, autodiff = AutoFiniteDiff())
@test L * x ≈ num_hesvec(g, x, x)
@test L * v ≈ num_hesvec(g, x, v)
@test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2
Expand All @@ -113,21 +113,12 @@ dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy r
out = similar(v)
gmres!(out, L, v)

@info "ZygoteHesVec"
using Zygote

x = rand(N)
v = rand(N)

L = ZygoteHesVec(g, x, autodiff = false)
@test L * x ≈ numback_hesvec(g, x, x) rtol = 1e-2
@test L * v ≈ numback_hesvec(g, x, v) rtol = 1e-2
@test mul!(dy, L, v)≈numback_hesvec(g, x, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2
update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v)≈numback_hesvec(g, v, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) ≈ a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2

L = ZygoteHesVec(g, x)
L = HesVec(g, x, autodiff = AutoZygote())
@test L * x ≈ autoback_hesvec(g, x, x)
@test L * v ≈ autoback_hesvec(g, x, v)
@test mul!(dy, L, v)≈autoback_hesvec(g, x, v) rtol=1e-8
Expand All @@ -139,12 +130,11 @@ dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy
out = similar(v)
gmres!(out, L, v)


@info "HesVecGrad"

x = rand(N)
v = rand(N)
L = HesVecGrad(h, x, autodiff = false)
L = HesVecGrad(h, x, autodiff = AutoFiniteDiff())
@test L * x ≈ num_hesvec(g, x, x)
@test L * v ≈ num_hesvec(g, x, v)
@test mul!(dy, L, v)≈num_hesvec(g, x, v) rtol=1e-2
Expand All @@ -153,7 +143,7 @@ update_coefficients!(L, v, nothing, 0.0)
@test mul!(dy, L, v)≈num_hesvec(g, v, v) rtol=1e-2
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*num_hesvec(g,x,v)+b*_dy rtol=1e-2

L = HesVecGrad(h, x, autodiff = true)
L = HesVecGrad(h, x)
@test L * x ≈ autonum_hesvec(g, x, x)
@test L * v ≈ numauto_hesvec(g, x, v)
@test mul!(dy, L, v)≈numauto_hesvec(g, x, v) rtol=1e-8
Expand Down
4 changes: 2 additions & 2 deletions test/test_vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ L = VecJac(f, x)
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
update_coefficients!(L, v, nothing, 0.0)
@test L * v ≈ actual_vjp
L = VecJac(f, x; autodiff = false)
L = VecJac(f, x; autodiff = AutoFiniteDiff())
update_coefficients!(L, v, nothing, 0.0)
@test L * v ≈ actual_vjp

Expand All @@ -28,7 +28,7 @@ L = ZygoteVecJac(f, x)
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
update_coefficients!(L, v, nothing, 0.0)
@test L * v ≈ actual_vjp
L = ZygoteVecJac(f, x; autodiff = false)
L = ZygoteVecJac(f, x; autodiff = AutoFiniteDiff())
update_coefficients!(L, v, nothing, 0.0)
@test L * v ≈ actual_vjp
#