@@ -363,15 +363,12 @@ class ALSModel private[ml] (
363363 * relatively efficient, the approach implemented here is significantly more efficient.
364364 *
365365 * This approach groups factors into blocks and computes the top-k elements per block,
366- * using Level 1 BLAS (dot) and an efficient BoundedPriorityQueue. It then computes the
366+ * using Level 1 BLAS (dot) and an efficient [[ BoundedPriorityQueue ]] . It then computes the
367367 * global top-k by aggregating the per block top-k elements with a [[TopByKeyAggregator ]].
368368 * This significantly reduces the size of intermediate and shuffle data.
369369 * This is the DataFrame equivalent to the approach used in
370370 * [[org.apache.spark.mllib.recommendation.MatrixFactorizationModel ]].
371371 *
372- * Compared with BLAS.dot, the hand-written version used below is more efficient than a call
373- * to the native BLAS backend and the same performance as the fallback F2jBLAS backend.
374- *
375372 * @param srcFactors src factors for which to generate recommendations
376373 * @param dstFactors dst factors used to make recommendations
377374 * @param srcOutputColumn name of the column for the source ID in the output DataFrame
@@ -397,12 +394,15 @@ class ALSModel private[ml] (
397394 val n = math.min(dstIter.size, num)
398395 val output = new Array [(Int , Int , Float )](m * n)
399396 var j = 0
397+ val pq = new BoundedPriorityQueue [(Int , Float )](num)(Ordering .by(_._2))
400398 srcIter.foreach { case (srcId, srcFactor) =>
401- val pq = new BoundedPriorityQueue [(Int , Float )](num)(Ordering .by(_._2))
402399 dstIter.foreach { case (dstId, dstFactor) =>
403400 /**
404401 * The below code is equivalent to
405402 * val score = blas.sdot(rank, srcFactor, 1, dstFactor, 1)
403+ * Compared with BLAS.dot, the hand-written version used below is more efficient than
404+ * a call to the native BLAS backend and the same performance as the fallback
405+ * F2jBLAS backend.
406406 */
407407 var score = 0.0f
408408 var k = 0
@@ -420,6 +420,7 @@ class ALSModel private[ml] (
420420 i += 1
421421 }
422422 j += n
423+ pq.clear()
423424 }
424425 output.toSeq
425426 }
0 commit comments