@@ -463,7 +463,7 @@ class BlockMatrix @Since("1.3.0") (
463463 val a11Inv = new BlockMatrix (a11RDD, this .rowsPerBlock, this .colsPerBlock)
464464
465465 val S = a22.subtract(a21.multiply(a11Inv.multiply(a12)))
466- return S
466+ S
467467 }
468468
469469 /** Returns a rectangular (sub)BlockMatrix with block ranges as specified. Block Ranges
@@ -487,10 +487,10 @@ class BlockMatrix @Since("1.3.0") (
487487 val colMin = blockColRange._1 ; val colMax = blockColRange._2
488488 val extractedSeq = this .blocks.filter{ case ((x, y), matrix) =>
489489 x >= rowMin && x<= rowMax && // finding blocks
490- y >= colMin && y<= colMax }.map{ // shifting indices
490+ y >= colMin && y<= colMax }.map { // shifting indices
491491 case (((x, y), matrix) ) => ((x- rowMin, y- colMin), matrix)
492492 }
493- return new BlockMatrix (extractedSeq, rowsPerBlock, colsPerBlock)
493+ new BlockMatrix (extractedSeq, rowsPerBlock, colsPerBlock)
494494 }
495495
496496 /** computes the LU decomposition of a Single Block from BlockMatrix using the
@@ -507,16 +507,33 @@ class BlockMatrix @Since("1.3.0") (
507507 val L = lowerTriangular(PLU ._1) - diag(diag(PLU ._1)) + diag(DenseVector .fill(k){1.0 })
508508 val U = upperTriangular(PLU ._1);
509509 var P = diag(DenseVector .fill(k){1.0 })
510- val Pi = diag(DenseVector .fill(k){1.0 })
510+ var Pi = diag(DenseVector .fill(k){1.0 })
511511 // size of square matrix
512- for (i <- 0 to (k - 1 )) { // i test populating permutation matrix
513- val I = i match {case 0 => k - 1 case _ => i - 1 }
512+ // populating permutation matrix
513+ var i = 0
514+ while (i < k) {
515+ val I = {
516+ if (i == 0 ){k - 1 }
517+ else {i - 1 }
518+ }
514519 val J = PLU ._2(i) - 1
515- if (i != J ) { Pi (i, J ) += 1.0 ; Pi (J , i) += 1.0 ; Pi (i, i) -= 1.0 ; Pi (J , J ) -= 1.0 }
520+ if (i != J ) {
521+ Pi (i, J ) += 1.0
522+ Pi (J , i) += 1.0
523+ Pi (i, i) -= 1.0
524+ Pi (J , J ) -= 1.0
525+ }
516526 P = Pi * P // constructor Pi*P for PA=LU
517- if (i != J ) { Pi (i, J ) -= 1.0 ; Pi (J , i) -= 1.0 ; Pi (i, i) += 1.0 ; Pi (J , J ) += 1.0 }
527+ // resetting Pi for next iteration
528+ if (i != J ) {
529+ Pi (i, J ) -= 1.0
530+ Pi (J , i) -= 1.0
531+ Pi (i, i) += 1.0
532+ Pi (J , J ) += 1.0
533+ }
534+ i += 1
518535 }
519- return List (P , L , U )
536+ List (P , L , U )
520537 }
521538
522539
@@ -532,10 +549,10 @@ class BlockMatrix @Since("1.3.0") (
532549 */
533550 private [mllib] def shiftIndices (rowMin : Int , colMin : Int ): RDD [((Int , Int ), Matrix )] = {
534551 // This routine recovers the absolute indexing of the block matrices for reassembly
535- val extractedSeq = this .blocks.map{ // shifting indices
552+ val extractedSeq = this .blocks.map { // shifting indices
536553 case (((x, y), matrix)) => ((x + rowMin, y + colMin), matrix)
537554 }
538- return extractedSeq
555+ extractedSeq
539556 }
540557
541558 /** Computes the LU Decomposition of a Square Matrix. For a matrix A of size (n x n)
@@ -572,6 +589,40 @@ class BlockMatrix @Since("1.3.0") (
572589 // accessing the spark context
573590 val sc = this .blocks.sparkContext
574591
592+ /** LUSequences is a class that is defined to make the recursiveSequencesBuild section
593+ * more readable.
594+ *
595+ * These are passed as an RDD of blocks:
596+ * @param p the permutation matrix.
597+ * @param l the lower diagonal matrix.
598+ * @param u the upper diagonal matrix.
599+ * @param lInv the inverse lower diagonal matrix (only populating (i,i) cells).
600+ * @param uInv the inverse upper diagonal matrix (only populating (i,i) cells).
601+ * @param lDiag the lower diagonal matrices (only populating (i,i) cells).
602+ * @param uDiag the upper diagonal matrices (only populating (i,i) cells).
603+ * This is passed as a BlockMatrix
604+ * @param a the Schur Complement from the previous iteration, treated as the source matrix
605+ * for the next iteraton.
606+ *
607+ *
608+ @Since("1.6.0")
609+ */
610+ class LUSequences (p : RDD [((Int , Int ), Matrix )], l : RDD [((Int , Int ), Matrix )],
611+ u : RDD [((Int , Int ), Matrix )],
612+ lInv : RDD [((Int , Int ), Matrix )], uInv : RDD [((Int , Int ), Matrix )],
613+ lDiag : RDD [((Int , Int ), Matrix )], uDiag : RDD [((Int , Int ), Matrix )],
614+ a : BlockMatrix ) {
615+ val P = p // the permutation matrix.
616+ val L = l // the lower diagonal matrix.
617+ val U = u // the upper diagonal matrix.
618+ val Li = lInv // the inverse lower diagonal matrix (only populating (i,i) cells).
619+ val Ui = uInv // the inverse upper diagonal matrix (only populating (i,i) cells).
620+ val LD = lDiag // the lower diagonal matrices (only populating (i,i) cells).
621+ val UD = uDiag // he upper diagonal matrices (only populating (i,i) cells).
622+ val A = a // the Schur Complement: BlockMatrix from the previous iteration, treated
623+ // as the source matrix for the next iteraton.
624+ }
625+
575626 /** Recursive Sequence Build is a nested recursion method that builds up all of the
576627 * sequences that are converted to BlockMatrix classes for large matrix
577628 * multiplication operations. The Schur Complement is calculated at each
@@ -586,32 +637,17 @@ class BlockMatrix @Since("1.3.0") (
586637 * the cascading Schur calculations, while for LD, (i, j<i) blocks are populated.
587638 *
588639 * @param rowI
589- * @param prevTuple
640+ * @param prev
590641 * @return dP, dL, dU, dLi, dUi, LD, UD, S All are RDDs of Sequences that are
591642 * iteratively built, while S is a BlockMatrix used in the recursion loop
592643 * @since 1.6.0
593644 */
594- def recursiveSequencesBuild (rowI : Int , prevTuple :
595- (RDD [((Int , Int ), Matrix )], RDD [((Int , Int ), Matrix )],
596- RDD [((Int , Int ), Matrix )], RDD [((Int , Int ), Matrix )],
597- RDD [((Int , Int ), Matrix )], RDD [((Int , Int ), Matrix )],
598- RDD [((Int , Int ), Matrix )],
599- BlockMatrix )):
600- (RDD [((Int , Int ), Matrix )], RDD [((Int , Int ), Matrix )],
601- RDD [((Int , Int ), Matrix )], RDD [((Int , Int ), Matrix )],
602- RDD [((Int , Int ), Matrix )], RDD [((Int , Int ), Matrix )],
603- RDD [((Int , Int ), Matrix )],
604- BlockMatrix ) = {
605- val prevP = prevTuple._1;
606- val prevL = prevTuple._2; val prevU = prevTuple._3
607- val prevLi = prevTuple._4; val prevUi = prevTuple._5
608- val prevLD = prevTuple._6; val prevUD = prevTuple._7
609- val ABlock = prevTuple._8
610-
611- val rowsRel = ABlock .numRowBlocks; val colsRel = ABlock .numColBlocks
645+ def recursiveSequencesBuild (rowI : Int , prev : LUSequences ): LUSequences = {
646+
647+ val rowsRel = prev.A .numRowBlocks; val colsRel = prev.A .numColBlocks
612648 val topRangeRel = (0 , 0 ); val botRangeRel = (1 , rowsRel - 1 )
613649 val topRangeAbs = (rowI, rowI); val botRangeAbs = (rowI + 1 , rowsAbs - 1 )
614- val PLU : List [BDM [Double ]] = ABlock .singleBlockPLU;
650+ val PLU : List [BDM [Double ]] = prev. A .singleBlockPLU;
615651 val PBrz = PLU (0 ); val LBrz = PLU (1 ); val UBrz = PLU (2 )
616652
617653 val P = Matrices .dense(PBrz .rows, PBrz .cols, PBrz .toArray)
@@ -626,28 +662,26 @@ class BlockMatrix @Since("1.3.0") (
626662 val ZB = BDM .zeros[Double ](LBrz .rows, LBrz .cols)
627663 val Z = Matrices .dense(LBrz .rows, LBrz .cols, ZB .toArray)
628664 val lastZ = sc.parallelize(Seq (((rowsAbs- 1 , colsAbs- 1 ), Z )))
629- val nextTuple = (nextP ++ prevP, nextL ++ prevL, nextU ++ prevU,
630- lastZ ++ prevLi, lastZ ++ prevUi,
631- lastZ ++ prevLD, lastZ ++ prevUD,
632- ABlock )
633- return nextTuple
665+ val nextTuple = new LUSequences (nextP ++ prev.P , nextL ++ prev.L , nextU ++ prev.U ,
666+ lastZ ++ prev.Li , lastZ ++ prev.Ui ,
667+ lastZ ++ prev.LD , lastZ ++ prev.UD , prev.A )
668+ nextTuple
634669 }
635670 else { // recursion block
636- val SBlock = ABlock .SchurComplement
637671 val Li = Matrices .dense(LBrz .rows, LBrz .cols, inv(LBrz ).toArray)
638672 val Ui = Matrices .dense(UBrz .rows, UBrz .cols, inv(UBrz ).toArray)
639673 val nextLi = sc.parallelize(Seq (((rowI, rowI), Li )))
640674 val nextUi = sc.parallelize(Seq (((rowI, rowI), Ui )))
641675
642- val nextLD = ABlock .subBlock(botRangeRel, topRangeRel).
676+ val nextLD = prev. A .subBlock(botRangeRel, topRangeRel).
643677 shiftIndices(botRangeAbs._1, topRangeAbs._1)
644- val nextUD = ABlock .subBlock(topRangeRel, botRangeRel).
678+ val nextUD = prev. A .subBlock(topRangeRel, botRangeRel).
645679 shiftIndices(topRangeAbs._1, botRangeAbs._1)
646680
647- val nextTuple = (nextP ++ prevP , nextL ++ prevL , nextU ++ prevU ,
648- nextLi ++ prevLi , nextUi ++ prevUi ,
649- prevLD ++ nextLD, prevUD ++ nextUD, SBlock )
650- return recursiveSequencesBuild(rowI + 1 , nextTuple)
681+ val nextTuple = new LUSequences (nextP ++ prev. P , nextL ++ prev. L , nextU ++ prev. U ,
682+ nextLi ++ prev. Li , nextUi ++ prev. Ui ,
683+ prev. LD ++ nextLD, prev. UD ++ nextUD, prev. A . SchurComplement )
684+ recursiveSequencesBuild(rowI + 1 , nextTuple)
651685 }
652686 }
653687
@@ -678,21 +712,21 @@ class BlockMatrix @Since("1.3.0") (
678712 val nextUD = this .subBlock(topRange, botRange).
679713 shiftIndices(topRange._1, botRange._1)
680714
681- val nextTuple = (nextP, nextL, nextU, nextLi, nextUi,
715+ val nextTuple = new LUSequences (nextP, nextL, nextU, nextLi, nextUi,
682716 firstZ ++ nextLD, firstZ ++ nextUD,
683717 this .SchurComplement )
684718
685719 // call to recursive build after initialization step
686- val allSequences = recursiveSequencesBuild(1 , nextTuple)
720+ val lastSequences = recursiveSequencesBuild(1 , nextTuple)
687721 val rowsPerBlock = this .rowsPerBlock;
688722 val colsPerBlock = this .colsPerBlock
689- val dP = new BlockMatrix (allSequences._1 , rowsPerBlock, colsPerBlock)
690- val dL = new BlockMatrix (allSequences._2 , rowsPerBlock, colsPerBlock)
691- val dU = new BlockMatrix (allSequences._3 , rowsPerBlock, colsPerBlock)
692- val dLi = new BlockMatrix (allSequences._4 , rowsPerBlock, colsPerBlock)
693- val dUi = new BlockMatrix (allSequences._5 , rowsPerBlock, colsPerBlock)
694- val LD = new BlockMatrix (allSequences._6 , rowsPerBlock, colsPerBlock)
695- val UD = new BlockMatrix (allSequences._7 , rowsPerBlock, colsPerBlock)
723+ val dP = new BlockMatrix (lastSequences. P , rowsPerBlock, colsPerBlock)
724+ val dL = new BlockMatrix (lastSequences. L , rowsPerBlock, colsPerBlock)
725+ val dU = new BlockMatrix (lastSequences. U , rowsPerBlock, colsPerBlock)
726+ val dLi = new BlockMatrix (lastSequences. Li , rowsPerBlock, colsPerBlock)
727+ val dUi = new BlockMatrix (lastSequences. Ui , rowsPerBlock, colsPerBlock)
728+ val LD = new BlockMatrix (lastSequences. LD , rowsPerBlock, colsPerBlock)
729+ val UD = new BlockMatrix (lastSequences. UD , rowsPerBlock, colsPerBlock)
696730
697731 // Large Matrix Multiplication Operations
698732 // dL and dU are the sets of L and U Matrices along the diagonal blocks,
@@ -709,7 +743,7 @@ class BlockMatrix @Since("1.3.0") (
709743 // U = ( d[Linv] * dP * UD + dU )
710744 val UFin = dLi.multiply(dP.multiply(UD )).add(dU)
711745 // val UFin = dLi.multiply(UD).add(dU)
712- return (PFin , LFin , UFin , dLi, dUi)
746+ (PFin , LFin , UFin , dLi, dUi)
713747 }
714748
715749
@@ -728,7 +762,7 @@ class BlockMatrix @Since("1.3.0") (
728762 val P = PLU ._1
729763 val L = PLU ._2
730764 val U = PLU ._3
731- return (P , L , U )
765+ (P , L , U )
732766 }
733767
734768/** For the matrix Equation AX=B, where A is NxN blocks, and X, B are matrices of
@@ -786,8 +820,8 @@ class BlockMatrix @Since("1.3.0") (
786820 val currentY = new BlockMatrix (prevY.blocks ++ nextY.blocks,
787821 this .rowsPerBlock, this .colsPerBlock)
788822
789- if (m == N - 1 ){return currentY} // terminal case
790- else { return recursiveYBuild(m + 1 , currentY)} // recursive case
823+ if (m == N - 1 ){currentY} // terminal case
824+ else {recursiveYBuild(m + 1 , currentY)} // recursive case
791825 }
792826
793827 // Solving LY = PB for Y using (see docs):
@@ -813,8 +847,8 @@ class BlockMatrix @Since("1.3.0") (
813847 this .rowsPerBlock, this .colsPerBlock)
814848
815849
816- if (mRev == 0 ){return currentX} // terminal case
817- else { return recursiveXBuild(mRev - 1 , currentX)} // recursive case
850+ if (mRev == 0 ){currentX} // terminal case
851+ else {recursiveXBuild(mRev - 1 , currentX)} // recursive case
818852 }
819853
820854 // Solving UX = Y for X
@@ -833,7 +867,7 @@ class BlockMatrix @Since("1.3.0") (
833867 this .rowsPerBlock, this .colsPerBlock)
834868
835869 val X = recursiveXBuild(mRev- 1 , firstX)
836- return X
870+ X
837871
838872 }
839873}
0 commit comments