Skip to content

Commit 0e6685d

Browse files
Merge pull request #30 from JuliaDiffEq/jacvecs
implement jacvec products
2 parents de24917 + 4546782 commit 0e6685d

File tree

5 files changed

+370
-6
lines changed

5 files changed

+370
-6
lines changed

Project.toml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
name = "SparseDiffTools"
22
uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
3-
authors = ["Pankaj Mishra <[email protected]>","Chris Rackauckas <[email protected]>"]
3+
authors = ["Pankaj Mishra <[email protected]>", "Chris Rackauckas <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
7+
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
8+
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
9+
DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa"
710
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
811
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
912
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1013
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1114
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"
12-
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
13-
BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0"
1415

1516
[compat]
1617
julia = "1"
1718

1819
[extras]
1920
DiffEqDiffTools = "01453d9d-ee7c-5054-8395-0335cb756afa"
21+
IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153"
2022
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2123

2224
[targets]
23-
test = ["Test", "DiffEqDiffTools"]
25+
test = ["Test", "DiffEqDiffTools", "IterativeSolvers"]

src/SparseDiffTools.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SparseDiffTools
22

3-
using SparseArrays, LinearAlgebra, BandedMatrices, BlockBandedMatrices, LightGraphs, VertexSafeGraphs
3+
using SparseArrays, LinearAlgebra, BandedMatrices, BlockBandedMatrices,
4+
LightGraphs, VertexSafeGraphs, DiffEqDiffTools, ForwardDiff
45
using BlockBandedMatrices:blocksize,nblocks
56
using ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
67

@@ -9,12 +10,21 @@ export contract_color,
910
matrix2graph,
1011
matrix_colors,
1112
forwarddiff_color_jacobian!,
12-
ForwardColorJacCache
13+
ForwardColorJacCache,
14+
auto_jacvec,auto_jacvec!,
15+
num_jacvec,num_jacvec!,
16+
num_hesvec,num_hesvec!,
17+
numauto_hesvec,numauto_hesvec!,
18+
autonum_hesvec,autonum_hesvec!,
19+
num_hesvecgrad,num_hesvecgrad!,
20+
auto_hesvecgrad,auto_hesvecgrad!,
21+
JacVec,HesVec,HesVecGrad
1322

1423
include("coloring/high_level.jl")
1524
include("coloring/contraction_coloring.jl")
1625
include("coloring/greedy_d1_coloring.jl")
1726
include("coloring/matrix2graph.jl")
1827
include("differentiation/compute_jacobian_ad.jl")
28+
include("differentiation/jaches_products.jl")
1929

2030
end # module
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
struct DeivVecTag end
2+
3+
# J(f(x))*v
4+
function auto_jacvec!(du, f, x, v,
5+
cache1 = ForwardDiff.Dual{DeivVecTag}.(x, v),
6+
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v))
7+
cache1 .= Dual{DeivVecTag}.(x, v)
8+
f(cache2,cache1)
9+
du .= partials.(cache2, 1)
10+
end
11+
function auto_jacvec(f, x, v)
12+
partials.(f(Dual{DeivVecTag}.(x, v)), 1)
13+
end
14+
15+
function num_jacvec!(du,f,x,v,cache1 = similar(v),
16+
cache2 = similar(v);
17+
compute_f0 = true)
18+
compute_f0 && (f(cache1,x))
19+
T = eltype(x)
20+
# Should it be min? max? mean?
21+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
22+
@. x += ϵ*v
23+
f(cache2,x)
24+
@. x -= ϵ*v
25+
@. du = (cache2 - cache1)/ϵ
26+
end
27+
28+
function num_jacvec(f,x,v,f0=nothing)
29+
f0 === nothing ? _f0 = f(x) : _f0 = f0
30+
T = eltype(x)
31+
# Should it be min? max? mean?
32+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(minimum(x)))
33+
(f(x.+ϵ.*v) .- f(x))./ϵ
34+
end
35+
36+
function num_hesvec!(du,f,x,v,
37+
cache1 = similar(v),
38+
cache2 = similar(v),
39+
cache3 = similar(v))
40+
cache = DiffEqDiffTools.GradientCache(v[1],cache1,Val{:central})
41+
g = let f=f,cache=cache
42+
(dx,x) -> DiffEqDiffTools.finite_difference_gradient!(dx,f,x,cache)
43+
end
44+
T = eltype(x)
45+
# Should it be min? max? mean?
46+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
47+
@. x += ϵ*v
48+
g(cache2,x)
49+
@. x -= 2ϵ*v
50+
g(cache3,x)
51+
@. du = (cache2 - cache3)/(2ϵ)
52+
end
53+
54+
function num_hesvec(f,x,v)
55+
g = (x) -> DiffEqDiffTools.finite_difference_gradient(f,x)
56+
T = eltype(x)
57+
# Should it be min? max? mean?
58+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
59+
x += ϵ*v
60+
gxp = g(x)
61+
x -= 2ϵ*v
62+
gxm = g(x)
63+
(gxp - gxm)/(2ϵ)
64+
end
65+
66+
function numauto_hesvec!(du,f,x,v,
67+
cache = ForwardDiff.GradientConfig(f,v),
68+
cache1 = similar(v),
69+
cache2 = similar(v))
70+
g = let f=f,x=x,cache=cache
71+
g = (dx,x) -> ForwardDiff.gradient!(dx,f,x,cache)
72+
end
73+
T = eltype(x)
74+
# Should it be min? max? mean?
75+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
76+
@. x += ϵ*v
77+
g(cache1,x)
78+
@. x -= 2ϵ*v
79+
g(cache2,x)
80+
@. du = (cache1 - cache2)/(2ϵ)
81+
end
82+
83+
function numauto_hesvec(f,x,v)
84+
g = (x) -> ForwardDiff.gradient(f,x)
85+
T = eltype(x)
86+
# Should it be min? max? mean?
87+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
88+
x += ϵ*v
89+
gxp = g(x)
90+
x -= 2ϵ*v
91+
gxm = g(x)
92+
(gxp - gxm)/(2ϵ)
93+
end
94+
95+
function autonum_hesvec!(du,f,x,v,
96+
cache1 = similar(v),
97+
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
98+
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v))
99+
cache = DiffEqDiffTools.GradientCache(v[1],cache1,Val{:central})
100+
g = (dx,x) -> DiffEqDiffTools.finite_difference_gradient!(dx,f,x,cache)
101+
cache2 .= Dual{DeivVecTag}.(x, v)
102+
g(cache3,cache2)
103+
du .= partials.(cache3, 1)
104+
end
105+
106+
function autonum_hesvec(f,x,v)
107+
g = (x) -> DiffEqDiffTools.finite_difference_gradient(f,x)
108+
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
109+
end
110+
111+
function num_hesvecgrad!(du,g,x,v,
112+
cache2 = similar(v),
113+
cache3 = similar(v))
114+
T = eltype(x)
115+
# Should it be min? max? mean?
116+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
117+
@. x += ϵ*v
118+
g(cache2,x)
119+
@. x -= 2ϵ*v
120+
g(cache3,x)
121+
@. du = (cache2 - cache3)/(2ϵ)
122+
end
123+
124+
function num_hesvecgrad(g,x,v)
125+
T = eltype(x)
126+
# Should it be min? max? mean?
127+
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
128+
x += ϵ*v
129+
gxp = g(x)
130+
x -= 2ϵ*v
131+
gxm = g(x)
132+
(gxp - gxm)/(2ϵ)
133+
end
134+
135+
function auto_hesvecgrad!(du,g,x,v,
136+
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
137+
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v))
138+
cache2 .= Dual{DeivVecTag}.(x, v)
139+
g(cache3,cache2)
140+
du .= partials.(cache3, 1)
141+
end
142+
143+
function auto_hesvecgrad(g,x,v)
144+
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
145+
end
146+
147+
### Operator Forms
148+
149+
mutable struct JacVec{F,T1,T2,uType}
150+
f::F
151+
cache1::T1
152+
cache2::T2
153+
u::uType
154+
autodiff::Bool
155+
end
156+
157+
function JacVec(f,u::AbstractArray;autodiff=true)
158+
if autodiff
159+
cache1 = ForwardDiff.Dual{DeivVecTag}.(u, u)
160+
cache2 = ForwardDiff.Dual{DeivVecTag}.(u, u)
161+
else
162+
cache1 = similar(u)
163+
cache2 = similar(u)
164+
end
165+
JacVec(f,cache1,cache2,u,autodiff)
166+
end
167+
168+
Base.size(L::JacVec) = (length(L.cache1),length(L.cache1))
169+
Base.size(L::JacVec,i::Int) = length(L.cache1)
170+
Base.:*(L::JacVec,x::AbstractVector) = L.autodiff ? auto_jacvec(_u->L.f(_u),L.u,x) : num_jacvec(_u->L.f(_u),L.u,x)
171+
172+
function LinearAlgebra.mul!(du::AbstractVector,L::JacVec,v::AbstractVector)
173+
if L.autodiff
174+
auto_jacvec!(du,(_du,_u)->L.f(_du,_u),L.u,v,L.cache1,L.cache2)
175+
else
176+
num_jacvec!(du,(_du,_u)->L.f(_du,_u),L.u,v,L.cache1,L.cache2)
177+
end
178+
end
179+
180+
mutable struct HesVec{F,T1,T2,uType}
181+
f::F
182+
cache1::T1
183+
cache2::T2
184+
cache3::T2
185+
u::uType
186+
autodiff::Bool
187+
end
188+
189+
function HesVec(f,u::AbstractArray;autodiff=true)
190+
if autodiff
191+
cache1 = ForwardDiff.GradientConfig(f,u)
192+
cache2 = similar(u)
193+
cache3 = similar(u)
194+
else
195+
cache1 = similar(u)
196+
cache2 = similar(u)
197+
cache3 = similar(u)
198+
end
199+
HesVec(f,cache1,cache2,cache3,u,autodiff)
200+
end
201+
202+
Base.size(L::HesVec) = (length(L.cache2),length(L.cache2))
203+
Base.size(L::HesVec,i::Int) = length(L.cache2)
204+
Base.:*(L::HesVec,x::AbstractVector) = L.autodiff ? numauto_hesvec(L.f,L.u,x) : num_hesvec(L.f,L.u,x)
205+
206+
function LinearAlgebra.mul!(du::AbstractVector,L::HesVec,v::AbstractVector)
207+
if L.autodiff
208+
numauto_hesvec!(du,L.f,L.u,v,L.cache1,L.cache2,L.cache3)
209+
else
210+
num_hesvec!(du,L.f,L.u,v,L.cache1,L.cache2,L.cache3)
211+
end
212+
end
213+
214+
mutable struct HesVecGrad{G,T1,T2,uType}
215+
g::G
216+
cache1::T1
217+
cache2::T2
218+
u::uType
219+
autodiff::Bool
220+
end
221+
222+
function HesVecGrad(g,u::AbstractArray;autodiff=true)
223+
if autodiff
224+
cache1 = ForwardDiff.Dual{DeivVecTag}.(u, u)
225+
cache2 = ForwardDiff.Dual{DeivVecTag}.(u, u)
226+
else
227+
cache1 = similar(u)
228+
cache2 = similar(u)
229+
end
230+
HesVecGrad(g,cache1,cache2,u,autodiff)
231+
end
232+
233+
Base.size(L::HesVecGrad) = (length(L.cache2),length(L.cache2))
234+
Base.size(L::HesVecGrad,i::Int) = length(L.cache2)
235+
Base.:*(L::HesVecGrad,x::AbstractVector) = L.autodiff ? auto_hesvecgrad(L.g,L.u,x) : num_hesvecgrad(L.g,L.u,x)
236+
237+
function LinearAlgebra.mul!(du::AbstractVector,L::HesVecGrad,v::AbstractVector)
238+
if L.autodiff
239+
auto_hesvecgrad!(du,L.g,L.u,v,L.cache1,L.cache2)
240+
else
241+
num_hesvecgrad!(du,L.g,L.u,v,L.cache1,L.cache2)
242+
end
243+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ using Test
88
@testset "AD using color vector" begin include("test_ad.jl") end
99
@testset "Integration test" begin include("test_integration.jl") end
1010
@testset "Special matrices" begin include("test_specialmatrices.jl") end
11+
@testset "Jac Vecs and Hes Vecs" begin include("test_jaches_products.jl") end

0 commit comments

Comments
 (0)