Skip to content

Commit e7dafd4

Browse files
lee19mengxr
authored andcommitted
[SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k
I'm sorry that I made #6949 closed by mistake. I pushed codes again. And, I added a test code. > There is a bug that `U.numCols() = self.nCols` in `IndexedRowMatrix.computeSVD()` It should have been `U.numCols() = k = svd.U.numCols()` > ``` self = U * sigma * V.transpose (m x n) = (m x n) * (k x k) * (k x n) //ASIS --> (m x n) = (m x k) * (k x k) * (k x n) //TOBE ``` Author: lee19 <[email protected]> Closes #6953 from lee19/MLlibBugfix and squashes the following commits: c1812a0 [lee19] [SPARK-8563] [MLlib] Used nRows instead of numRows() to reduce a burden. 4b9803b [lee19] [SPARK-8563] [MLlib] Fixed a build error. c2ccd89 [lee19] Added a unit test that validates matrix sizes of svd for [SPARK-8563][MLlib] 8373424 [lee19] [SPARK-8563][MLlib] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k (cherry picked from commit e725262) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 1b5439f commit e7dafd4

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class IndexedRowMatrix(
146146
val indexedRows = indices.zip(svd.U.rows).map { case (i, v) =>
147147
IndexedRow(i, v)
148148
}
149-
new IndexedRowMatrix(indexedRows, nRows, nCols)
149+
new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt)
150150
} else {
151151
null
152152
}

mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
135135
assert(closeToZero(U * brzDiag(s) * V.t - localA))
136136
}
137137

138+
test("validate matrix sizes of svd") {
139+
val k = 2
140+
val A = new IndexedRowMatrix(indexedRows)
141+
val svd = A.computeSVD(k, computeU = true)
142+
assert(svd.U.numRows() === m)
143+
assert(svd.U.numCols() === k)
144+
assert(svd.s.size === k)
145+
assert(svd.V.numRows === n)
146+
assert(svd.V.numCols === k)
147+
}
148+
138149
test("validate k in svd") {
139150
val A = new IndexedRowMatrix(indexedRows)
140151
intercept[IllegalArgumentException] {

0 commit comments

Comments
 (0)