Skip to content

Commit b645968

Browse files
committed
use gemv
1 parent b923b56 commit b645968

File tree

1 file changed

+7
-7
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/recommendation

1 file changed

+7
-7
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -469,18 +469,18 @@ class ALSModel private[ml] (
469469
require(dstMat.length == dstIds.length * rank)
470470
val m = srcIds.length
471471
val n = dstIds.length
472-
if (buffer == null || buffer.length < m * n) {
473-
buffer = Array.ofDim[Float](m * n)
472+
if (buffer == null || buffer.length < n) {
473+
buffer = Array.ofDim[Float](n)
474474
}
475475

476-
BLAS.f2jBLAS.sgemm("T", "N", m, n, rank, 1.0F,
477-
srcMat, rank, dstMat, rank, 0.0F, buffer, m)
478-
479476
Iterator.range(0, m).flatMap { i =>
480-
val srcId = srcIds(i)
477+
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
478+
srcMat, i * rank, 1, 0.0F, buffer, 0, 1)
479+
481480
pq.clear()
482481
var j = 0
483-
while (j < n) { pq += dstIds(j) -> buffer(i + j * m); j += 1 }
482+
while (j < n) { pq += dstIds(j) -> buffer(j); j += 1 }
483+
val srcId = srcIds(i)
484484
pq.iterator.map { case (dstId, value) => (srcId, dstId, value) }
485485
}
486486
} ++ {

0 commit comments

Comments
 (0)