-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-10599][MLLIB] Lower communication for block matrix multiplication #8757
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,12 +54,14 @@ private[mllib] class GridPartitioner( | |
| /** | ||
| * Returns the index of the partition the input coordinate belongs to. | ||
| * | ||
| * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in | ||
| * multiplication. k is ignored in computing partitions. | ||
| * @param key The partition id i (calculated through this method for coordinate (i, j) in | ||
| * `simulateMultiply`, the coordinate (i, j) or a tuple (i, j, k), where k is | ||
| * the inner index used in multiplication. k is ignored in computing partitions. | ||
| * @return The index of the partition, which the coordinate belongs to. | ||
| */ | ||
| override def getPartition(key: Any): Int = { | ||
| key match { | ||
| case i: Int => i | ||
| case (i: Int, j: Int) => | ||
| getPartitionId(i, j) | ||
| case (i: Int, j: Int, _: Int) => | ||
|
|
@@ -352,12 +354,49 @@ class BlockMatrix @Since("1.3.0") ( | |
| } | ||
| } | ||
|
|
||
| /** Block (i,j) --> Set of destination partitions */ | ||
| private type BlockDestinations = Map[(Int, Int), Set[Int]] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document |
||
|
|
||
| /** | ||
| * Simulate the multiplication with just block indices in order to cut costs on communication, | ||
| * when we are actually shuffling the matrices. | ||
| * The `colsPerBlock` of this matrix must equal the `rowsPerBlock` of `other`. | ||
| * Exposed for tests. | ||
| * | ||
| * @param other The BlockMatrix to multiply | ||
| * @param partitioner The partitioner that will be used for the resulting matrix `C = A * B` | ||
| * @return A tuple of [[BlockDestinations]]. The first element is the Map of the set of partitions | ||
| * that we need to shuffle each blocks of `this`, and the second element is the Map for | ||
| * `other`. | ||
| */ | ||
| private[distributed] def simulateMultiply( | ||
| other: BlockMatrix, | ||
| partitioner: GridPartitioner): (BlockDestinations, BlockDestinations) = { | ||
| val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached | ||
| val rightMatrix = other.blocks.keys.collect() | ||
| val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) => | ||
| val rightCounterparts = rightMatrix.filter(_._1 == colIndex) | ||
| val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2))) | ||
| ((rowIndex, colIndex), partitions.toSet) | ||
| }.toMap | ||
| val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) => | ||
| val leftCounterparts = leftMatrix.filter(_._2 == rowIndex) | ||
| val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex))) | ||
| ((rowIndex, colIndex), partitions.toSet) | ||
| }.toMap | ||
| (leftDestinations, rightDestinations) | ||
| } | ||
|
|
||
| /** | ||
| * Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` | ||
| * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains | ||
| * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output | ||
| * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause | ||
| * some performance issues until support for multiplying two sparse matrices is added. | ||
| * | ||
| * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when | ||
| * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added | ||
| * with each other. | ||
| */ | ||
| @Since("1.3.0") | ||
| def multiply(other: BlockMatrix): BlockMatrix = { | ||
|
|
@@ -368,33 +407,30 @@ class BlockMatrix @Since("1.3.0") ( | |
| if (colsPerBlock == other.rowsPerBlock) { | ||
| val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks, | ||
| math.max(blocks.partitions.length, other.blocks.partitions.length)) | ||
| // Each block of A must be multiplied with the corresponding blocks in each column of B. | ||
| // TODO: Optimize to send block to a partition once, similar to ALS | ||
| val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner) | ||
| // Each block of A must be multiplied with the corresponding blocks in the columns of B. | ||
| val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => | ||
| Iterator.tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block)) | ||
| val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) | ||
| destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) | ||
| } | ||
| // Each block of B must be multiplied with the corresponding blocks in each row of A. | ||
| val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => | ||
| Iterator.tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block)) | ||
| val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set.empty) | ||
| destinations.map(j => (j, (blockRowIndex, blockColIndex, block))) | ||
| } | ||
| val newBlocks: RDD[MatrixBlock] = flatA.cogroup(flatB, resultPartitioner) | ||
| .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) => | ||
| if (a.size > 1 || b.size > 1) { | ||
| throw new SparkException("There are multiple MatrixBlocks with indices: " + | ||
| s"($blockRowIndex, $blockColIndex). Please remove them.") | ||
| } | ||
| if (a.nonEmpty && b.nonEmpty) { | ||
| val C = b.head match { | ||
| case dense: DenseMatrix => a.head.multiply(dense) | ||
| case sparse: SparseMatrix => a.head.multiply(sparse.toDense) | ||
| case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") | ||
| val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) => | ||
| a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) => | ||
| b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) => | ||
| val C = rightBlock match { | ||
| case dense: DenseMatrix => leftBlock.multiply(dense) | ||
| case sparse: SparseMatrix => leftBlock.multiply(sparse.toDense) | ||
| case _ => | ||
| throw new SparkException(s"Unrecognized matrix type ${rightBlock.getClass}.") | ||
| } | ||
| Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) | ||
| } else { | ||
| Iterator() | ||
| ((leftRowIndex, rightColIndex), C.toBreeze) | ||
| } | ||
| }.reduceByKey(resultPartitioner, (a, b) => a + b) | ||
| .mapValues(Matrices.fromBreeze) | ||
| } | ||
| }.reduceByKey(resultPartitioner, (a, b) => a + b).mapValues(Matrices.fromBreeze) | ||
| // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices | ||
| new BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols()) | ||
| } else { | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Update documentation