Skip to content

Commit d79a00f

Browse files
Merge pull request #212 from vpuri3/scimlops
make AD operators scimloperators
2 parents a91d1da + f322ec7 commit d79a00f

13 files changed

+341
-117
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
.DS_Store
2+
.*.swp
3+
.*.swo
24
Manifest.toml
35
/dev/
46

57
docs/build/
6-
docs/site/
8+
docs/site/

Project.toml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
33
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
4-
version = "1.31.0"
4+
version = "2.00.0"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -13,8 +13,10 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1313
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
1414
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1515
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
16+
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
1617
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1718
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
19+
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
1820
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
1921

2022
[compat]
@@ -25,8 +27,10 @@ DataStructures = "0.18"
2527
FiniteDiff = "2.8.1"
2628
ForwardDiff = "0.10"
2729
Graphs = "1"
28-
Requires = "1.0"
30+
Requires = "1"
31+
SciMLOperators = "0.1.19"
2932
StaticArrays = "1"
33+
Tricks = "0.1.6"
3034
VertexSafeGraphs = "0.2"
3135
julia = "1.6"
3236

src/SparseDiffTools.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,12 @@ using DataStructures: DisjointSets, find_root!, union!
1919

2020
using ArrayInterface: matrix_colors
2121

22+
using SciMLOperators
23+
import SciMLOperators: update_coefficients, update_coefficients!
24+
using Tricks: static_hasmethod
25+
26+
abstract type AbstractAutoDiffVecProd end
27+
2228
export contract_color,
2329
greedy_d1,
2430
greedy_star1_coloring,
@@ -42,7 +48,8 @@ export contract_color,
4248
autonum_hesvec, autonum_hesvec!,
4349
num_hesvecgrad, num_hesvecgrad!,
4450
auto_hesvecgrad, auto_hesvecgrad!,
45-
JacVec, HesVec, HesVecGrad,
51+
JacVec, HesVec, HesVecGrad, VecJac,
52+
update_coefficients, update_coefficients!,
4653
value!
4754

4855
include("coloring/high_level.jl")
@@ -64,8 +71,10 @@ parameterless_type(x::Type) = __parameterless_type(x)
6471

6572
function __init__()
6673
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
67-
export numback_hesvec, numback_hesvec!, autoback_hesvec, autoback_hesvec!,
68-
auto_vecjac, auto_vecjac!
74+
export numback_hesvec, numback_hesvec!,
75+
autoback_hesvec, autoback_hesvec!,
76+
auto_vecjac, auto_vecjac!,
77+
ZygoteVecJac, ZygoteHesVec
6978

7079
include("differentiation/vecjac_products_zygote.jl")
7180
include("differentiation/jaches_products_zygote.jl")

src/differentiation/jaches_products.jl

Lines changed: 94 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -198,112 +198,126 @@ end
198198

199199
### Operator Forms
200200

201-
struct JacVec{F, T1, T2, xType}
201+
struct FwdModeAutoDiffVecProd{F,U,C,V,V!} <: AbstractAutoDiffVecProd
202202
f::F
203-
cache1::T1
204-
cache2::T2
205-
x::xType
206-
autodiff::Bool
203+
u::U
204+
cache::C
205+
vecprod::V
206+
vecprod!::V!
207207
end
208208

209-
function JacVec(f, x::AbstractArray, tag = DeivVecTag(); autodiff = true)
210-
if autodiff
211-
cache1 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
212-
}.(x, ForwardDiff.Partials.(tuple.(x)))
213-
cache2 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
214-
}.(x, ForwardDiff.Partials.(tuple.(x)))
215-
else
216-
cache1 = similar(x)
217-
cache2 = similar(x)
218-
end
219-
JacVec(f, cache1, cache2, x, autodiff)
209+
function update_coefficients(L::FwdModeAutoDiffVecProd, u, p, t)
210+
FwdModeAutoDiffVecProd(L.f, u, L.vecprod, L.vecprod!, L.cache)
220211
end
221212

222-
Base.eltype(L::JacVec) = eltype(L.x)
223-
Base.size(L::JacVec) = (length(L.cache1), length(L.cache1))
224-
Base.size(L::JacVec, i::Int) = length(L.cache1)
225-
function Base.:*(L::JacVec, v::AbstractVector)
226-
L.autodiff ? auto_jacvec(_x -> L.f(_x), L.x, v) :
227-
num_jacvec(_x -> L.f(_x), L.x, v)
213+
function update_coefficients!(L::FwdModeAutoDiffVecProd, u, p, t)
214+
copy!(L.u, u)
215+
L
228216
end
229217

