@@ -52,17 +52,17 @@ private[mllib] class GridPartitioner(
5252 * Returns the index of the partition the input coordinate belongs to.
5353 *
5454 * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in
55- * multiplication.
55+ * multiplication. k is ignored in computing partitions.
5656 * @return The index of the partition, which the coordinate belongs to.
5757 */
5858 override def getPartition (key : Any ): Int = {
5959 key match {
6060 case (i : Int , j : Int ) =>
6161 getPartitionId(i, j)
62- case (i : Int , j : Int , _) =>
62+ case (i : Int , j : Int , _ : Int ) =>
6363 getPartitionId(i, j)
6464 case _ =>
65- throw new IllegalArgumentException (s " Unrecognized key: $key" )
65+ throw new IllegalArgumentException (s " Unrecognized key: $key. " )
6666 }
6767 }
6868
@@ -73,7 +73,6 @@ private[mllib] class GridPartitioner(
7373 i / rowsPerPart + j / colsPerPart * rowPartitions
7474 }
7575
76- /** Checks whether the partitioners have the same characteristics */
7776 override def equals (obj : Any ): Boolean = {
7877 obj match {
7978 case r : GridPartitioner =>
@@ -87,10 +86,12 @@ private[mllib] class GridPartitioner(
8786
8887private [mllib] object GridPartitioner {
8988
89+ /** Creates a new [[GridPartitioner ]] instance. */
9090 def apply (rows : Int , cols : Int , rowsPerPart : Int , colsPerPart : Int ): GridPartitioner = {
9191 new GridPartitioner (rows, cols, rowsPerPart, colsPerPart)
9292 }
9393
94+ /** Creates a new [[GridPartitioner ]] instance with the input suggested number of partitions. */
9495 def apply (rows : Int , cols : Int , suggestedNumPartitions : Int ): GridPartitioner = {
9596 require(suggestedNumPartitions > 0 )
9697 val scale = 1.0 / math.sqrt(suggestedNumPartitions)
@@ -103,24 +104,25 @@ private[mllib] object GridPartitioner {
103104/**
104105 * Represents a distributed matrix in blocks of local matrices.
105106 *
106- * @param rdd The RDD of SubMatrices (local matrices) that form this matrix
107- * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
108- * the number of rows will be calculated when `numRows` is invoked.
109- * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
110- * zero, the number of columns will be calculated when `numCols` is invoked.
107+ * @param blocks The RDD of sub-matrix blocks (blockRowIndex, blockColIndex, sub-matrix) that form
108+ * this distributed matrix.
111109 * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final
112110 * rows are not required to have the given number of rows
113111 * @param colsPerBlock Number of columns that make up each block. The blocks forming the final
114112 * columns are not required to have the given number of columns
113+ * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero,
114+ * the number of rows will be calculated when `numRows` is invoked.
115+ * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to
116+ * zero, the number of columns will be calculated when `numCols` is invoked.
115117 */
116118class BlockMatrix (
117- val rdd : RDD [((Int , Int ), Matrix )],
118- private var nRows : Long ,
119- private var nCols : Long ,
119+ val blocks : RDD [((Int , Int ), Matrix )],
120120 val rowsPerBlock : Int ,
121- val colsPerBlock : Int ) extends DistributedMatrix with Logging {
121+ val colsPerBlock : Int ,
122+ private var nRows : Long ,
123+ private var nCols : Long ) extends DistributedMatrix with Logging {
122124
123- private type SubMatrix = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), matrix)
125+ private type MatrixBlock = ((Int , Int ), Matrix ) // ((blockRowIndex, blockColIndex), sub- matrix)
124126
125127 /**
126128 * Alternate constructor for BlockMatrix without the input of the number of rows and columns.
@@ -135,45 +137,48 @@ class BlockMatrix(
135137 rdd : RDD [((Int , Int ), Matrix )],
136138 rowsPerBlock : Int ,
137139 colsPerBlock : Int ) = {
138- this (rdd, 0L , 0L , rowsPerBlock, colsPerBlock )
140+ this (rdd, rowsPerBlock, colsPerBlock, 0L , 0L )
139141 }
140142
141- private lazy val dims : (Long , Long ) = getDim
142-
143143 override def numRows (): Long = {
144- if (nRows <= 0L ) nRows = dims._1
144+ if (nRows <= 0L ) estimateDim()
145145 nRows
146146 }
147147
148148 override def numCols (): Long = {
149- if (nCols <= 0L ) nCols = dims._2
149+ if (nCols <= 0L ) estimateDim()
150150 nCols
151151 }
152152
153153 val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt
154154 val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt
155155
156156 private [mllib] var partitioner : GridPartitioner =
157- GridPartitioner (numRowBlocks, numColBlocks, suggestedNumPartitions = rdd.partitions.size)
158-
159- /** Returns the dimensions of the matrix. */
160- private def getDim : (Long , Long ) = {
161- val (rows, cols) = rdd.map { case ((blockRowIndex, blockColIndex), mat) =>
162- (blockRowIndex * rowsPerBlock + mat.numRows, blockColIndex * colsPerBlock + mat.numCols)
163- }.reduce((x0, x1) => (math.max(x0._1, x1._1), math.max(x0._2, x1._2)))
164-
165- (math.max(rows, nRows), math.max(cols, nCols))
157+ GridPartitioner (numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size)
158+
159+ /** Estimates the dimensions of the matrix. */
160+ private def estimateDim (): Unit = {
161+ val (rows, cols) = blocks.map { case ((blockRowIndex, blockColIndex), mat) =>
162+ (blockRowIndex.toLong * rowsPerBlock + mat.numRows,
163+ blockColIndex.toLong * colsPerBlock + mat.numCols)
164+ }.reduce { (x0, x1) =>
165+ (math.max(x0._1, x1._1), math.max(x0._2, x1._2))
166+ }
167+ if (nRows <= 0L ) nRows = rows
168+ assert(rows <= nRows, s " The number of rows $rows is more than claimed $nRows. " )
169+ if (nCols <= 0L ) nCols = cols
170+ assert(cols <= nCols, s " The number of columns $cols is more than claimed $nCols. " )
166171 }
167172
168- /** Cache the underlying RDD. */
169- def cache (): BlockMatrix = {
170- rdd .cache()
173+ /** Caches the underlying RDD. */
174+ def cache (): this . type = {
175+ blocks .cache()
171176 this
172177 }
173178
174- /** Set the storage level for the underlying RDD . */
175- def persist (storageLevel : StorageLevel ): BlockMatrix = {
176- rdd .persist(storageLevel)
179+ /** Persists the underlying RDD with the specified storage level . */
180+ def persist (storageLevel : StorageLevel ): this . type = {
181+ blocks .persist(storageLevel)
177182 this
178183 }
179184
@@ -185,22 +190,22 @@ class BlockMatrix(
185190 s " Int.MaxValue. Currently numCols: ${numCols()}" )
186191 require(numRows() * numCols() < Int .MaxValue , " The length of the values array must be " +
187192 s " less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}" )
188- val nRows = numRows().toInt
189- val nCols = numCols().toInt
190- val mem = nRows * nCols / 125000
193+ val m = numRows().toInt
194+ val n = numCols().toInt
195+ val mem = m * n / 125000
191196 if (mem > 500 ) logWarning(s " Storing this matrix will require $mem MB of memory! " )
192197
193- val parts = rdd .collect()
194- val values = new Array [Double ](nRows * nCols )
195- parts .foreach { case ((blockRowIndex, blockColIndex), block ) =>
198+ val localBlocks = blocks .collect()
199+ val values = new Array [Double ](m * n )
200+ localBlocks .foreach { case ((blockRowIndex, blockColIndex), submat ) =>
196201 val rowOffset = blockRowIndex * rowsPerBlock
197202 val colOffset = blockColIndex * colsPerBlock
198- block .foreachActive { (i, j, v) =>
199- val indexOffset = (j + colOffset) * nRows + rowOffset + i
203+ submat .foreachActive { (i, j, v) =>
204+ val indexOffset = (j + colOffset) * m + rowOffset + i
200205 values(indexOffset) = v
201206 }
202207 }
203- new DenseMatrix (nRows, nCols , values)
208+ new DenseMatrix (m, n , values)
204209 }
205210
206211 /** Collects data and assembles a local dense breeze matrix (for test only). */
0 commit comments