Skip to content

Commit 7dd2b91

Browse files
committed
use guava ordering
1 parent 7861b7b commit 7dd2b91

File tree

1 file changed

+23
-9
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/recommendation

1 file changed

+23
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import scala.util.{Sorting, Try}
2727
import scala.util.hashing.byteswap64
2828

2929
import com.github.fommil.netlib.BLAS.{getInstance => blas}
30+
import com.google.common.collect.{Ordering => GuavaOrdering}
3031
import org.apache.hadoop.fs.Path
3132
import org.json4s.DefaultFormats
3233
import org.json4s.JsonDSL._
@@ -47,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, Dataset}
4748
import org.apache.spark.sql.functions._
4849
import org.apache.spark.sql.types._
4950
import org.apache.spark.storage.StorageLevel
50-
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
51+
import org.apache.spark.util.Utils
5152
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
5253
import org.apache.spark.util.random.XORShiftRandom
5354

@@ -456,37 +457,35 @@ class ALSModel private[ml] (
456457
num: Int,
457458
blockSize: Int): DataFrame = {
458459
import srcFactors.sparkSession.implicits._
460+
import ALSModel.TopSelector
459461

460462
val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize)
461463
val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize)
462464
val partialRecs = srcFactorsBlocked.crossJoin(dstFactorsBlocked)
463465
.as[(Array[Int], Array[Float], Array[Int], Array[Float])]
464466
.mapPartitions { iter =>
465467
var buffer: Array[Float] = null
466-
val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2))
468+
var selector: TopSelector = null
467469
iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) =>
468470
require(srcMat.length == srcIds.length * rank)
469471
require(dstMat.length == dstIds.length * rank)
470472
val m = srcIds.length
471473
val n = dstIds.length
472474
if (buffer == null || buffer.length < n) {
473475
buffer = Array.ofDim[Float](n)
476+
selector = new TopSelector(buffer)
474477
}
475478

476479
Iterator.tabulate(m) { i =>
477480
// buffer = i-th vec in srcMat * dstMat
478481
BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank,
479482
srcMat, i * rank, 1, 0.0F, buffer, 0, 1)
480-
481-
pq.clear()
482-
var j = 0
483-
while (j < n) { pq += dstIds(j) -> buffer(j); j += 1 }
484-
val (kDstIds, kScores) = pq.toArray.sortBy(-_._2).unzip
485-
(srcIds(i), kDstIds, kScores)
483+
val indices = selector.selectTopKIndices(Iterator.range(0, n), num)
484+
(srcIds(i), indices.map(dstIds), indices.map(buffer))
486485
}
487486
} ++ {
488487
buffer = null
489-
pq.clear()
488+
selector = null
490489
Iterator.empty
491490
}
492491
}
@@ -564,6 +563,21 @@ object ALSModel extends MLReadable[ALSModel] {
564563
model
565564
}
566565
}
566+
567+
/** select top indices based on values. */
568+
private[recommendation] class TopSelector(val values: Array[Float]) {
569+
import scala.collection.JavaConverters._
570+
571+
private val indexOrdering = new GuavaOrdering[Int] {
572+
override def compare(left: Int, right: Int): Int = {
573+
Ordering[Float].compare(values(left), values(right))
574+
}
575+
}
576+
577+
def selectTopKIndices(iterator: Iterator[Int], k: Int): Array[Int] = {
578+
indexOrdering.greatestOf(iterator.asJava, k).asScala.toArray
579+
}
580+
}
567581
}
568582

569583
/**

0 commit comments

Comments
 (0)