Skip to content

Commit 507177a

Browse files
Merge pull request #40 from pkj-m/row-partition
Partition matrix by rows
2 parents 8e1fb66 + ccabdc3 commit 507177a

File tree

4 files changed

+52
-20
lines changed

4 files changed

+52
-20
lines changed

README.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,19 @@ this can be significant savings.
9393
The API for computing the color vector is:
9494

9595
```julia
96-
matrix_colors(A::AbstractMatrix,alg::ColoringAlgorithm = GreedyD1Color())
96+
matrix_colors(A::AbstractMatrix,alg::ColoringAlgorithm = GreedyD1Color(); partition_by_rows::Bool = false)
9797
```
9898

9999
The first argument is the abstract matrix which represents the sparsity pattern
100100
of the Jacobian. The second argument is the optional choice of coloring algorithm.
101101
It will default to a greedy distance 1 coloring, though if your special matrix
102102
type has more information, like is a `Tridiagonal` or `BlockBandedMatrix`, the
103-
color vector will be analytically calculated instead.
103+
color vector will be analytically calculated instead. The variable argument
104+
`partition_by_rows` allows you to partition the Jacobian on the basis of rows instead
105+
of columns and generate a corresponding coloring vector which can be used for
106+
reverse-mode AD. Default value is false.
104107

105-
The result is a vector which assigns a color to each row of the matrix.
108+
The result is a vector which assigns a color to each column (or row) of the matrix.
106109

107110
### Color-Assisted Differentiation
108111

@@ -195,14 +198,14 @@ autonum_hesvec(f,x,v)
195198
numback_hesvec!(du,f,x,v,
196199
cache1 = similar(v),
197200
cache2 = similar(v))
198-
201+
199202
numback_hesvec(f,x,v)
200203

201204
# Currently errors! See https://github.com/FluxML/Zygote.jl/issues/241
202205
autoback_hesvec!(du,f,x,v,
203206
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
204207
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v))
205-
208+
206209
autoback_hesvec(f,x,v)
207210
```
208211

@@ -211,7 +214,7 @@ the former almost always being more efficient and is thus recommended. `numback`
211214
`autoback` methods are numerical/ForwardDiff over reverse mode automatic differentiation
212215
respectively, where the reverse-mode AD is provided by Zygote.jl. Currently these methods
213216
are not competitive against `numauto`, but as Zygote.jl gets optimized these will likely
214-
be the fastest.
217+
be the fastest.
215218

216219
In addition,
217220
the following forms allow you to provide a gradient function `g(dx,x)` or `dx=g(x)`

src/coloring/high_level.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@ struct ContractionColor <: ColoringAlgorithm end
1111
The coloring defaults to a greedy distance-1 coloring.
1212
1313
"""
14-
function matrix_colors(A::AbstractMatrix,alg::ColoringAlgorithm = GreedyD1Color())
14+
function matrix_colors(A::AbstractMatrix,alg::ColoringAlgorithm = GreedyD1Color(); partition_by_rows::Bool = false)
1515
_A = A isa SparseMatrixCSC ? A : sparse(A) # Avoid the copy
16-
A_graph = matrix2graph(_A)
16+
A_graph = matrix2graph(_A, partition_by_rows)
1717
color_graph(A_graph,alg)
1818
end
1919

2020
"""
2121
matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})
2222
23-
The color vector for dense matrix and triangular matrix is simply
23+
The color vector for dense matrix and triangular matrix is simply
2424
`[1,2,3,...,size(A,2)]`
2525
"""
2626
function matrix_colors(A::Union{Array,UpperTriangular,LowerTriangular})
@@ -77,4 +77,4 @@ function matrix_colors(A::BandedBlockBandedMatrix)
7777
startinds=[endinds[i]-ncolors[i]+1 for i in 1:blockwidth]
7878
colors=[_cycle(startinds[blockcolors[i]]:endinds[blockcolors[i]],cols[i]) for i in 1:nblock]
7979
vcat(colors...)
80-
end
80+
end

src/coloring/matrix2graph.jl

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,23 +6,39 @@ sparse matrix, columns are represented with vertices
66
and 2 vertices are connected with an edge only if
77
the two columns are mutually orthogonal.
88
"""
9-
function matrix2graph(SparseMatrix::SparseMatrixCSC{T,Int}) where T<:Number
9+
function matrix2graph(SparseMatrix::SparseMatrixCSC{T,Int}, partition_by_rows::Bool) where T<:Number
1010
dropzeros(SparseMatrix)
1111
(rows_index, cols_index, val) = findnz(SparseMatrix)
1212

13-
V = cols = size(SparseMatrix, 2)
13+
cols = size(SparseMatrix, 2)
1414
rows = size(SparseMatrix, 1)
1515

16+
partition_by_rows ? V = rows : V = cols
17+
1618
inner = SimpleGraph(V)
1719
graph = VSafeGraph(inner)
1820

19-
for i = 1:length(cols_index)
20-
cur_col = cols_index[i]
21-
for j = 1:(i-1)
22-
next_col = cols_index[j]
23-
if cur_col != next_col
24-
if rows_index[i] == rows_index[j]
25-
add_edge!(graph, cur_col, next_col)
21+
if partition_by_rows
22+
for i = 1:length(rows_index)
23+
cur_row = rows_index[i]
24+
for j = 1:(i-1)
25+
next_row = rows_index[j]
26+
if cur_row != next_row
27+
if cols_index[i] == cols_index[j]
28+
add_edge!(graph, cur_row, next_row)
29+
end
30+
end
31+
end
32+
end
33+
else
34+
for i = 1:length(cols_index)
35+
cur_col = cols_index[i]
36+
for j = 1:(i-1)
37+
next_col = cols_index[j]
38+
if cur_col != next_col
39+
if rows_index[i] == rows_index[j]
40+
add_edge!(graph, cur_col, next_col)
41+
end
2642
end
2743
end
2844
end

test/test_matrix2graph.jl

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ end
1717

1818
for i in 1:20
1919
matrix = matrices[i]
20-
g = matrix2graph(matrix)
20+
g = matrix2graph(matrix, false)
2121
for e in edges(g)
2222
src = LG.src(e)
2323
dst = LG.dst(e)
@@ -27,3 +27,16 @@ for i in 1:20
2727
@test pr != 0
2828
end
2929
end
30+
31+
for i in 1:20
32+
matrix = matrices[i]
33+
g = matrix2graph(matrix, true)
34+
for e in edges(g)
35+
src = LG.src(e)
36+
dst = LG.dst(e)
37+
row1 = abs.(matrix[src, :])
38+
row2 = abs.(matrix[dst, :])
39+
pr = row1' * row2
40+
@test pr != 0
41+
end
42+
end

0 commit comments

Comments
 (0)