230-
function LinearAlgebra.mul!(dy::AbstractVector, L::JacVec, v::AbstractVector)
231-
if L.autodiff
232-
auto_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2)
233-
else
234-
num_jacvec!(dy, (_y, _x) -> L.f(_y, _x), L.x, v, L.cache1, L.cache2)
235-
end
218+
function (L::FwdModeAutoDiffVecProd)(v, p, t)
219+
L.vecprod(L.f, L.u, v)
236220
end
237221

238-
struct HesVec{F, T1, T2, xType}
239-
f::F
240-
cache1::T1
241-
cache2::T2
242-
cache3::T2
243-
x::xType
244-
autodiff::Bool
222+
function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
223+
L.vecprod!(dv, L.f, L.u, v, L.cache...)
245224
end
246225

247-
function HesVec(f, x::AbstractArray; autodiff = true)
226+
function JacVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
227+
248228
if autodiff
249-
cache1 = ForwardDiff.GradientConfig(f, x)
250-
cache2 = similar(x)
251-
cache3 = similar(x)
229+
cache1 = Dual{
230+
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
231+
}.(u, ForwardDiff.Partials.(tuple.(u)))
232+
233+
cache2 = copy(cache1)
252234
else
253-
cache1 = similar(x)
254-
cache2 = similar(x)
255-
cache3 = similar(x)
235+
cache1 = similar(u)
236+
cache2 = similar(u)
256237
end
257-
HesVec(f, cache1, cache2, cache3, x, autodiff)
258-
end
259238

260-
Base.size(L::HesVec) = (length(L.cache2), length(L.cache2))
261-
Base.size(L::HesVec, i::Int) = length(L.cache2)
262-
function Base.:*(L::HesVec, v::AbstractVector)
263-
L.autodiff ? numauto_hesvec(L.f, L.x, v) : num_hesvec(L.f, L.x, v)
264-
end
239+
cache = (cache1, cache2,)
265240

266-
function LinearAlgebra.mul!(dy::AbstractVector, L::HesVec, v::AbstractVector)
267-
if L.autodiff
268-
numauto_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3)
269-
else
270-
num_hesvec!(dy, L.f, L.x, v, L.cache1, L.cache2, L.cache3)
241+
vecprod = autodiff ? auto_jacvec : num_jacvec
242+
vecprod! = autodiff ? auto_jacvec! : num_jacvec!
243+
244+
outofplace = static_hasmethod(f, typeof((u,)))
245+
isinplace = static_hasmethod(f, typeof((u, u,)))
246+
247+
if !(isinplace) & !(outofplace)
248+
error("$f must have signature f(u), or f(du, u).")
271249
end
272-
end
273250

274-
struct HesVecGrad{G, T1, T2, uType}
275-
g::G
276-
cache1::T1
277-
cache2::T2
278-
x::uType
279-
autodiff::Bool
251+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
252+
253+
FunctionOperator(L, u, u;
254+
isinplace = isinplace, outofplace = outofplace,
255+
p = p, t = t, islinear = true,
256+
)
280257
end
281258

282-
function HesVecGrad(g, x::AbstractArray, tag = DeivVecTag(); autodiff = false)
259+
function HesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
260+
283261
if autodiff
284-
cache1 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
285-
}.(x, ForwardDiff.Partials.(tuple.(x)))
286-
cache2 = Dual{typeof(ForwardDiff.Tag(tag, eltype(x))), eltype(x), 1
287-
}.(x, ForwardDiff.Partials.(tuple.(x)))
262+
cache1 = ForwardDiff.GradientConfig(f, u)
263+
cache2 = similar(u)
264+
cache3 = similar(u)
288265
else
289-
cache1 = similar(x)
290-
cache2 = similar(x)
266+
cache1 = similar(u)
267+
cache2 = similar(u)
268+
cache3 = similar(u)
291269
end
292-
HesVecGrad(g, cache1, cache2, x, autodiff)
293-
end
294270

