Skip to content

Commit 15b01e5

Browse files
committed
zygote
1 parent 686827c commit 15b01e5

File tree

6 files changed

+86
-10
lines changed

6 files changed

+86
-10
lines changed

src/SparseDiffTools.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@ parameterless_type(x::Type) = __parameterless_type(x)
7171

7272
function __init__()
7373
@require Zygote="e88e6eb3-aa80-5325-afca-941959d7151f" begin
74-
export numback_hesvec, numback_hesvec!, autoback_hesvec, autoback_hesvec!,
75-
auto_vecjac, auto_vecjac!
74+
export numback_hesvec, numback_hesvec!,
75+
autoback_hesvec, autoback_hesvec!,
76+
auto_vecjac, auto_vecjac!,
77+
ZygoteVecJac, ZygoteHesVec
7678

7779
include("differentiation/vecjac_products_zygote.jl")
7880
include("differentiation/jaches_products_zygote.jl")

src/differentiation/jaches_products_zygote.jl

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@ function numback_hesvec(f, x, v)
2525
end
2626

2727
function autoback_hesvec!(dy, f, x, v,
28-
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
28+
cache1 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
2929
eltype(x), 1
3030
}.(x,
3131
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))),
32-
cache3 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
32+
cache2 = Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))),
3333
eltype(x), 1
3434
}.(x,
3535
ForwardDiff.Partials.(Tuple.(reshape(v, size(x))))))
3636
g = let f = f
3737
(dx, x) -> dx .= first(Zygote.gradient(f, x))
3838
end
39-
cache2 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
39+
cache1 .= Dual{typeof(ForwardDiff.Tag(DeivVecTag, eltype(x))), eltype(x), 1
4040
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
41-
g(cache3, cache2)
42-
dy .= partials.(cache3, 1)
41+
g(cache2, cache1)
42+
dy .= partials.(cache2, 1)
4343
end
4444

4545
function autoback_hesvec(f, x, v)
@@ -48,3 +48,38 @@ function autoback_hesvec(f, x, v)
4848
}.(x, ForwardDiff.Partials.(Tuple.(reshape(v, size(x)))))
4949
ForwardDiff.partials.(g(y), 1)
5050
end
51+
52+
### Operator Forms
53+
54+
function ZygoteHesVec(f, u::AbstractArray, p = nothing, t = nothing; autodiff = true)
55+
56+
if autodiff
57+
cache1 = Dual{
58+
typeof(ForwardDiff.Tag(DeivVecTag(),eltype(u))), eltype(u), 1
59+
}.(u, ForwardDiff.Partials.(tuple.(u)))
60+
cache2 = copy(u)
61+
else
62+
cache1 = similar(u)
63+
cache2 = similar(u)
64+
end
65+
66+
cache = (cache1, cache2,)
67+
68+
vecprod = autodiff ? autoback_hesvec : numback_hesvec
69+
vecprod! = autodiff ? autoback_hesvec! : numback_hesvec!
70+
71+
outofplace = static_hasmethod(f, typeof((u,)))
72+
isinplace = static_hasmethod(f, typeof((u,)))
73+
74+
if !(isinplace) & !(outofplace)
75+
error("$f must have signature f(u).")
76+
end
77+
78+
L = FwdModeAutoDiffVecProd(f, u, cache, vecprod, vecprod!)
79+
80+
FunctionOperator(L, u, u;
81+
isinplace = isinplace, outofplace = outofplace,
82+
p = p, t = t, islinear = true,
83+
)
84+
end
85+
#

src/differentiation/vecjac_products_zygote.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ function auto_vecjac(f, x, v)
77
vv, back = Zygote.pullback(f, x)
88
return vec(back(reshape(v, size(vv)))[1])
99
end
10+
11+
const ZygoteVecJac = VecJac
12+
#

test/test_ad.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@ using SparseArrays, Test
44
using LinearAlgebra
55
using BlockBandedMatrices, ArrayInterfaceBlockBandedMatrices
66
using BandedMatrices, ArrayInterfaceBandedMatrices
7-
using StaticArrays
8-
using ArrayInterfaceStaticArrays
7+
using StaticArrays#, StaticArrayInterface
98

109
fcalls = 0
1110
function f(dx, x)

test/test_jaches_products.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,33 @@ dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)≈a*numauto_hesvec(g,x,v)+b*_dy r
113113
out = similar(v)
114114
gmres!(out, L, v)
115115

116+
@info "ZygoteHesVec"
117+
using Zygote
118+
x = rand(N)
119+
v = rand(N)
120+
121+
L = ZygoteHesVec(g, x, autodiff = false)
122+
@test L * x numback_hesvec(g, x, x) rtol = 1e-2
123+
@test L * v numback_hesvec(g, x, v) rtol = 1e-2
124+
@test mul!(dy, L, v)numback_hesvec(g, x, v) rtol=1e-2
125+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2
126+
update_coefficients!(L, v, nothing, 0.0)
127+
@test mul!(dy, L, v)numback_hesvec(g, v, v) rtol=1e-2
128+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b) a*numback_hesvec(g,x,v) + b*_dy rtol=1e-2
129+
130+
L = HesVec(g, x)
131+
@test L * x autoback_hesvec(g, x, x)
132+
@test L * v autoback_hesvec(g, x, v)
133+
@test mul!(dy, L, v)autoback_hesvec(g, x, v) rtol=1e-8
134+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
135+
update_coefficients!(L, v, nothing, 0.0)
136+
@test mul!(dy, L, v)autoback_hesvec(g, v, v) rtol=1e-8
137+
dy=rand(N);_dy=copy(dy);@test mul!(dy,L,v,a,b)a*autoback_hesvec(g,x,v)+b*_dy rtol=1e-8
138+
139+
out = similar(v)
140+
gmres!(out, L, v)
141+
142+
116143
@info "HesVecGrad"
117144

118145
x = rand(N)

test/test_vecjac_products.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ v = rand(Float32, N)
1212
f(du,u,p,t) = mul!(du, A, u)
1313
f(u,p,t) = A * u
1414

15-
# VecJac
15+
@info "VecJac"
1616

1717
L = VecJac(f, x)
1818
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
@@ -21,4 +21,14 @@ update_coefficients!(L, v, nothing, 0.0)
2121
L = VecJac(f, x; autodiff = false)
2222
update_coefficients!(L, v, nothing, 0.0)
2323
@test L * v actual_vjp
24+
25+
@info "ZygoteVecJac"
26+
27+
L = ZygoteVecJac(f, x)
28+
actual_vjp = Zygote.jacobian(x -> f(x, nothing, 0.0), x)[1]' * v
29+
update_coefficients!(L, v, nothing, 0.0)
30+
@test L * v actual_vjp
31+
L = ZygoteVecJac(f, x; autodiff = false)
32+
update_coefficients!(L, v, nothing, 0.0)
33+
@test L * v actual_vjp
2434
#

0 commit comments

Comments
 (0)