Skip to content

Commit 1c38d42

Browse files
rotationsymmetrymengxr
authored andcommitted
[SPARK-9175] [MLLIB] BLAS.gemm fails to update matrix C when alpha==0 and beta!=1
Fix BLAS.gemm to update matrix C when alpha==0 and beta!=1 Also include unit tests to verify the fix. mengxr brkyvz Author: Meihua Wu <[email protected]> Closes #7503 from rotationsymmetry/fix_BLAS_gemm and squashes the following commits: fce199c [Meihua Wu] Fix BLAS.gemm to update C when alpha==0 and beta!=1 (cherry picked from commit ff3c72d) Signed-off-by: Xiangrui Meng <[email protected]>
1 parent 429eedd commit 1c38d42

File tree

2 files changed

+17
-2
lines changed

2 files changed

+17
-2
lines changed

mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ private[spark] object BLAS extends Serializable with Logging {
247247
B: DenseMatrix,
248248
beta: Double,
249249
C: DenseMatrix): Unit = {
250-
if (alpha == 0.0) {
251-
logDebug("gemm: alpha is equal to 0. Returning C.")
250+
if (alpha == 0.0 && beta == 1.0) {
251+
logDebug("gemm: alpha is equal to 0 and beta is equal to 1. Returning C.")
252252
} else {
253253
A match {
254254
case sparse: SparseMatrix =>

mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,14 @@ class BLASSuite extends FunSuite {
147147
val C6 = C1.copy
148148
val C7 = C1.copy
149149
val C8 = C1.copy
150+
val C13 = C1.copy
151+
val C14 = C1.copy
152+
val C15 = C1.copy
153+
val C16 = C1.copy
150154
val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0))
151155
val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0))
156+
val expected4 = new DenseMatrix(4, 2, Array(5.0, 0.0, 10.0, 5.0, 0.0, 0.0, 5.0, 0.0))
157+
val expected5 = C1.copy
152158

153159
gemm(1.0, dA, B, 2.0, C1)
154160
gemm(1.0, sA, B, 2.0, C2)
@@ -181,6 +187,15 @@ class BLASSuite extends FunSuite {
181187
assert(C6 ~== expected2 absTol 1e-15)
182188
assert(C7 ~== expected3 absTol 1e-15)
183189
assert(C8 ~== expected3 absTol 1e-15)
190+
191+
gemm(0, dA, B, 5, C13)
192+
gemm(0, sA, B, 5, C14)
193+
gemm(0, dA, B, 1, C15)
194+
gemm(0, sA, B, 1, C16)
195+
assert(C13 ~== expected4 absTol 1e-15)
196+
assert(C14 ~== expected4 absTol 1e-15)
197+
assert(C15 ~== expected5 absTol 1e-15)
198+
assert(C16 ~== expected5 absTol 1e-15)
184199
}
185200

186201
test("gemv") {

0 commit comments

Comments
 (0)