@@ -54,12 +54,14 @@ private[mllib] class GridPartitioner(
5454 /**
5555 * Returns the index of the partition the input coordinate belongs to.
5656 *
57- * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in
58- * multiplication. k is ignored in computing partitions.
57+ * @param key The partition id i (calculated through this method for coordinate (i, j) in
58+ * `simulateMultiply`, the coordinate (i, j) or a tuple (i, j, k), where k is
59+ * the inner index used in multiplication. k is ignored in computing partitions.
5960 * @return The index of the partition, which the coordinate belongs to.
6061 */
6162 override def getPartition (key : Any ): Int = {
6263 key match {
64+ case i : Int => i
6365 case (i : Int , j : Int ) =>
6466 getPartitionId(i, j)
6567 case (i : Int , j : Int , _ : Int ) =>
@@ -352,12 +354,49 @@ class BlockMatrix @Since("1.3.0") (
352354 }
353355 }
354356
357+ /** Block (i,j) --> Set of destination partitions */
358+ private type BlockDestinations = Map [(Int , Int ), Set [Int ]]
359+
360+ /**
361+ * Simulate the multiplication with just block indices in order to cut costs on communication,
362+ * when we are actually shuffling the matrices.
363+ * The `colsPerBlock` of this matrix must equal the `rowsPerBlock` of `other`.
364+ * Exposed for tests.
365+ *
366+ * @param other The BlockMatrix to multiply
367+ * @param partitioner The partitioner that will be used for the resulting matrix `C = A * B`
368+ * @return A tuple of [[BlockDestinations ]]. The first element is the Map of the set of partitions
369+ * that we need to shuffle each blocks of `this`, and the second element is the Map for
370+ * `other`.
371+ */
372+ private [distributed] def simulateMultiply (
373+ other : BlockMatrix ,
374+ partitioner : GridPartitioner ): (BlockDestinations , BlockDestinations ) = {
375+ val leftMatrix = blockInfo.keys.collect() // blockInfo should already be cached
376+ val rightMatrix = other.blocks.keys.collect()
377+ val leftDestinations = leftMatrix.map { case (rowIndex, colIndex) =>
378+ val rightCounterparts = rightMatrix.filter(_._1 == colIndex)
379+ val partitions = rightCounterparts.map(b => partitioner.getPartition((rowIndex, b._2)))
380+ ((rowIndex, colIndex), partitions.toSet)
381+ }.toMap
382+ val rightDestinations = rightMatrix.map { case (rowIndex, colIndex) =>
383+ val leftCounterparts = leftMatrix.filter(_._2 == rowIndex)
384+ val partitions = leftCounterparts.map(b => partitioner.getPartition((b._1, colIndex)))
385+ ((rowIndex, colIndex), partitions.toSet)
386+ }.toMap
387+ (leftDestinations, rightDestinations)
388+ }
389+
355390 /**
356391 * Left multiplies this [[BlockMatrix ]] to `other`, another [[BlockMatrix ]]. The `colsPerBlock`
357392 * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains
358393 * [[SparseMatrix ]], they will have to be converted to a [[DenseMatrix ]]. The output
359394 * [[BlockMatrix ]] will only consist of blocks of [[DenseMatrix ]]. This may cause
360395 * some performance issues until support for multiplying two sparse matrices is added.
396+ *
397+ * Note: The behavior of multiply has changed in 1.6.0. `multiply` used to throw an error when
398+ * there were blocks with duplicate indices. Now, the blocks with duplicate indices will be added
399+ * with each other.
361400 */
362401 @ Since (" 1.3.0" )
363402 def multiply (other : BlockMatrix ): BlockMatrix = {
@@ -368,33 +407,30 @@ class BlockMatrix @Since("1.3.0") (
368407 if (colsPerBlock == other.rowsPerBlock) {
369408 val resultPartitioner = GridPartitioner (numRowBlocks, other.numColBlocks,
370409 math.max(blocks.partitions.length, other.blocks.partitions.length))
371- // Each block of A must be multiplied with the corresponding blocks in each column of B.
372- // TODO: Optimize to send block to a partition once, similar to ALS
410+ val (leftDestinations, rightDestinations) = simulateMultiply(other, resultPartitioner)
411+ // Each block of A must be multiplied with the corresponding blocks in the columns of B.
373412 val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) =>
374- Iterator .tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block))
413+ val destinations = leftDestinations.getOrElse((blockRowIndex, blockColIndex), Set .empty)
414+ destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
375415 }
376416 // Each block of B must be multiplied with the corresponding blocks in each row of A.
377417 val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) =>
378- Iterator .tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block))
418+ val destinations = rightDestinations.getOrElse((blockRowIndex, blockColIndex), Set .empty)
419+ destinations.map(j => (j, (blockRowIndex, blockColIndex, block)))
379420 }
380- val newBlocks : RDD [MatrixBlock ] = flatA.cogroup(flatB, resultPartitioner)
381- .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) =>
382- if (a.size > 1 || b.size > 1 ) {
383- throw new SparkException (" There are multiple MatrixBlocks with indices: " +
384- s " ( $blockRowIndex, $blockColIndex). Please remove them. " )
385- }
386- if (a.nonEmpty && b.nonEmpty) {
387- val C = b.head match {
388- case dense : DenseMatrix => a.head.multiply(dense)
389- case sparse : SparseMatrix => a.head.multiply(sparse.toDense)
390- case _ => throw new SparkException (s " Unrecognized matrix type ${b.head.getClass}. " )
421+ val newBlocks = flatA.cogroup(flatB, resultPartitioner).flatMap { case (pId, (a, b)) =>
422+ a.flatMap { case (leftRowIndex, leftColIndex, leftBlock) =>
423+ b.filter(_._1 == leftColIndex).map { case (rightRowIndex, rightColIndex, rightBlock) =>
424+ val C = rightBlock match {
425+ case dense : DenseMatrix => leftBlock.multiply(dense)
426+ case sparse : SparseMatrix => leftBlock.multiply(sparse.toDense)
427+ case _ =>
428+ throw new SparkException (s " Unrecognized matrix type ${rightBlock.getClass}. " )
391429 }
392- Iterator (((blockRowIndex, blockColIndex), C .toBreeze))
393- } else {
394- Iterator ()
430+ ((leftRowIndex, rightColIndex), C .toBreeze)
395431 }
396- }.reduceByKey(resultPartitioner, (a, b) => a + b)
397- .mapValues(Matrices .fromBreeze)
432+ }
433+ }.reduceByKey(resultPartitioner, (a, b) => a + b) .mapValues(Matrices .fromBreeze)
398434 // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices
399435 new BlockMatrix (newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols())
400436 } else {
0 commit comments