Skip to content

Commit 1024c54

Browse files
Merge pull request #206 from Vaibhavdixit02/customtaghesscache
Move tag to an arg in `ForwardAutoColorHesCache` to allow passing custom tags
2 parents da1da48 + c0b9ada commit 1024c54

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

src/differentiation/compute_hessian_ad.jl

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,14 @@ function make_hessian_buffers(colorvec, x)
1919
return (;ncolors, D, buffer, G1, G2)
2020
end
2121

22-
function ForwardColorHesCache(f,
23-
x::AbstractVector{<:Number},
24-
colorvec::AbstractVector{<:Integer}=eachindex(x),
22+
function ForwardColorHesCache(f,
23+
x::AbstractVector{<:Number},
24+
colorvec::AbstractVector{<:Integer}=eachindex(x),
2525
sparsity::Union{AbstractMatrix, Nothing}=nothing,
2626
g! = (G, x, grad_config) -> ForwardDiff.gradient!(G, f, x, grad_config))
2727
ncolors, D, buffer, G, G2 = make_hessian_buffers(colorvec, x)
2828
grad_config = ForwardDiff.GradientConfig(f, x)
29-
29+
3030
# If user supplied their own gradient function, make sure it has the right
3131
# signature (i.e. g!(G, x) or g!(G, x, grad_config::ForwardDiff.GradientConfig))
3232
if ! hasmethod(g!, (typeof(G), typeof(G), typeof(grad_config)))
@@ -38,16 +38,16 @@ function ForwardColorHesCache(f,
3838
else
3939
g1! = g!
4040
end
41-
41+
4242
if sparsity === nothing
4343
sparsity = sparse(ones(length(x), length(x)))
4444
end
4545
return ForwardColorHesCache(sparsity, colorvec, ncolors, D, buffer, g1!, grad_config, G, G2)
4646
end
4747

48-
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
49-
f,
50-
x::AbstractArray{<:Number},
48+
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
49+
f,
50+
x::AbstractArray{<:Number},
5151
hes_cache::ForwardColorHesCache;
5252
safe = true)
5353
ϵ = cbrt(eps(eltype(x)))
@@ -69,18 +69,18 @@ function numauto_color_hessian!(H::AbstractMatrix{<:Number},
6969
return H
7070
end
7171

72-
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
73-
f,
72+
function numauto_color_hessian!(H::AbstractMatrix{<:Number},
73+
f,
7474
x::AbstractArray{<:Number},
75-
colorvec::AbstractVector{<:Integer}=eachindex(x),
75+
colorvec::AbstractVector{<:Integer}=eachindex(x),
7676
sparsity::Union{AbstractMatrix, Nothing}=nothing)
7777
hes_cache = ForwardColorHesCache(f, x, colorvec, sparsity)
7878
numauto_color_hessian!(H, f, x, hes_cache)
7979
return H
8080
end
8181

82-
function numauto_color_hessian(f,
83-
x::AbstractArray{<:Number},
82+
function numauto_color_hessian(f,
83+
x::AbstractArray{<:Number},
8484
hes_cache::ForwardColorHesCache)
8585
H = convert.(eltype(x), hes_cache.sparsity)
8686
numauto_color_hessian!(H, f, x, hes_cache)
@@ -89,7 +89,7 @@ end
8989

9090
function numauto_color_hessian(f,
9191
x::AbstractArray{<:Number},
92-
colorvec::AbstractVector{<:Integer}=eachindex(x),
92+
colorvec::AbstractVector{<:Integer}=eachindex(x),
9393
sparsity::Union{AbstractMatrix, Nothing}=nothing)
9494
hes_cache = ForwardColorHesCache(f, x, colorvec, sparsity)
9595
H = convert.(eltype(x), hes_cache.sparsity)
@@ -113,13 +113,13 @@ struct AutoAutoTag end
113113
function ForwardAutoColorHesCache(f,
114114
x::AbstractVector{V},
115115
colorvec::AbstractVector{<:Integer}=eachindex(x),
116-
sparsity::Union{AbstractMatrix,Nothing}=nothing) where V
116+
sparsity::Union{AbstractMatrix,Nothing}=nothing,
117+
tag::ForwardDiff.Tag = ForwardDiff.Tag(AutoAutoTag(), V)) where V
117118

118119
if sparsity === nothing
119120
sparsity = sparse(ones(length(x), length(x)))
120121
end
121122

122-
tag = ForwardDiff.Tag(AutoAutoTag(), V)
123123
chunksize = ForwardDiff.pickchunksize(maximum(colorvec))
124124
chunk = ForwardDiff.Chunk(chunksize)
125125

@@ -130,7 +130,7 @@ function ForwardAutoColorHesCache(f,
130130
g! = (G, x) -> ForwardDiff.gradient!(G, f, x, gradient_config, Val(false))
131131

132132
jac_cache = ForwardColorJacCache(g!, x; colorvec, sparsity, tag=outer_tag)
133-
133+
134134
return ForwardAutoColorHesCache(jac_cache, g!, sparsity, colorvec)
135135
end
136136

0 commit comments

Comments
 (0)