Skip to content

Commit 16789e2

Browse files
Merge pull request #196 from baggepinnen/patch-1
function barrier reduces impact of type instability
2 parents 2e9dfed + 9c75bb6 commit 16789e2

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,8 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
3636

3737
if x isa Array
3838
p = generate_chunked_partials(x,colorvec,chunksize)
39-
t = Array{Dual{T,eltype(x),length(first(first(p)))}}(undef,size(x))
40-
for i in eachindex(t)
41-
t[i] = Dual{T,eltype(x),length(first(first(p)))}(x[i],ForwardDiff.Partials(first(p)[i]))
42-
end
39+
DT = Dual{T,eltype(x),length(first(first(p)))}
40+
t = _get_t(DT, x, p)
4341
else
4442
p = adapt.(parameterless_type(x),generate_chunked_partials(x,colorvec,chunksize))
4543
_t = Dual{T,eltype(x),getsize(chunksize)}.(vec(x),ForwardDiff.Partials.(first(p)))
@@ -60,6 +58,15 @@ function ForwardColorJacCache(f::F,x,_chunksize = nothing;
6058
ForwardColorJacCache(t,fx,_dx,p,colorvec,sparsity,getsize(chunksize))
6159
end
6260

61+
# Function barrier for unknown constructor type
62+
function _get_t(::Type{DT}, x, p) where DT
63+
t = similar(x, DT)
64+
for i in eachindex(t)
65+
t[i] = DT(x[i],ForwardDiff.Partials(first(p)[i]))
66+
end
67+
t
68+
end
69+
6370
generate_chunked_partials(x,colorvec,N::Integer) = generate_chunked_partials(x,colorvec,Val(N))
6471
function generate_chunked_partials(x,colorvec,cs::Val{chunksize}) where chunksize
6572
maxcolor = maximum(colorvec)

0 commit comments

Comments
 (0)