Skip to content

Commit 3be742b

Browse files
Merge pull request #221 from vpuri3/ad
multivalue autodiff
2 parents 06dac8f + b824cb0 commit 3be742b

File tree

7 files changed

+64
-104
lines changed

7 files changed

+64
-104
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <cont
44
version = "2.0.0"
55

66
[deps]
7+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
78
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
89
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
910
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

ext/SparseDiffToolsZygote.jl

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,13 @@ module SparseDiffToolsZygote
33
if isdefined(Base, :get_extension)
44
import Zygote
55
using LinearAlgebra
6-
using SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
6+
using SparseDiffTools: SparseDiffTools, DeivVecTag
77
using ForwardDiff: ForwardDiff, Dual, partials
8-
using SciMLOperators: FunctionOperator
9-
using Tricks: static_hasmethod
108
else
119
import ..Zygote
1210
using ..LinearAlgebra
13-
using ..SparseDiffTools: SparseDiffTools, DeivVecTag, FwdModeAutoDiffVecProd, VecJac
11+
using ..SparseDiffTools: SparseDiffTools, DeivVecTag
1412
using ..ForwardDiff: ForwardDiff, Dual, partials
15-
using ..SciMLOperators: FunctionOperator
16-
using ..Tricks: static_hasmethod
1713
end
1814

1915
### Jac, Hes products
@@ -69,40 +65,6 @@ function SparseDiffTools.autoback_hesvec(f, x, v)
6965
ForwardDiff.partials.(g(y), 1)
7066
end
7167

72-
# Operator Forms
73-
74-
function SparseDiffTools.ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
75-
76-
if autodiff
77-
cache1 = Dual{
78-
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
79-
}.(u, ForwardDiff.Partials.(tuple.(u)))
80-
cache2 = copy(u)
81-
else
82-
cache1 = similar(u)
83-
cache2 = similar(u)
84-
end
85-
86-
cache = (cache1, cache2,)
87-
88-
vecprod = autodiff ? SparseDiffTools.autoback_hesvec : SparseDiffTools.numback_hesvec
89-
vecprod! = autodiff ? SparseDiffTools.autoback_hesvec! : SparseDiffTools.numback_hesvec!
90-
91-
outofplace = static_hasmethod(f, typeof((u,)))
92-
isinplace = static_hasmethod(f, typeof((u,)))
93-
94-
if !(isinplace) & !(outofplace)
95-
error("$f must have signature f(u).")
96-
end
97-
98-
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
99-
100-
FunctionOperator(L, u, u;
101-
isinplace = isinplace, outofplace = outofplace,
102-
p = p, t = t, islinear = true,
103-
)
104-
end
105-
10668
## VecJac products
10769

10870
function SparseDiffTools.auto_vecjac!(du, f, x, v, cache1 = nothing, cache2 = nothing)
@@ -115,8 +77,4 @@ function SparseDiffTools.auto_vecjac(f, x, v)
11577
return vec(back(reshape(v, size(vv)))[1])
11678
end
11779

118-
function SparseDiffTools.ZygoteVecJac(args...; autodiff = true, kwargs...)
119-
VecJac(args...; autodiff = autodiff, kwargs...)
120-
end
121-
12280
end # module

src/SparseDiffTools.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using Graphs
77
using Graphs: SimpleGraph
88
using VertexSafeGraphs
99
using Adapt
10+
using Reexport
11+
@reexport using ADTypes
1012

1113
using LinearAlgebra
1214
using SparseArrays, ArrayInterface
@@ -69,30 +71,26 @@ parameterless_type(x) = parameterless_type(typeof(x))
6971
parameterless_type(x::Type) = __parameterless_type(x)
7072

7173
import Requires
72-
import Reexport
7374

7475
function numback_hesvec end
7576
function numback_hesvec! end
7677
function autoback_hesvec end
7778
function autoback_hesvec! end
7879
function auto_vecjac end
7980
function auto_vecjac! end
80-
function ZygoteVecJac end
81-
function ZygoteHesVec end
8281

8382
@static if !isdefined(Base, :get_extension)
8483
function __init__()
8584
Requires.@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin
8685
include("../ext/SparseDiffToolsZygote.jl")
87-
Reexport.@reexport using .SparseDiffToolsZygote
86+
@reexport using .SparseDiffToolsZygote
8887
end
8988
end
9089
end
9190

9291
export
9392
numback_hesvec, numback_hesvec!,
9493
autoback_hesvec, autoback_hesvec!,
95-
auto_vecjac, auto_vecjac!,
96-
ZygoteVecJac, ZygoteHesVec
94+
auto_vecjac, auto_vecjac!
9795

9896
end # module

src/differentiation/jaches_products.jl

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -223,24 +223,25 @@ function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
223223
L.vecprod!(dv, L.f, L.u, v, L.cache...)
224224
end
225225

