diff --git a/src/differentiation/compute_jacobian_ad.jl b/src/differentiation/compute_jacobian_ad.jl index d2ef9154..782ae35a 100644 --- a/src/differentiation/compute_jacobian_ad.jl +++ b/src/differentiation/compute_jacobian_ad.jl @@ -11,11 +11,14 @@ end getsize(::Val{N}) where N = N getsize(N::Integer) = N void_setindex!(args...) = (setindex!(args...); return) +gettag(::Type{ForwardDiff.Dual{T}}) where {T} = T const default_chunk_size = ForwardDiff.pickchunksize +const SMALLTAG = ForwardDiff.Tag(missing,Float64) function ForwardColorJacCache(f::F,x,_chunksize = nothing; dx = nothing, + tag = nothing, colorvec=1:length(x), sparsity::Union{AbstractArray,Nothing}=nothing) where {F} @@ -25,15 +28,21 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing; chunksize = _chunksize end + if tag === nothing + T = typeof(ForwardDiff.Tag(f,eltype(vec(x)))) + else + T = tag + end + if x isa Array p = generate_chunked_partials(x,colorvec,chunksize) - t = similar(x,Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x)))),eltype(x),length(first(first(p)))}) + t = similar(x,Dual{T}) for i in eachindex(t) - t[i] = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x)))),eltype(x),length(first(first(p)))}(x[i],ForwardDiff.Partials(first(p)[i])) + t[i] = Dual{T,eltype(x),length(first(first(p)))}(x[i],ForwardDiff.Partials(first(p)[i])) end else p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize)) - _t = Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(x),first(p)) + _t = Dual{T,eltype(x),getsize(chunksize)}.(vec(x),ForwardDiff.Partials.(first(p))) t = ArrayInterface.restructure(x,_t) end @@ -44,7 +53,7 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing; else tup = ArrayInterface.allowed_getindex(ArrayInterface.allowed_getindex(p,1),1) .* false _pi = adapt(parameterless_type(dx),[tup for i in 1:length(dx)]) - fx = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vec(x))))}.(vec(dx),_pi),size(dx)...) + fx = reshape(Dual{T,eltype(dx),length(tup)}.(vec(dx),ForwardDiff.Partials.(_pi)),size(dx)...) _dx = dx end @@ -162,7 +171,7 @@ function forwarddiff_color_jacobian(J::AbstractMatrix{<:Number},f::F,x::Abstract for i in eachindex(p) partial_i = p[i] - t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t)) + t = reshape(eltype(t).(vecx, ForwardDiff.Partials.(partial_i)),size(t)) fx = f(t) if !(sparsity isa Nothing) for j in 1:chunksize @@ -230,7 +239,7 @@ function forwarddiff_color_jacobian_immutable(f,x::AbstractArray{<:Number},jac_c for i in eachindex(p) partial_i = p[i] - t = reshape(Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i),size(t)) + t = reshape(eltype(t).(vecx, ForwardDiff.Partials.(partial_i)),size(t)) fx = f(t) if !(sparsity isa Nothing) for j in 1:chunksize @@ -311,10 +320,10 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number}, if vect isa Array @inbounds @simd ivdep for j in eachindex(vect) - vect[j] = Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}(vecx[j], partial_i[j]) + vect[j] = eltype(t)(vecx[j], ForwardDiff.Partials(partial_i[j])) end else - vect .= Dual{typeof(ForwardDiff.Tag(f,eltype(vecx)))}.(vecx, partial_i) + vect .= eltype(t).(vecx, ForwardDiff.Partials.(partial_i)) end f(fx,t)