From 276591c25418cb0766fa82853fa361d113b02e2c Mon Sep 17 00:00:00 2001 From: Bjarne Fruergaard Date: Thu, 29 Sep 2016 11:32:08 +0200 Subject: [PATCH 1/4] fix gemv --- .../main/scala/org/apache/spark/ml/linalg/BLAS.scala | 10 ++++++++-- .../scala/org/apache/spark/mllib/linalg/BLAS.scala | 10 ++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 41b0c6c89a647..00fc5a5b2e418 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -638,12 +638,18 @@ private[spark] object BLAS extends Serializable { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } + else if (xIndices(k) < Acols(i)) { + k += 1 + } + else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 6a85608706974..314f2c00f5db6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -637,12 +637,18 @@ private[spark] object BLAS extends Serializable with Logging { val indEnd = Arows(rowCounter + 1) var sum = 0.0 var k = 0 - while (k < xNnz && i < indEnd) { + while (i < indEnd && k < xNnz) { if (xIndices(k) == Acols(i)) { sum += Avals(i) * xValues(k) + k += 1 + i += 1 + } + else if (xIndices(k) < Acols(i)) { + k += 1 + } + else { i += 1 } - k += 1 } yValues(rowCounter) = sum * alpha + beta * yValues(rowCounter) rowCounter += 1 From 60bc8a4c2d5fa5feb7bf8bf15b483c37a410b017 Mon Sep 17 00:00:00 2001 From: Bjarne Fruergaard Date: Thu, 29 Sep 2016 11:32:35 +0200 Subject: [PATCH 2/4] additional tests for gemv --- .../scala/org/apache/spark/ml/linalg/BLASSuite.scala | 11 +++++++++++ .../org/apache/spark/mllib/linalg/BLASSuite.scala | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala index 8a9f49792c1cd..efd7fd2af1749 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala @@ -392,6 +392,17 @@ class BLASSuite extends SparkMLFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + + assert(y17 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 80da03cc2efeb..6a0be51811ea1 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -392,6 +392,17 @@ class BLASSuite extends SparkFunSuite { } } + val y17 = new DenseVector(Array(0.0, 0.0)) + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) + .transpose + val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) + + val expected4 = new DenseVector(Array(5.0, 4.0)) + + gemv(1.0, sA3, sx3, 0.0, y17) + + assert(y17 ~== expected4 absTol 1e-15) + val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = From 334a830f70b9cce3b022356be66ad04096b65ace Mon Sep 17 00:00:00 2001 From: Bjarne Fruergaard Date: Thu, 29 Sep 2016 14:25:24 +0200 Subject: [PATCH 3/4] style --- .../src/main/scala/org/apache/spark/ml/linalg/BLAS.scala | 6 ++---- .../src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala index 00fc5a5b2e418..4ca19f3387f07 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala +++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/BLAS.scala @@ -643,11 +643,9 @@ private[spark] object BLAS extends Serializable { sum += Avals(i) * xValues(k) k += 1 i += 1 - } - else if (xIndices(k) < Acols(i)) { + } else if (xIndices(k) < Acols(i)) { k += 1 - } - else { + } else { i += 1 } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 314f2c00f5db6..0cd68a633c0b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -642,11 +642,9 @@ private[spark] object BLAS extends Serializable with Logging { sum += Avals(i) * xValues(k) k += 1 i += 1 - } - else if (xIndices(k) < Acols(i)) { + } else if (xIndices(k) < Acols(i)) { k += 1 - } - else { + } else { i += 1 } } From cdc36422f6e73dbe239c9d7ed568bbb138fae266 Mon Sep 17 00:00:00 2001 From: Bjarne Fruergaard Date: Thu, 29 Sep 2016 14:41:38 +0200 Subject: [PATCH 4/4] adds tests validating the (working) non-transpose case --- .../test/scala/org/apache/spark/ml/linalg/BLASSuite.scala | 6 ++++++ .../scala/org/apache/spark/mllib/linalg/BLASSuite.scala | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala index efd7fd2af1749..6e72a5fff0a91 100644 --- a/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala +++ b/mllib-local/src/test/scala/org/apache/spark/ml/linalg/BLASSuite.scala @@ -393,15 +393,21 @@ class BLASSuite extends SparkMLFunSuite { } val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) val expected4 = new DenseVector(Array(5.0, 4.0)) gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 6a0be51811ea1..6e68c1c9d36c8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -393,15 +393,21 @@ class BLASSuite extends SparkFunSuite { } val y17 = new DenseVector(Array(0.0, 0.0)) + val y18 = y17.copy + val sA3 = new SparseMatrix(3, 2, Array(0, 2, 4), Array(1, 2, 0, 1), Array(2.0, 1.0, 1.0, 2.0)) .transpose + val sA4 = + new SparseMatrix(2, 3, Array(0, 1, 3, 4), Array(1, 0, 1, 0), Array(1.0, 2.0, 2.0, 1.0)) val sx3 = new SparseVector(3, Array(1, 2), Array(2.0, 1.0)) val expected4 = new DenseVector(Array(5.0, 4.0)) gemv(1.0, sA3, sx3, 0.0, y17) + gemv(1.0, sA4, sx3, 0.0, y18) assert(y17 ~== expected4 absTol 1e-15) + assert(y18 ~== expected4 absTol 1e-15) val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0))