Skip to content

Commit 032805d

Browse files
committed
addressing @dbtsai review comments
1 parent 28b61b8 commit 032805d

File tree

1 file changed

+93
-59
lines changed

1 file changed

+93
-59
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala

Lines changed: 93 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ class BlockMatrix @Since("1.3.0") (
463463
val a11Inv = new BlockMatrix(a11RDD, this.rowsPerBlock, this.colsPerBlock)
464464

465465
val S = a22.subtract(a21.multiply(a11Inv.multiply(a12)))
466-
return S
466+
S
467467
}
468468

469469
/** Returns a rectangular (sub)BlockMatrix with block ranges as specified. Block Ranges
@@ -487,10 +487,10 @@ class BlockMatrix @Since("1.3.0") (
487487
val colMin = blockColRange._1 ; val colMax = blockColRange._2
488488
val extractedSeq = this.blocks.filter{ case((x, y), matrix) =>
489489
x >= rowMin && x<= rowMax && // finding blocks
490-
y >= colMin && y<= colMax }.map{ // shifting indices
490+
y >= colMin && y<= colMax }.map { // shifting indices
491491
case(((x, y), matrix) ) => ((x-rowMin, y-colMin), matrix)
492492
}
493-
return new BlockMatrix(extractedSeq, rowsPerBlock, colsPerBlock)
493+
new BlockMatrix(extractedSeq, rowsPerBlock, colsPerBlock)
494494
}
495495

496496
/** computes the LU decomposition of a Single Block from BlockMatrix using the
@@ -507,16 +507,33 @@ class BlockMatrix @Since("1.3.0") (
507507
val L = lowerTriangular(PLU._1) - diag(diag(PLU._1)) + diag(DenseVector.fill(k){1.0})
508508
val U = upperTriangular(PLU._1);
509509
var P = diag(DenseVector.fill(k){1.0})
510-
val Pi = diag(DenseVector.fill(k){1.0})
510+
var Pi = diag(DenseVector.fill(k){1.0})
511511
// size of square matrix
512-
for(i <- 0 to (k - 1)) { // i test populating permutation matrix
513-
val I = i match {case 0 => k - 1 case _ => i - 1}
512+
// populating permutation matrix
513+
var i = 0
514+
while (i < k) {
515+
val I = {
516+
if (i == 0){k - 1}
517+
else {i - 1}
518+
}
514519
val J = PLU._2(i) -1
515-
if (i != J) { Pi(i, J) += 1.0; Pi(J, i) += 1.0; Pi(i, i) -= 1.0; Pi(J, J) -= 1.0}
520+
if (i != J) {
521+
Pi(i, J) += 1.0
522+
Pi(J, i) += 1.0
523+
Pi(i, i) -= 1.0
524+
Pi(J, J) -= 1.0
525+
}
516526
P = Pi * P // constructor Pi*P for PA=LU
517-
if (i != J) { Pi(i, J) -= 1.0; Pi(J, i) -= 1.0; Pi(i, i) += 1.0; Pi(J, J) += 1.0}
527+
// resetting Pi for next iteration
528+
if (i != J) {
529+
Pi(i, J) -= 1.0
530+
Pi(J, i) -= 1.0
531+
Pi(i, i) += 1.0
532+
Pi(J, J) += 1.0
533+
}
534+
i += 1
518535
}
519-
return List(P, L, U)
536+
List(P, L, U)
520537
}
521538

522539

@@ -532,10 +549,10 @@ class BlockMatrix @Since("1.3.0") (
532549
*/
533550
private [mllib] def shiftIndices(rowMin: Int, colMin: Int): RDD[((Int, Int), Matrix)] = {
534551
// This routine recovers the absolute indexing of the block matrices for reassembly
535-
val extractedSeq = this.blocks.map{ // shifting indices
552+
val extractedSeq = this.blocks.map { // shifting indices
536553
case(((x, y), matrix)) => ((x + rowMin, y + colMin), matrix)
537554
}
538-
return extractedSeq
555+
extractedSeq
539556
}
540557

541558
/** Computes the LU Decomposition of a Square Matrix. For a matrix A of size (n x n)
@@ -572,6 +589,40 @@ class BlockMatrix @Since("1.3.0") (
572589
// accessing the spark context
573590
val sc = this.blocks.sparkContext
574591

592+
/** LUSequences is a class that is defined to make the recursiveSequencesBuild section
593+
* more readable.
594+
*
595+
* These are passed as an RDD of blocks:
596+
* @param p the permutation matrix.
597+
* @param l the lower diagonal matrix.
598+
* @param u the upper diagonal matrix.
599+
* @param lInv the inverse lower diagonal matrix (only populating (i,i) cells).
600+
* @param uInv the inverse upper diagonal matrix (only populating (i,i) cells).
601+
* @param lDiag the lower diagonal matrices (only populating (i,i) cells).
602+
* @param uDiag the upper diagonal matrices (only populating (i,i) cells).
603+
* This is passed as a BlockMatrix
604+
* @param a the Schur Complement from the previous iteration, treated as the source matrix
605+
* for the next iteraton.
606+
*
607+
*
608+
@Since("1.6.0")
609+
*/
610+
class LUSequences(p: RDD[((Int, Int), Matrix)], l: RDD[((Int, Int), Matrix)],
611+
u: RDD[((Int, Int), Matrix)],
612+
lInv: RDD[((Int, Int), Matrix)], uInv: RDD[((Int, Int), Matrix)],
613+
lDiag: RDD[((Int, Int), Matrix)], uDiag: RDD[((Int, Int), Matrix)],
614+
a: BlockMatrix) {
615+
val P = p // the permutation matrix.
616+
val L = l // the lower diagonal matrix.
617+
val U = u // the upper diagonal matrix.
618+
val Li = lInv // the inverse lower diagonal matrix (only populating (i,i) cells).
619+
val Ui = uInv // the inverse upper diagonal matrix (only populating (i,i) cells).
620+
val LD = lDiag // the lower diagonal matrices (only populating (i,i) cells).
621+
val UD = uDiag // he upper diagonal matrices (only populating (i,i) cells).
622+
val A = a // the Schur Complement: BlockMatrix from the previous iteration, treated
623+
// as the source matrix for the next iteraton.
624+
}
625+
575626
/** Recursive Sequence Build is a nested recursion method that builds up all of the
576627
* sequences that are converted to BlockMatrix classes for large matrix
577628
* multiplication operations. The Schur Complement is calculated at each
@@ -586,32 +637,17 @@ class BlockMatrix @Since("1.3.0") (
586637
* the cascading Schur calculations, while for LD, (i, j<i) blocks are populated.
587638
*
588639
* @param rowI
589-
* @param prevTuple
640+
* @param prev
590641
* @return dP, dL, dU, dLi, dUi, LD, UD, S All are RDDs of Sequences that are
591642
* iteratively built, while S is a BlockMatrix used in the recursion loop
592643
* @since 1.6.0
593644
*/
594-
def recursiveSequencesBuild(rowI: Int, prevTuple:
595-
(RDD[((Int, Int), Matrix)], RDD[((Int, Int), Matrix)],
596-
RDD[((Int, Int), Matrix)], RDD[((Int, Int), Matrix)],
597-
RDD[((Int, Int), Matrix)], RDD[((Int, Int), Matrix)],
598-
RDD[((Int, Int), Matrix)],
599-
BlockMatrix)):
600-
(RDD[((Int, Int), Matrix)], RDD[((Int, Int), Matrix)],
601-
RDD[((Int, Int), Matrix)], RDD[((Int, Int), Matrix)],
602-
RDD[((Int, Int), Matrix)], RDD[((Int, Int), Matrix)],
603-
RDD[((Int, Int), Matrix)],
604-
BlockMatrix) = {
605-
val prevP = prevTuple._1;
606-
val prevL = prevTuple._2; val prevU = prevTuple._3
607-
val prevLi = prevTuple._4; val prevUi = prevTuple._5
608-
val prevLD = prevTuple._6; val prevUD = prevTuple._7
609-
val ABlock = prevTuple._8
610-
611-
val rowsRel = ABlock.numRowBlocks; val colsRel = ABlock.numColBlocks
645+
def recursiveSequencesBuild(rowI: Int, prev: LUSequences): LUSequences = {
646+
647+
val rowsRel = prev.A.numRowBlocks; val colsRel = prev.A.numColBlocks
612648
val topRangeRel = (0, 0); val botRangeRel = (1, rowsRel - 1)
613649
val topRangeAbs = (rowI, rowI); val botRangeAbs = (rowI + 1, rowsAbs - 1 )
614-
val PLU: List[BDM[Double]] = ABlock.singleBlockPLU;
650+
val PLU: List[BDM[Double]] = prev.A.singleBlockPLU;
615651
val PBrz = PLU(0); val LBrz = PLU(1); val UBrz = PLU(2)
616652

617653
val P = Matrices.dense(PBrz.rows, PBrz.cols, PBrz.toArray)
@@ -626,28 +662,26 @@ class BlockMatrix @Since("1.3.0") (
626662
val ZB = BDM.zeros[Double](LBrz.rows, LBrz.cols)
627663
val Z = Matrices.dense(LBrz.rows, LBrz.cols, ZB.toArray)
628664
val lastZ = sc.parallelize(Seq(((rowsAbs-1, colsAbs-1), Z)))
629-
val nextTuple = (nextP ++ prevP, nextL ++ prevL, nextU ++ prevU,
630-
lastZ ++ prevLi, lastZ ++ prevUi,
631-
lastZ ++ prevLD, lastZ ++ prevUD,
632-
ABlock)
633-
return nextTuple
665+
val nextTuple = new LUSequences(nextP ++ prev.P, nextL ++ prev.L, nextU ++ prev.U,
666+
lastZ ++ prev.Li, lastZ ++ prev.Ui,
667+
lastZ ++ prev.LD, lastZ ++ prev.UD, prev.A)
668+
nextTuple
634669
}
635670
else { // recursion block
636-
val SBlock = ABlock.SchurComplement
637671
val Li = Matrices.dense(LBrz.rows, LBrz.cols, inv(LBrz).toArray)
638672
val Ui = Matrices.dense(UBrz.rows, UBrz.cols, inv(UBrz).toArray)
639673
val nextLi = sc.parallelize(Seq(((rowI, rowI), Li)))
640674
val nextUi = sc.parallelize(Seq(((rowI, rowI), Ui)))
641675

642-
val nextLD = ABlock.subBlock(botRangeRel, topRangeRel).
676+
val nextLD = prev.A.subBlock(botRangeRel, topRangeRel).
643677
shiftIndices(botRangeAbs._1, topRangeAbs._1)
644-
val nextUD = ABlock.subBlock(topRangeRel, botRangeRel).
678+
val nextUD = prev.A.subBlock(topRangeRel, botRangeRel).
645679
shiftIndices(topRangeAbs._1, botRangeAbs._1)
646680

647-
val nextTuple = (nextP ++ prevP, nextL ++ prevL, nextU ++ prevU,
648-
nextLi ++ prevLi, nextUi ++ prevUi,
649-
prevLD ++ nextLD, prevUD ++ nextUD, SBlock)
650-
return recursiveSequencesBuild(rowI + 1, nextTuple)
681+
val nextTuple = new LUSequences(nextP ++ prev.P, nextL ++ prev.L, nextU ++ prev.U,
682+
nextLi ++ prev.Li, nextUi ++ prev.Ui,
683+
prev.LD ++ nextLD, prev.UD ++ nextUD, prev.A.SchurComplement)
684+
recursiveSequencesBuild(rowI + 1, nextTuple)
651685
}
652686
}
653687

@@ -678,21 +712,21 @@ class BlockMatrix @Since("1.3.0") (
678712
val nextUD = this.subBlock(topRange, botRange).
679713
shiftIndices(topRange._1, botRange._1)
680714

681-
val nextTuple = (nextP, nextL, nextU, nextLi, nextUi,
715+
val nextTuple = new LUSequences(nextP, nextL, nextU, nextLi, nextUi,
682716
firstZ ++ nextLD, firstZ ++ nextUD,
683717
this.SchurComplement)
684718

685719
// call to recursive build after initialization step
686-
val allSequences = recursiveSequencesBuild(1, nextTuple)
720+
val lastSequences = recursiveSequencesBuild(1, nextTuple)
687721
val rowsPerBlock = this.rowsPerBlock;
688722
val colsPerBlock = this.colsPerBlock
689-
val dP = new BlockMatrix(allSequences._1, rowsPerBlock, colsPerBlock)
690-
val dL = new BlockMatrix(allSequences._2, rowsPerBlock, colsPerBlock)
691-
val dU = new BlockMatrix(allSequences._3, rowsPerBlock, colsPerBlock)
692-
val dLi = new BlockMatrix(allSequences._4, rowsPerBlock, colsPerBlock)
693-
val dUi = new BlockMatrix(allSequences._5, rowsPerBlock, colsPerBlock)
694-
val LD = new BlockMatrix(allSequences._6, rowsPerBlock, colsPerBlock)
695-
val UD = new BlockMatrix(allSequences._7, rowsPerBlock, colsPerBlock)
723+
val dP = new BlockMatrix(lastSequences.P, rowsPerBlock, colsPerBlock)
724+
val dL = new BlockMatrix(lastSequences.L, rowsPerBlock, colsPerBlock)
725+
val dU = new BlockMatrix(lastSequences.U, rowsPerBlock, colsPerBlock)
726+
val dLi = new BlockMatrix(lastSequences.Li, rowsPerBlock, colsPerBlock)
727+
val dUi = new BlockMatrix(lastSequences.Ui, rowsPerBlock, colsPerBlock)
728+
val LD = new BlockMatrix(lastSequences.LD, rowsPerBlock, colsPerBlock)
729+
val UD = new BlockMatrix(lastSequences.UD, rowsPerBlock, colsPerBlock)
696730

697731
// Large Matrix Multiplication Operations
698732
// dL and dU are the sets of L and U Matrices along the diagonal blocks,
@@ -709,7 +743,7 @@ class BlockMatrix @Since("1.3.0") (
709743
// U = ( d[Linv] * dP * UD + dU )
710744
val UFin = dLi.multiply(dP.multiply(UD)).add(dU)
711745
// val UFin = dLi.multiply(UD).add(dU)
712-
return (PFin, LFin, UFin, dLi, dUi)
746+
(PFin, LFin, UFin, dLi, dUi)
713747
}
714748

715749

@@ -728,7 +762,7 @@ class BlockMatrix @Since("1.3.0") (
728762
val P = PLU._1
729763
val L = PLU._2
730764
val U = PLU._3
731-
return (P, L, U)
765+
(P, L, U)
732766
}
733767

734768
/** For the matrix Equation AX=B, where A is NxN blocks, and X, B are matrices of
@@ -786,8 +820,8 @@ class BlockMatrix @Since("1.3.0") (
786820
val currentY = new BlockMatrix(prevY.blocks ++ nextY.blocks,
787821
this.rowsPerBlock, this.colsPerBlock)
788822

789-
if (m == N - 1){return currentY} // terminal case
790-
else { return recursiveYBuild(m + 1, currentY)} // recursive case
823+
if (m == N - 1){currentY} // terminal case
824+
else {recursiveYBuild(m + 1, currentY)} // recursive case
791825
}
792826

793827
// Solving LY = PB for Y using (see docs):
@@ -813,8 +847,8 @@ class BlockMatrix @Since("1.3.0") (
813847
this.rowsPerBlock, this.colsPerBlock)
814848

815849

816-
if (mRev == 0 ){return currentX} // terminal case
817-
else { return recursiveXBuild(mRev - 1, currentX)} // recursive case
850+
if (mRev == 0 ){currentX} // terminal case
851+
else {recursiveXBuild(mRev - 1, currentX)} // recursive case
818852
}
819853

820854
// Solving UX = Y for X
@@ -833,7 +867,7 @@ class BlockMatrix @Since("1.3.0") (
833867
this.rowsPerBlock, this.colsPerBlock)
834868

835869
val X = recursiveXBuild(mRev-1, firstX)
836-
return X
870+
X
837871

838872
}
839873
}

0 commit comments

Comments
 (0)