Skip to content

Commit 442afbb

Browse files
Merge pull request #91 from huanglangwen/oop4sp
fix oop jacobian for SparseMatrixCSC
2 parents 95e4b27 + 73591f7 commit 442afbb

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,15 @@ function forwarddiff_color_jacobian(f,x::AbstractArray{<:Number},jac_cache::Forw
108108
pick_inds = [i for i in 1:length(rows_index) if colorvec[cols_index[i]] == color_i]
109109
rows_index_c = rows_index[pick_inds]
110110
cols_index_c = cols_index[pick_inds]
111-
len_rows = length(pick_inds)
112-
unused_rows = setdiff(1:nrows,rows_index_c)
113-
perm_rows = sortperm(vcat(rows_index_c,unused_rows))
114-
cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows]
115-
Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols]
111+
if J isa SparseMatrixCSC
112+
Ji = sparse(rows_index_c, cols_index_c, dx[rows_index_c],nrows,ncols)
113+
else
114+
len_rows = length(pick_inds)
115+
unused_rows = setdiff(1:nrows,rows_index_c)
116+
perm_rows = sortperm(vcat(rows_index_c,unused_rows))
117+
cols_index_c = vcat(cols_index_c,zeros(Int,nrows-len_rows))[perm_rows]
118+
Ji = [j==cols_index_c[i] ? dx[i] : false for i in 1:nrows, j in 1:ncols]
119+
end
116120
J = J + Ji
117121
color_i += 1
118122
(color_i > maxcolor) && return J

test/test_ad.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ forwarddiff_color_jacobian!(_J1, f, x, colorvec = repeat(1:3,10))
106106
fcalls = 0
107107
_J1 = forwarddiff_color_jacobian(oopf, x, colorvec = repeat(1:3,10), sparsity = _J, jac_prototype = _J)
108108
@test _J1 J
109+
@test typeof(_J1) == typeof(_J)
109110
@test fcalls == 1
110111

111112
@info "third passed"
@@ -146,6 +147,7 @@ _nsqJ = forwarddiff_color_jacobian(nsqf, x, colorvec = repeat(1:3,10), sparsity
146147
@test _nsqJ nsqJ
147148
_nsqJ = forwarddiff_color_jacobian(nsqf, x, jac_prototype = SMatrix{15,30}(nsqJ))
148149
@test _nsqJ nsqJ
150+
@test typeof(_nsqJ) == typeof(SMatrix{15,30}(nsqJ))
149151
_nsqJ = forwarddiff_color_jacobian(staticnsqf, SVector{30}(x), jac_prototype = SMatrix{15,30}(nsqJ))
150152
@test _nsqJ nsqJ
151153
_nsqJ = forwarddiff_color_jacobian(staticnsqf, SVector{30}(x), jac_prototype = SMatrix{15,30}(nsqJ), colorvec = repeat(1:3,10), sparsity = spnsqJ)

0 commit comments

Comments
 (0)