226-
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
226+
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())
227227

228-
if autodiff
228+
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
229+
cache1 = similar(u)
230+
cache2 = similar(u)
231+
232+
(cache1, cache2), num_jacvec, num_jacvec!
233+
elseif autodiff isa AutoForwardDiff
229234
cache1 = Dual{
230235
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
231236
}.(u, ForwardDiff.Partials.(tuple.(u)))
232237

233238
cache2 = copy(cache1)
239+
240+
(cache1, cache2), auto_jacvec, auto_jacvec!
234241
else
235-
cache1 = similar(u)
236-
cache2 = similar(u)
242+
@error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
237243
end
238244

239-
cache = (cache1, cache2,)
240-
241-
vecprod = autodiff ? auto_jacvec : num_jacvec
242-
vecprod! = autodiff ? auto_jacvec! : num_jacvec!
243-
244245
outofplace = static_hasmethod(f, typeof((u,)))
245246
isinplace = static_hasmethod(f, typeof((u, u,)))
246247

@@ -256,22 +257,32 @@ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
256257
)
257258
end
258259

259-
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
260+
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())
260261

261-
if autodiff
262-
cache1 = ForwardDiff.GradientConfig(f, u)
262+
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
263+
cache1 = similar(u)
263264
cache2 = similar(u)
264265
cache3 = similar(u)
265-
else
266-
cache1 = similar(u)
266+
267+
(cache1, cache2, cache3), num_hesvec, num_hesvec!
268+
elseif autodiff isa AutoForwardDiff
269+
cache1 = ForwardDiff.GradientConfig(f, u)
267270
cache2 = similar(u)
268271
cache3 = similar(u)
269-
end
270272

271-
cache = (cache1, cache2, cache3,)
273+
(cache1, cache2, cache3), numauto_hesvec, numauto_hesvec!
274+
elseif autodiff isa AutoZygote
275+
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
272276

273-
vecprod = autodiff ? numauto_hesvec : num_hesvec
274-
vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
277+
cache1 = Dual{
278+
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
279+
}.(u, ForwardDiff.Partials.(tuple.(u)))
280+
cache2 = copy(u)
281+
282+
(cache1, cache2), autoback_hesvec, autoback_hesvec!
283+
else
284+
@error("Set autodiff to either AutoForwardDiff(), AutoZygote(), or AutoFiniteDiff()")
285+
end
275286

276287
outofplace = static_hasmethod(f, typeof((u,)))
277288
isinplace = static_hasmethod(f, typeof((u,)))
@@ -288,24 +299,24 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
288299
)
289300
end
290301

291-
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
302+
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoForwardDiff())
292303

293-
if autodiff
304+
cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff
305+
cache1 = similar(u)
306+
cache2 = similar(u)
307+
308+
(cache1, cache2), num_hesvecgrad, num_hesvecgrad!
309+
elseif autodiff isa AutoForwardDiff
294310
cache1 = Dual{
295311
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
296312
}.(u, ForwardDiff.Partials.(tuple.(u)))
297-
298313
cache2 = copy(cache1)
314+
315+
(cache1, cache2), auto_hesvecgrad, auto_hesvecgrad!
299316
else
300-
cache1 = similar(u)
301-
cache2 = similar(u)
317+
@error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
302318
end
303319

304-
cache = (cache1, cache2,)
305-
306-
vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
307-
vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
308-
309320
outofplace = static_hasmethod(f, typeof((u,)))
310321
isinplace = static_hasmethod(f, typeof((u, u,)))
311322

src/differentiation/vecjac_products.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@ struct RevModeAutoDiffVecProd{ad,iip,oop,F,U,C,V,V!} <: AbstractAutoDiffVecProd
4444
vecprod::V
4545
vecprod!::V!
4646

47-
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!; autodiff = false,
47+
function RevModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!;
48+
autodiff = AutoFiniteDiff(),
4849
isinplace = false, outofplace = true)
4950
@assert isinplace || outofplace
5051

