From dfd30d8188069b65f931c02c503b7127d93a5aa6 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 16 Sep 2022 21:50:33 -0400 Subject: [PATCH 1/3] using StaticArraysCore --- Project.toml | 8 +++++--- src/ForwardDiff.jl | 3 ++- src/apiutils.jl | 2 +- src/dual.jl | 4 ++-- src/gradient.jl | 2 +- src/jacobian.jl | 2 +- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index a75b616d..75786509 100644 --- a/Project.toml +++ b/Project.toml @@ -13,7 +13,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [compat] Calculus = "0.5" @@ -25,7 +25,8 @@ LogExpFunctions = "0.3" NaNMath = "1" Preferences = "1" SpecialFunctions = "1, 2" -StaticArrays = "1.5" +StaticArrays = "1.5.7" +StaticArraysCore = "1.3.0" julia = "1.6" [extras] @@ -33,7 +34,8 @@ Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Calculus", "DiffTests", "SparseArrays", "Test", "InteractiveUtils"] +test = ["Calculus", "DiffTests", "SparseArrays", "StaticArrays", "Test", "InteractiveUtils"] diff --git a/src/ForwardDiff.jl b/src/ForwardDiff.jl index 93d3b246..7af0993a 100644 --- a/src/ForwardDiff.jl +++ b/src/ForwardDiff.jl @@ -2,7 +2,8 @@ module ForwardDiff using DiffRules, DiffResults using DiffResults: DiffResult, MutableDiffResult, ImmutableDiffResult -using StaticArrays +using StaticArraysCore +using StaticArraysCore: StaticArray, StaticMatrix if VERSION >= v"1.6" using Preferences end diff --git a/src/apiutils.jl b/src/apiutils.jl index 971c368c..90052263 100644 --- a/src/apiutils.jl +++ b/src/apiutils.jl @@ -21,7 +21,7 @@ end @generated function dualize(::Type{T}, x::StaticArray) where T N = length(x) dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...) - V = StaticArrays.similar_type(x, Dual{T,eltype(x),N}) + V = StaticArraysCore.similar_type(x, Dual{T,eltype(x),N}) return quote chunk = Chunk{$N}() $(Expr(:meta, :inline)) diff --git a/src/dual.jl b/src/dual.jl index 7e86e9b2..a3cd64b7 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -731,7 +731,7 @@ function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N Dual{Tg}.(λ, tuple.(parts...)) end -function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N} +function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArraysCore.StaticMatrix}) where {Tg,T<:Real,N} λ,Q = eigen(Symmetric(value.(parent(A)))) parts = ntuple(j -> diag(Q' * getindex.(partials.(A), j) * Q), N) Dual{Tg}.(λ, tuple.(parts...)) @@ -766,7 +766,7 @@ function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}}) where {Tg,T<:Real,N} Eigen(λ,Dual{Tg}.(Q, tuple.(parts...))) end -function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N} +function LinearAlgebra.eigen(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArraysCore.StaticMatrix}) where {Tg,T<:Real,N} λ = eigvals(A) _,Q = eigen(Symmetric(value.(parent(A)))) parts = ntuple(j -> Q*_lyap_div!(Q' * getindex.(partials.(A), j) * Q - Diagonal(getindex.(partials.(λ), j)), value.(λ)), N) diff --git a/src/gradient.jl b/src/gradient.jl index f9f173eb..1b52e1c7 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -59,7 +59,7 @@ gradient(f, x::Real) = throw(DimensionMismatch("gradient(f, x) expects that x is result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...) return quote $(Expr(:meta, :inline)) - V = StaticArrays.similar_type(S, valtype($y)) + V = StaticArraysCore.similar_type(S, valtype($y)) return V($result) end end diff --git a/src/jacobian.jl b/src/jacobian.jl index bcda61d7..aceb2476 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -101,7 +101,7 @@ jacobian(f, x::Real) = throw(DimensionMismatch("jacobian(f, x) expects that x is result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...) return quote $(Expr(:meta, :inline)) - V = StaticArrays.similar_type(S, valtype(eltype($ydual)), Size($M, $N)) + V = StaticArraysCore.similar_type(S, valtype(eltype($ydual)), Size($M, $N)) return V($result) end end From a2b200091698112a6599afd129489c8fc347189e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 11 Dec 2022 22:56:49 -0500 Subject: [PATCH 2/3] work around length problem --- src/apiutils.jl | 5 ++++- src/gradient.jl | 3 ++- src/jacobian.jl | 6 ++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/apiutils.jl b/src/apiutils.jl index 90052263..47e2a072 100644 --- a/src/apiutils.jl +++ b/src/apiutils.jl @@ -19,7 +19,7 @@ end ################################### @generated function dualize(::Type{T}, x::StaticArray) where T - N = length(x) + N = _static_length(StaticArraysCore.Size(x)) dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...) V = StaticArraysCore.similar_type(x, Dual{T,eltype(x),N}) return quote @@ -29,6 +29,9 @@ end end end +# This works around length(::Type{StaticArray}) not being defined in this world-age: +_static_length(::StaticArraysCore.Size{s}) where {s} = StaticArraysCore.tuple_prod(s) + @inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x)) function vector_mode_dual_eval!(f::F, cfg::Union{JacobianConfig,GradientConfig}, x) where {F} diff --git a/src/gradient.jl b/src/gradient.jl index 1b52e1c7..4f6208a7 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -56,7 +56,8 @@ gradient(f, x::Real) = throw(DimensionMismatch("gradient(f, x) expects that x is ##################### @generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray} - result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...) + N = _static_length(StaticArraysCore.Size(S)) + result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:N]...) return quote $(Expr(:meta, :inline)) V = StaticArraysCore.similar_type(S, valtype($y)) diff --git a/src/jacobian.jl b/src/jacobian.jl index aceb2476..00655cf6 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -96,8 +96,10 @@ jacobian(f, x::Real) = throw(DimensionMismatch("jacobian(f, x) expects that x is # result extraction # ##################### -@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray} - M, N = length(ydual), length(x) +@generated function extract_jacobian(::Type{T}, ydual::Sy, x::S) where {T, Sy<:StaticArray, S<:StaticArray} + # M, N = length(ydual), length(x) + M = _static_length(StaticArraysCore.Size(Sy)) + N = _static_length(StaticArraysCore.Size(S)) result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...) return quote $(Expr(:meta, :inline)) From aeae29b46be09e2d2e3189bf5bf02f26d15fb42c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 11 Dec 2022 23:25:32 -0500 Subject: [PATCH 3/3] move similar_type inside quote --- src/apiutils.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/apiutils.jl b/src/apiutils.jl index 47e2a072..9616df2a 100644 --- a/src/apiutils.jl +++ b/src/apiutils.jl @@ -18,14 +18,14 @@ end # vector mode function evaluation # ################################### -@generated function dualize(::Type{T}, x::StaticArray) where T - N = _static_length(StaticArraysCore.Size(x)) +@generated function dualize(::Type{T}, x::S) where {T, S<:StaticArray} + N = _static_length(StaticArraysCore.Size(S)) dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...) - V = StaticArraysCore.similar_type(x, Dual{T,eltype(x),N}) return quote + V = StaticArraysCore.similar_type(S, Dual{$T, $(eltype(x)), $N}) chunk = Chunk{$N}() $(Expr(:meta, :inline)) - return $V($(dx)) + return V($(dx)) end end