Skip to content

Commit ebbce79

Browse files
Merge pull request #198 from baggepinnen/autoauto
RFC: add auto-auto hessian
2 parents 108df97 + b51681f commit ebbce79

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

src/SparseDiffTools.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@ export contract_color,
3030
ForwardColorJacCache,
3131
numauto_color_hessian!,
3232
numauto_color_hessian,
33+
autoauto_color_hessian!,
34+
autoauto_color_hessian,
3335
ForwardColorHesCache,
36+
ForwardAutoColorHesCache,
3437
auto_jacvec,auto_jacvec!,
3538
num_jacvec,num_jacvec!,
3639
num_vecjac,num_vecjac!,

src/differentiation/compute_hessian_ad.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,76 @@ function numauto_color_hessian(f,
9696
numauto_color_hessian!(H, f, x, hes_cache)
9797
return H
9898
end
99+
100+
101+
102+
## autoauto_color_hessian
103+
104+
mutable struct ForwardAutoColorHesCache{TJC,TG,TS,TC}
105+
jac_cache::TJC
106+
grad!::TG
107+
sparsity::TS
108+
colorvec::TC
109+
end
110+
111+
struct AutoAutoTag end
112+
113+
function ForwardAutoColorHesCache(f,
114+
x::AbstractVector{V},
115+
colorvec::AbstractVector{<:Integer}=eachindex(x),
116+
sparsity::Union{AbstractMatrix,Nothing}=nothing) where V
117+
118+
if sparsity === nothing
119+
sparsity = sparse(ones(length(x), length(x)))
120+
end
121+
122+
tag = ForwardDiff.Tag(AutoAutoTag(), V)
123+
chunksize = ForwardDiff.pickchunksize(maximum(colorvec))
124+
chunk = ForwardDiff.Chunk(chunksize)
125+
126+
jacobian_config = ForwardDiff.JacobianConfig(f, x, chunk, tag)
127+
gradient_config = ForwardDiff.GradientConfig(f, jacobian_config.duals, chunk, tag)
128+
129+
outer_tag = get_tag(jacobian_config.duals)
130+
g! = (G, x) -> ForwardDiff.gradient!(G, f, x, gradient_config, Val(false))
131+
132+
jac_cache = ForwardColorJacCache(g!, x; colorvec, sparsity, tag=outer_tag)
133+
134+
return ForwardAutoColorHesCache(jac_cache, g!, sparsity, colorvec)
135+
end
136+
137+
function autoauto_color_hessian!(H::AbstractMatrix{<:Number},
138+
f,
139+
x::AbstractArray{<:Number},
140+
hes_cache::ForwardAutoColorHesCache)
141+
142+
forwarddiff_color_jacobian!(H, hes_cache.grad!, x, hes_cache.jac_cache)
143+
end
144+
145+
function autoauto_color_hessian!(H::AbstractMatrix{<:Number},
146+
f,
147+
x::AbstractArray{<:Number},
148+
colorvec::AbstractVector{<:Integer}=eachindex(x),
149+
sparsity::Union{AbstractMatrix,Nothing}=nothing)
150+
hes_cache = ForwardAutoColorHesCache(f, x, colorvec, sparsity)
151+
autoauto_color_hessian!(H, f, x, hes_cache)
152+
return H
153+
end
154+
155+
function autoauto_color_hessian(f,
156+
x::AbstractArray{<:Number},
157+
hes_cache::ForwardAutoColorHesCache)
158+
H = convert.(eltype(x), hes_cache.sparsity)
159+
autoauto_color_hessian!(H, f, x, hes_cache)
160+
return H
161+
end
162+
163+
function autoauto_color_hessian(f,
164+
x::AbstractArray{<:Number},
165+
colorvec::AbstractVector{<:Integer}=eachindex(x),
166+
sparsity::Union{AbstractMatrix,Nothing}=nothing)
167+
hes_cache = ForwardAutoColorHesCache(f, x, colorvec, sparsity)
168+
H = convert.(eltype(x), hes_cache.sparsity)
169+
autoauto_color_hessian!(H, f, x, hes_cache)
170+
return H
171+
end

test/test_sparse_hessian.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,30 @@ for (i, hescache) in enumerate([hescache1, hescache2, hescache3, hescache4, hesc
9898
# for _ in 1:100)
9999
# @test t_unsafe <= t_safe
100100
end
101+
102+
103+
hescache1 = ForwardAutoColorHesCache(fscalar, x, colors, sparsity)
104+
hescache2 = ForwardAutoColorHesCache(fscalar, x)
105+
106+
107+
for (i, hescache) in enumerate([hescache1, hescache2])
108+
109+
H = SparseDiffTools.autoauto_color_hessian(fscalar, x, colors, sparsity)
110+
H1 = SparseDiffTools.autoauto_color_hessian(fscalar, x, hescache)
111+
H2 = SparseDiffTools.autoauto_color_hessian(fscalar, x)
112+
@test all(isapprox.(Hforward, H, rtol=1e-6))
113+
@test all(isapprox.(H, H1, rtol=1e-6))
114+
@test all(isapprox.(H2, H1, rtol=1e-6))
115+
116+
H1 = similar(H)
117+
118+
SparseDiffTools.autoauto_color_hessian!(H1, fscalar, x, collect(hescache.colorvec), hescache.sparsity)
119+
@test all(isapprox.(H1, H))
120+
121+
SparseDiffTools.autoauto_color_hessian!(H2, fscalar, x)
122+
@test all(isapprox.(H2, H))
123+
124+
SparseDiffTools.autoauto_color_hessian!(H1, fscalar, x, hescache)
125+
@test all(isapprox.(H1, H))
126+
127+
end

0 commit comments

Comments
 (0)