Skip to content

Commit 5dcf85c

Browse files
committed
[SPARK-5322] Added transpose functionality to BlockMatrix
1 parent a3dc618 commit 5dcf85c

File tree

2 files changed

+30
-0
lines changed

2 files changed

+30
-0
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,15 @@ class BlockMatrix(
232232
new DenseMatrix(m, n, values)
233233
}
234234

235+
/** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the
236+
* same underlying data. */
237+
def transpose: BlockMatrix = {
238+
val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) =>
239+
((blockColIndex, blockRowIndex), mat.transpose)
240+
}
241+
new BlockMatrix(transposedBlocks, colsPerBlock, rowsPerBlock, numCols(), numRows())
242+
}
243+
235244
/** Collects data and assembles a local dense breeze matrix (for test only). */
236245
private[mllib] def toBreeze(): BDM[Double] = {
237246
val localMat = toLocalMatrix()

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,4 +146,25 @@ class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext {
146146
assert(gridBasedMat.toLocalMatrix() === dense)
147147
assert(gridBasedMat.toBreeze() === expected)
148148
}
149+
150+
test("transpose") {
151+
val expected = BDM(
152+
(1.0, 0.0, 3.0, 0.0, 0.0),
153+
(0.0, 2.0, 1.0, 1.0, 0.0),
154+
(0.0, 1.0, 1.0, 2.0, 1.0),
155+
(0.0, 0.0, 0.0, 1.0, 5.0))
156+
157+
val AT = gridBasedMat.transpose
158+
assert(AT.numRows() === gridBasedMat.numCols())
159+
assert(AT.numCols() === gridBasedMat.numRows())
160+
assert(AT.toBreeze() === expected)
161+
162+
// partitioner must update as well
163+
val originalPartitioner = gridBasedMat.partitioner
164+
val ATpartitioner = AT.partitioner
165+
assert(originalPartitioner.colsPerPart === ATpartitioner.rowsPerPart)
166+
assert(originalPartitioner.rowsPerPart === ATpartitioner.colsPerPart)
167+
assert(originalPartitioner.cols === ATpartitioner.rows)
168+
assert(originalPartitioner.rows === ATpartitioner.cols)
169+
}
149170
}

0 commit comments

Comments
 (0)