295-
Base.size(L::HesVecGrad) = (length(L.cache2), length(L.cache2))
296-
Base.size(L::HesVecGrad, i::Int) = length(L.cache2)
297-
function Base.:*(L::HesVecGrad, v::AbstractVector)
298-
L.autodiff ? auto_hesvecgrad(L.g, L.x, v) : num_hesvecgrad(L.g, L.x, v)
271+
cache = (cache1, cache2, cache3,)
272+
273+
vecprod = autodiff ? numauto_hesvec : num_hesvec
274+
vecprod! = autodiff ? numauto_hesvec! : num_hesvec!
275+
276+
outofplace = static_hasmethod(f, typeof((u,)))
277+
isinplace = static_hasmethod(f, typeof((u,)))
278+
279+
if !(isinplace) & !(outofplace)
280+
error("$f must have signature f(u).")
281+
end
282+
283+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
284+
285+
FunctionOperator(L, u, u;
286+
isinplace = isinplace, outofplace = outofplace,
287+
p = p, t = t, islinear = true,
288+
)
299289
end
300290

301-
function LinearAlgebra.mul!(dy::AbstractVector,
302-
L::HesVecGrad,
303-
v::AbstractVector)
304-
if L.autodiff
305-
auto_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2)
291+
function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
292+
293+
if autodiff
294+
cache1 = Dual{
295+
typeof(ForwardDiff.Tag(DeivVecTag(), eltype(u))), eltype(u), 1
296+
}.(u, ForwardDiff.Partials.(tuple.(u)))
297+
298+
cache2 = copy(cache1)
306299
else
307-
num_hesvecgrad!(dy, L.g, L.x, v, L.cache1, L.cache2)
300+
cache1 = similar(u)
301+
cache2 = similar(u)
302+
end
303+
304+
cache = (cache1, cache2,)
305+
306+
vecprod = autodiff ? auto_hesvecgrad : num_hesvecgrad
307+
vecprod! = autodiff ? auto_hesvecgrad! : num_hesvecgrad!
308+
309+
outofplace = static_hasmethod(f, typeof((u,)))
310+
isinplace = static_hasmethod(f, typeof((u, u,)))
311+
312+
if !(isinplace) & !(outofplace)
313+
error("$f must have signature f(u), or f(du, u).")
308314
end
315+
316+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
317+
318+
FunctionOperator(L, u, u;
319+
isinplace = isinplace, outofplace = outofplace,
320+
p = p, t = t, islinear = true,
321+
)
309322
end
323+
#

src/differentiation/jaches_products_zygote.jl

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@ function numback_hesvec(f, x, v)
2525
end
2626

2727
function autoback_hesvec!(dy, f, x, v,
28-
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
28+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
2929
eltype(x), 1
3030
}.(x,
3131
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
32-
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
32+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
3333
eltype(x), 1
3434
}.(x,
3535
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))))
3636
g = let f = f
3737
(dx, x) -> dx .= first(Zygote.gradient(f, x))
3838
end
39-
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
39+
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
4040
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
41-
g(cache3, cache2)
42-
dy .= partials.(cache3, 1)
41+
g(cache2, cache1)
42+
dy .= partials.(cache2, 1)
4343
end
4444

4545
function autoback_hesvec(f, x, v)
@@ -48,3 +48,38 @@ function autoback_hesvec(f, x, v)
4848
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
4949
ForwardDiff.partials.(g(y), 1)
5050
end
51+
52+
### Operator Forms
53+
54+
function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
55+
56+
if autodiff
57+
cache1 = Dual{
58+
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
59+
}.(u, ForwardDiff.Partials.(tuple.(u)))
60+
cache2 = copy(u)
61+
else
62+
cache1 = similar(u)
63+
cache2 = similar(u)
64+
end
65+
66+
cache = (cache1, cache2,)
67+
68+
vecprod = autodiff ? autoback_hesvec : numback_hesvec
69+
vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!
70+
71+
outofplace = static_hasmethod(f, typeof((u,)))
72+
isinplace = static_hasmethod(f, typeof((u,)))
73+
74+
if !(isinplace) & !(outofplace)
75+
error("$f must have signature f(u).")
76+
end
77+
78+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
79+
80+
FunctionOperator(L, u, u;
81+
isinplace = isinplace, outofplace = outofplace,
82+
p = p, t = t, islinear = true,
83+
)
84+
end
85+
#

0 commit comments

Comments
 (0)