Skip to content

Commit 2fec1ca

Browse files
Merge pull request #35 from pkj-m/chunksize
Limit size of partials by chunk size
2 parents 371823d + 448df5b commit 2fec1ca

File tree

1 file changed

+50
-25
lines changed

1 file changed

+50
-25
lines changed

src/differentiation/compute_jacobian_ad.jl

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,34 @@ end
4444
generate_chunked_partials(x,color,N::Integer) = generate_chunked_partials(x,color,Val(N))
4545
function generate_chunked_partials(x,color,::Val{N}) where N
4646

47-
# TODO: should only go up to the chunksize each time, and should
48-
# generate p[i] different parts, each with less than the chunksize
47+
chunksize = getsize(default_chunk_size(maximum(color)))
48+
num_of_chunks = Int64(ceil(maximum(color) / chunksize))
49+
50+
padding_size = (chunksize - (maximum(color) % chunksize)) % chunksize
51+
52+
partials = BitMatrix(undef, length(x), maximum(color))
53+
partial = BitMatrix(undef, length(x), chunksize)
54+
chunked_partials = Array{Array{NTuple,1},1}(undef, num_of_chunks)
4955

50-
partials_array = BitMatrix(undef, length(x), maximum(color))
5156
for color_i in 1:maximum(color)
52-
for i in eachindex(x)
53-
if color[i]==color_i
54-
partials_array[i,color_i] = true
55-
else
56-
partials_array[i,color_i] = false
57-
end
57+
for j in 1:length(x)
58+
partials[j, color_i] = color[j] == color_i
59+
end
60+
end
61+
62+
padding_matrix = BitMatrix(undef, length(x), padding_size)
63+
partials = hcat(partials, padding_matrix)
64+
65+
for i in 1:num_of_chunks
66+
partial[:,1] .= partials[:,(i-1)*chunksize+1]
67+
for j in 2:chunksize
68+
partial[:,j] .= partials[:,(i-1)*chunksize+j]
5869
end
70+
chunked_partials[i] = Tuple.(eachrow(partial))
5971
end
60-
p = Tuple.(eachrow(partials_array))
72+
73+
chunked_partials
74+
6175
end
6276

6377
function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
@@ -79,23 +93,34 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
7993
p = jac_cache.p
8094
color = jac_cache.color
8195

82-
# TODO: Should compute on each p[i] and decompress
83-
t .= Dual{typeof(f)}.(x, p)
84-
f(fx, t)
85-
86-
if J isa SparseMatrixCSC
87-
rows_index, cols_index, val = findnz(J)
88-
for color_i in 1:maximum(color)
89-
dx .= partials.(fx,color_i)
90-
for i in 1:length(cols_index)
91-
if color[cols_index[i]]==color_i
92-
J[rows_index[i],cols_index[i]] = dx[rows_index[i]]
96+
color_i = 1
97+
chunksize = getsize(default_chunk_size(maximum(color)))
98+
99+
for i in 1:length(p)
100+
partial_i = p[i]
101+
t .= Dual{typeof(f)}.(x, partial_i)
102+
f(fx,t)
103+
104+
if J isa SparseMatrixCSC
105+
rows_index, cols_index, val = findnz(J)
106+
for j in 1:chunksize
107+
dx .= partials.(fx, j)
108+
for k in 1:length(cols_index)
109+
if color[cols_index[k]] == color_i
110+
J[rows_index[k], cols_index[k]] = dx[rows_index[k]]
111+
end
112+
end
113+
color_i += 1
114+
end
115+
else
116+
for j in 1:chunksize
117+
col_index = (i-1)*chunksize + j
118+
J[:, col_index] .= partials.(fx, color_i)
119+
color_i += 1
120+
if color_i == maximum(color) + 1
121+
color_i = 1
93122
end
94123
end
95-
end
96-
else # Compute the compressed version
97-
for color_i in 1:maximum(color)
98-
J[:,i] .= partials.(fx,color_i)
99124
end
100125
end
101126
nothing

0 commit comments

Comments
 (0)