Skip to content

Commit a2b2000

Browse files
committed
work around length problem
1 parent dfd30d8 commit a2b2000

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

src/apiutils.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ end
1919
###################################
2020

2121
@generated function dualize(::Type{T}, x::StaticArray) where T
22-
N = length(x)
22+
N = _static_length(StaticArraysCore.Size(x))
2323
dx = Expr(:tuple, [:(Dual{T}(x[$i], chunk, Val{$i}())) for i in 1:N]...)
2424
V = StaticArraysCore.similar_type(x, Dual{T,eltype(x),N})
2525
return quote
@@ -29,6 +29,9 @@ end
2929
end
3030
end
3131

32+
# This works around length(::Type{StaticArray}) not being defined in this world-age:
33+
_static_length(::StaticArraysCore.Size{s}) where {s} = StaticArraysCore.tuple_prod(s)
34+
3235
@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
3336

3437
function vector_mode_dual_eval!(f::F, cfg::Union{JacobianConfig,GradientConfig}, x) where {F}

src/gradient.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ gradient(f, x::Real) = throw(DimensionMismatch("gradient(f, x) expects that x is
5656
#####################
5757

5858
@generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray}
59-
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
59+
N = _static_length(StaticArraysCore.Size(S))
60+
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:N]...)
6061
return quote
6162
$(Expr(:meta, :inline))
6263
V = StaticArraysCore.similar_type(S, valtype($y))

src/jacobian.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ jacobian(f, x::Real) = throw(DimensionMismatch("jacobian(f, x) expects that x is
9696
# result extraction #
9797
#####################
9898

99-
@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
100-
M, N = length(ydual), length(x)
99+
@generated function extract_jacobian(::Type{T}, ydual::Sy, x::S) where {T, Sy<:StaticArray, S<:StaticArray}
100+
# M, N = length(ydual), length(x)
101+
M = _static_length(StaticArraysCore.Size(Sy))
102+
N = _static_length(StaticArraysCore.Size(S))
101103
result = Expr(:tuple, [:(partials(T, ydual[$i], $j)) for i in 1:M, j in 1:N]...)
102104
return quote
103105
$(Expr(:meta, :inline))

0 commit comments

Comments
 (0)