44
44
generate_chunked_partials (x,color,N:: Integer ) = generate_chunked_partials (x,color,Val (N))
45
45
function generate_chunked_partials (x,color,:: Val{N} ) where N
46
46
47
- # TODO : should only go up to the chunksize each time, and should
48
- # generate p[i] different parts, each with less than the chunksize
47
+ chunksize = getsize (default_chunk_size (maximum (color)))
48
+ num_of_chunks = Int64 (ceil (maximum (color) / chunksize))
49
+
50
+ padding_size = (chunksize - (maximum (color) % chunksize)) % chunksize
51
+
52
+ partials = BitMatrix (undef, length (x), maximum (color))
53
+ partial = BitMatrix (undef, length (x), chunksize)
54
+ chunked_partials = Array {Array{NTuple,1},1} (undef, num_of_chunks)
49
55
50
- partials_array = BitMatrix (undef, length (x), maximum (color))
51
56
for color_i in 1 : maximum (color)
52
- for i in eachindex (x)
53
- if color[i]== color_i
54
- partials_array[i,color_i] = true
55
- else
56
- partials_array[i,color_i] = false
57
- end
57
+ for j in 1 : length (x)
58
+ partials[j, color_i] = color[j] == color_i
59
+ end
60
+ end
61
+
62
+ padding_matrix = BitMatrix (undef, length (x), padding_size)
63
+ partials = hcat (partials, padding_matrix)
64
+
65
+ for i in 1 : num_of_chunks
66
+ partial[:,1 ] .= partials[:,(i- 1 )* chunksize+ 1 ]
67
+ for j in 2 : chunksize
68
+ partial[:,j] .= partials[:,(i- 1 )* chunksize+ j]
58
69
end
70
+ chunked_partials[i] = Tuple .(eachrow (partial))
59
71
end
60
- p = Tuple .(eachrow (partials_array))
72
+
73
+ chunked_partials
74
+
61
75
end
62
76
63
77
function forwarddiff_color_jacobian! (J:: AbstractMatrix{<:Number} ,
@@ -79,23 +93,34 @@ function forwarddiff_color_jacobian!(J::AbstractMatrix{<:Number},
79
93
p = jac_cache. p
80
94
color = jac_cache. color
81
95
82
- # TODO : Should compute on each p[i] and decompress
83
- t .= Dual {typeof(f)} .(x, p)
84
- f (fx, t)
85
-
86
- if J isa SparseMatrixCSC
87
- rows_index, cols_index, val = findnz (J)
88
- for color_i in 1 : maximum (color)
89
- dx .= partials .(fx,color_i)
90
- for i in 1 : length (cols_index)
91
- if color[cols_index[i]]== color_i
92
- J[rows_index[i],cols_index[i]] = dx[rows_index[i]]
96
+ color_i = 1
97
+ chunksize = getsize (default_chunk_size (maximum (color)))
98
+
99
+ for i in 1 : length (p)
100
+ partial_i = p[i]
101
+ t .= Dual {typeof(f)} .(x, partial_i)
102
+ f (fx,t)
103
+
104
+ if J isa SparseMatrixCSC
105
+ rows_index, cols_index, val = findnz (J)
106
+ for j in 1 : chunksize
107
+ dx .= partials .(fx, j)
108
+ for k in 1 : length (cols_index)
109
+ if color[cols_index[k]] == color_i
110
+ J[rows_index[k], cols_index[k]] = dx[rows_index[k]]
111
+ end
112
+ end
113
+ color_i += 1
114
+ end
115
+ else
116
+ for j in 1 : chunksize
117
+ col_index = (i- 1 )* chunksize + j
118
+ J[:, col_index] .= partials .(fx, color_i)
119
+ color_i += 1
120
+ if color_i == maximum (color) + 1
121
+ color_i = 1
93
122
end
94
123
end
95
- end
96
- else # Compute the compressed version
97
- for color_i in 1 : maximum (color)
98
- J[:,i] .= partials .(fx,color_i)
99
124
end
100
125
end
101
126
nothing
0 commit comments