@@ -27,6 +27,7 @@ import scala.util.{Sorting, Try}
2727import scala .util .hashing .byteswap64
2828
2929import com .github .fommil .netlib .BLAS .{getInstance => blas }
30+ import com .google .common .collect .{Ordering => GuavaOrdering }
3031import org .apache .hadoop .fs .Path
3132import org .json4s .DefaultFormats
3233import org .json4s .JsonDSL ._
@@ -47,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
4748import org .apache .spark .sql .functions ._
4849import org .apache .spark .sql .types ._
4950import org .apache .spark .storage .StorageLevel
50- import org .apache .spark .util .{ BoundedPriorityQueue , Utils }
51+ import org .apache .spark .util .Utils
5152import org .apache .spark .util .collection .{OpenHashMap , OpenHashSet , SortDataFormat , Sorter }
5253import org .apache .spark .util .random .XORShiftRandom
5354
@@ -456,37 +457,35 @@ class ALSModel private[ml] (
456457 num : Int ,
457458 blockSize : Int ): DataFrame = {
458459 import srcFactors .sparkSession .implicits ._
460+ import ALSModel .TopSelector
459461
460462 val srcFactorsBlocked = blockify(srcFactors.as[(Int , Array [Float ])], blockSize)
461463 val dstFactorsBlocked = blockify(dstFactors.as[(Int , Array [Float ])], blockSize)
462464 val partialRecs = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
463465 .as[(Array [Int ], Array [Float ], Array [Int ], Array [Float ])]
464466 .mapPartitions { iter =>
465467 var buffer : Array [Float ] = null
466- val pq = new BoundedPriorityQueue [( Int , Float )](num)( Ordering .by(_._2))
468+ var selector : TopSelector = null
467469 iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
468470 require(srcMat.length == srcIds.length * rank)
469471 require(dstMat.length == dstIds.length * rank)
470472 val m = srcIds.length
471473 val n = dstIds.length
472474 if (buffer == null || buffer.length < n) {
473475 buffer = Array .ofDim[Float ](n)
476+ selector = new TopSelector (buffer)
474477 }
475478
476479 Iterator .tabulate(m) { i =>
477480 // buffer = i-th vec in srcMat * dstMat
478481 BLAS .f2jBLAS.sgemv(" T" , rank, n, 1.0F , dstMat, 0 , rank,
479482 srcMat, i * rank, 1 , 0.0F , buffer, 0 , 1 )
480-
481- pq.clear()
482- var j = 0
483- while (j < n) { pq += dstIds(j) -> buffer(j); j += 1 }
484- val (kDstIds, kScores) = pq.toArray.sortBy(- _._2).unzip
485- (srcIds(i), kDstIds, kScores)
483+ val indices = selector.selectTopKIndices(Iterator .range(0 , n), num)
484+ (srcIds(i), indices.map(dstIds), indices.map(buffer))
486485 }
487486 } ++ {
488487 buffer = null
489- pq.clear()
488+ selector = null
490489 Iterator .empty
491490 }
492491 }
@@ -564,6 +563,21 @@ object ALSModel extends MLReadable[ALSModel] {
564563 model
565564 }
566565 }
566+
567+ /** select top indices based on values. */
568+ private [recommendation] class TopSelector (val values : Array [Float ]) {
569+ import scala .collection .JavaConverters ._
570+
571+ private val indexOrdering = new GuavaOrdering [Int ] {
572+ override def compare (left : Int , right : Int ): Int = {
573+ Ordering [Float ].compare(values(left), values(right))
574+ }
575+ }
576+
577+ def selectTopKIndices (iterator : Iterator [Int ], k : Int ): Array [Int ] = {
578+ indexOrdering.greatestOf(iterator.asJava, k).asScala.toArray
579+ }
580+ }
567581}
568582
569583/**
0 commit comments