diff --git a/ext/ForwardDiffStaticArraysExt.jl b/ext/ForwardDiffStaticArraysExt.jl index 63f841db..e97da9d5 100644 --- a/ext/ForwardDiffStaticArraysExt.jl +++ b/ext/ForwardDiffStaticArraysExt.jl @@ -3,7 +3,7 @@ module ForwardDiffStaticArraysExt using ForwardDiff, StaticArrays using ForwardDiff.LinearAlgebra using ForwardDiff.DiffResults -using ForwardDiff: Dual, partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk, +using ForwardDiff: Dual, partials, npartials, Partials, GradientConfig, JacobianConfig, HessianConfig, Tag, Chunk, gradient, hessian, jacobian, gradient!, hessian!, jacobian!, extract_gradient!, extract_jacobian!, extract_value!, vector_mode_gradient, vector_mode_gradient!, @@ -71,8 +71,9 @@ end @inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian!(result, f, x) @inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian!(result, f, x) -@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray} - M, N = length(ydual), length(x) +@generated function extract_jacobian(::Type{T}, ydual::Union{StaticArray,Partials}, x::S) where {T,S<:StaticArray} + M = ydual <: Partials ? npartials(ydual) : length(ydual) + N = length(x) result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...) return quote $(Expr(:meta, :inline)) diff --git a/test/HessianTest.jl b/test/HessianTest.jl index 4c667e5e..8be72ee5 100644 --- a/test/HessianTest.jl +++ b/test/HessianTest.jl @@ -163,4 +163,15 @@ end @test ForwardDiff.hessian(x->dot(x,H,x), zeros(3)) ≈ [2 6 10; 6 10 14; 10 14 18] end +#https://github.com/JuliaDiff/ForwardDiff.jl/issues/720 +@testset "allocation-free hessian with StaticArrays" begin + function hessian_allocs() + g = r -> (r[1]^2 - 3) * (r[2]^2 - 2) + x = SVector(0.5, 2.8) + hres = DiffResults.HessianResult(x) + return @allocated(ForwardDiff.hessian!(hres, g, x)) + end + @test iszero(hessian_allocs()) +end + end # module