Skip to content

Commit baeadd0

Browse files
author
Nick Pentreath
committed
Move PQ outside of foreach and update comments
1 parent 29d6777 commit baeadd0

File tree

1 file changed

+6
-5
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/recommendation

1 file changed

+6
-5
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -363,15 +363,12 @@ class ALSModel private[ml] (
363363
* relatively efficient, the approach implemented here is significantly more efficient.
364364
*
365365
* This approach groups factors into blocks and computes the top-k elements per block,
366-
* using Level 1 BLAS (dot) and an efficient BoundedPriorityQueue. It then computes the
366+
* using Level 1 BLAS (dot) and an efficient [[BoundedPriorityQueue]]. It then computes the
367367
* global top-k by aggregating the per block top-k elements with a [[TopByKeyAggregator]].
368368
* This significantly reduces the size of intermediate and shuffle data.
369369
* This is the DataFrame equivalent to the approach used in
370370
* [[org.apache.spark.mllib.recommendation.MatrixFactorizationModel]].
371371
*
372-
* Compared with BLAS.dot, the hand-written version used below is more efficient than a call
373-
* to the native BLAS backend and the same performance as the fallback F2jBLAS backend.
374-
*
375372
* @param srcFactors src factors for which to generate recommendations
376373
* @param dstFactors dst factors used to make recommendations
377374
* @param srcOutputColumn name of the column for the source ID in the output DataFrame
@@ -397,12 +394,15 @@ class ALSModel private[ml] (
397394
val n = math.min(dstIter.size, num)
398395
val output = new Array[(Int, Int, Float)](m * n)
399396
var j = 0
397+
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
400398
srcIter.foreach { case (srcId, srcFactor) =>
401-
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
402399
dstIter.foreach { case (dstId, dstFactor) =>
403400
/**
404401
* The below code is equivalent to
405402
* val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)
403+
* Compared with BLAS.dot, the hand-written version used below is more efficient than
404+
* a call to the native BLAS backend and the same performance as the fallback
405+
* F2jBLAS backend.
406406
*/
407407
var score = 0.0f
408408
var k = 0
@@ -420,6 +420,7 @@ class ALSModel private[ml] (
420420
i += 1
421421
}
422422
j += n
423+
pq.clear()
423424
}
424425
output.toSeq
425426
}

0 commit comments

Comments
 (0)