Skip to content

Commit fdcaadb

Browse files
longemen3000KristofferC
authored andcommitted
improve performance of hessians with static arrays
1 parent 42e0aa6 commit fdcaadb

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

ext/ForwardDiffStaticArraysExt.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ end
8181
end
8282
end
8383

84+
@generated function extract_jacobian(::Type{T}, ydual::Partials{M}, x::S) where {M, T, S<:StaticArray}
85+
N = length(x)
86+
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
87+
return quote
88+
$(Expr(:meta, :inline))
89+
V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N))
90+
return V($result)
91+
end
92+
end
93+
8494
@inline function ForwardDiff.vector_mode_jacobian(f::F, x::StaticArray) where {F}
8595
T = typeof(Tag(f, eltype(x)))
8696
return extract_jacobian(T, static_dual_eval(T, f, x), x)

test/HessianTest.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,12 @@ end
163163
@test ForwardDiff.hessian(x->dot(x,H,x), zeros(3)) [2 6 10; 6 10 14; 10 14 18]
164164
end
165165

166+
@testset "allocation-free hessian with StaticArrays" begin
167+
#https://github.com/JuliaDiff/ForwardDiff.jl/issues/720
168+
g = r -> (r[1]^2 - 3) * (r[2]^2 - 2)
169+
x = SA_F32[0.5, 2.7]
170+
hres = DiffResults.HessianResult(x)
171+
@test @allocated(ForwardDiff.hessian!(hres, g, x))) == 0
172+
end
173+
166174
end # module

0 commit comments

Comments
 (0)