From 12dae7308c831d76a2e282ca7ee57e8718b0d3c8 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Thu, 8 Jan 2015 14:05:57 +0530 Subject: [PATCH 1/2] [SPARK-4406] FIX: Validate k in SVD --- .../mllib/linalg/distributed/IndexedRowMatrix.scala | 2 ++ .../linalg/distributed/IndexedRowMatrixSuite.scala | 10 ++++++++++ .../mllib/linalg/distributed/RowMatrixSuite.scala | 12 ++++++++++++ 3 files changed, 24 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 36d8cadd2bdd7..f53b3774e062d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -102,6 +102,8 @@ class IndexedRowMatrix( k: Int, computeU: Boolean = false, rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = { + + require(k >= 1, "k should be at least one.") val indices = rows.map(_.index) val svd = toRowMatrix().computeSVD(k, computeU, rCond) val U = if (computeU) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index e25bc02b06c9a..66853ad79435a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -113,6 +113,16 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate k in svd") { + val A = new IndexedRowMatrix(indexedRows) + try { + A.computeSVD(-1) + } catch { + case ie: IllegalArgumentException => + } + } + + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index dbf55ff81ca99..d294d54a5e661 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -171,6 +171,18 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { } } + test("validate k in svd") { + for (mat <- Seq(denseMat, sparseMat)) { + for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { + try { + mat.computeSVD(-1, computeU = true, 1e-6, 300, 1e-10, mode) + } catch { + case ie: IllegalArgumentException => + } + } + } + } + def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } From 64e6d2d993dc30403fc07c547622b2d7a123c77e Mon Sep 17 00:00:00 2001 From: MechCoder Date: Fri, 9 Jan 2015 02:09:49 +0530 Subject: [PATCH 2/2] TST: Add better test errors and messages --- .../mllib/linalg/distributed/IndexedRowMatrix.scala | 3 ++- .../spark/mllib/linalg/distributed/RowMatrix.scala | 2 +- .../mllib/linalg/distributed/IndexedRowMatrixSuite.scala | 9 +++------ .../spark/mllib/linalg/distributed/RowMatrixSuite.scala | 8 ++------ 4 files changed, 8 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index f53b3774e062d..181f507516485 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -103,7 +103,8 @@ class IndexedRowMatrix( computeU: Boolean = false, rCond: Double = 1e-9): SingularValueDecomposition[IndexedRowMatrix, Matrix] = { - require(k >= 1, "k should be at least one.") + val n = numCols().toInt + require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") val indices = rows.map(_.index) val svd = toRowMatrix().computeSVD(k, computeU, rCond) val U = if (computeU) { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index fbd35e372f9b1..d5abba6a4b645 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -212,7 +212,7 @@ class RowMatrix( tol: Double, mode: String): SingularValueDecomposition[RowMatrix, Matrix] = { val n = numCols().toInt - require(k > 0 && k <= n, s"Request up to n singular values but got k=$k and n=$n.") + require(k > 0 && k <= n, s"Requested k singular values but got k=$k and numCols=$n.") object SVDMode extends Enumeration { val LocalARPACK, LocalLAPACK, DistARPACK = Value diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 66853ad79435a..741cd4997b853 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -115,14 +115,11 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { test("validate k in svd") { val A = new IndexedRowMatrix(indexedRows) - try { - A.computeSVD(-1) - } catch { - case ie: IllegalArgumentException => - } + intercept[IllegalArgumentException] { + A.computeSVD(-1) + } } - def closeToZero(G: BDM[Double]): Boolean = { G.valuesIterator.map(math.abs).sum < 1e-6 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index d294d54a5e661..3309713e91f87 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -173,12 +173,8 @@ class RowMatrixSuite extends FunSuite with MLlibTestSparkContext { test("validate k in svd") { for (mat <- Seq(denseMat, sparseMat)) { - for (mode <- Seq("auto", "local-svd", "local-eigs", "dist-eigs")) { - try { - mat.computeSVD(-1, computeU = true, 1e-6, 300, 1e-10, mode) - } catch { - case ie: IllegalArgumentException => - } + intercept[IllegalArgumentException] { + mat.computeSVD(-1) } } }