5152
new{
52-
autodiff,
53+
typeof(autodiff),
5354
isinplace,
5455
outofplace,
5556
typeof(f),
@@ -86,18 +87,19 @@ function (L::RevModeAutoDiffVecProd{ad,true,false})(dv, v, p, t) where{ad}
8687
L.vecprod!(dv, (_du, _u) -> L.f(_du, _u, p, t), L.u, v, L.cache...)
8788
end
8889

89-
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = false,
90+
function VecJac(f, u::AbstractArray, p = nothing, t = nothing; autodiff = AutoFiniteDiff(),
9091
ishermitian = false, opnrom = true)
9192

92-
if autodiff
93-
@assert isdefined(SparseDiffTools, :auto_vecjac) "Please load Zygote with `using Zygote`, or `import Zygote` to use VecJac with `autodiff = true`."
93+
vecprod, vecprod! = if autodiff isa AutoFiniteDiff
94+
num_vecjac, num_vecjac!
95+
elseif autodiff isa AutoZygote
96+
@assert static_hasmethod(auto_vecjac, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
97+
98+
auto_vecjac, auto_vecjac!
9499
end
95100

96101
cache = (similar(u), similar(u),)
97102

98-
vecprod = autodiff ? auto_vecjac : num_vecjac
99-
vecprod! = autodiff ? auto_vecjac! : num_vecjac!
100-
101103
outofplace = static_hasmethod(f, typeof((u, p, t)))
102104
isinplace = static_hasmethod(f, typeof((u, u, p, t)))
103105

test/test_jaches_products.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ update_coefficients!(L, v, nothing, 0.0)
7676
@test mul!(dy, L, v) auto_jacvec(f, v, v)
7777
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*auto_jacvec(f,x,v) + b*_dy
7878

79-
L = JacVec(f, x, autodiff = false)
79+
L = JacVec(f, x, autodiff = AutoFiniteDiff())
8080
@test L * x num_jacvec(f, x, x)
8181
@test L * v num_jacvec(f, x, v)
8282
@test mul!(dy, L, v)num_jacvec(f, x, v) rtol=1e-6
@@ -92,7 +92,7 @@ gmres!(out, L, v)
9292

9393
x = rand(N)
9494
v = rand(N)
95-
L = HesVec(g, x, autodiff = false)
95+
L = HesVec(g, x, autodiff = AutoFiniteDiff())
9696
@test L * x num_hesvec(g, x, x)
9797
@test L * v num_hesvec(g, x, v)
9898
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
@@ -113,21 +113,12 @@ dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy r
113113
out = similar(v)
114114
gmres!(out, L, v)
115115

116-
@info "ZygoteHesVec"
117116
using Zygote
117+
118118
x = rand(N)
119119
v = rand(N)
120120

121-
L = ZygoteHesVec(g, x, autodiff = false)
122-
@test L * x numback_hesvec(g, x, x) rtol = 1e-2
123-
@test L * v numback_hesvec(g, x, v) rtol = 1e-2
124-
@test mul!(dy, L, v)numback_hesvec(g, x, v) rtol=1e-2
125-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2
126-
update_coefficients!(L, v, nothing, 0.0)
127-
@test mul!(dy, L, v)numback_hesvec(g, v, v) rtol=1e-2
128-
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2
129-
130-
L = ZygoteHesVec(g, x)
121+
L = HesVec(g, x, autodiff = AutoZygote())
131122
@test L * x autoback_hesvec(g, x, x)
132123
@test L * v autoback_hesvec(g, x, v)
133124
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8
@@ -139,12 +130,11 @@ dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*autoback_hesvec(g,x,v)+b*_dy
139130
out = similar(v)
140131
gmres!(out, L, v)
141132

142-
143133
@info "HesVecGrad"
144134

145135
x = rand(N)
146136
v = rand(N)
147-
L = HesVecGrad(h, x, autodiff = false)
137+
L = HesVecGrad(h, x, autodiff = AutoFiniteDiff())
148138
@test L * x num_hesvec(g, x, x)
149139
@test L * v num_hesvec(g, x, v)
150140
@test mul!(dy, L, v)num_hesvec(g, x, v) rtol=1e-2
@@ -153,7 +143,7 @@ update_coefficients!(L, v, nothing, 0.0)
153143
@test mul!(dy, L, v)num_hesvec(g, v, v) rtol=1e-2
154144
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*num_hesvec(g,x,v)+b*_dy rtol=1e-2
155145

156-
L = HesVecGrad(h, x, autodiff = true)
146+
L = HesVecGrad(h, x)
157147
@test L * x autonum_hesvec(g, x, x)
158148
@test L * v numauto_hesvec(g, x, v)
159149
@test mul!(dy, L, v)numauto_hesvec(g, x, v) rtol=1e-8

test/test_vecjac_products.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ L = VecJac(f, x)
1818
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
1919
update_coefficients!(L, v, nothing, 0.0)
2020
@test L * v actual_vjp
21-
L = VecJac(f, x; autodiff = false)
21+
L = VecJac(f, x; autodiff = AutoFiniteDiff())
2222
update_coefficients!(L, v, nothing, 0.0)
2323
@test L * v actual_vjp
2424

@@ -28,7 +28,7 @@ L = ZygoteVecJac(f, x)
2828
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
2929
update_coefficients!(L, v, nothing, 0.0)
3030
@test L * v actual_vjp
31-
L = ZygoteVecJac(f, x; autodiff = false)
31+
L = ZygoteVecJac(f, x; autodiff = AutoFiniteDiff())
3232
update_coefficients!(L, v, nothing, 0.0)
3333
@test L * v actual_vjp
3434
#

0 commit comments

Comments
 (0)