diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 088f6a682be8..1b856bda45e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -27,6 +27,7 @@ import scala.util.{Sorting, Try} import scala.util.hashing.byteswap64 import com.github.fommil.netlib.BLAS.{getInstance => blas} +import com.google.common.collect.{Ordering => GuavaOrdering} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats import org.json4s.JsonDSL._ @@ -47,7 +48,7 @@ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} import org.apache.spark.util.random.XORShiftRandom @@ -456,30 +457,39 @@ class ALSModel private[ml] ( num: Int, blockSize: Int): DataFrame = { import srcFactors.sparkSession.implicits._ + import scala.collection.JavaConverters._ val srcFactorsBlocked = blockify(srcFactors.as[(Int, Array[Float])], blockSize) val dstFactorsBlocked = blockify(dstFactors.as[(Int, Array[Float])], blockSize) val ratings = srcFactorsBlocked.crossJoin(dstFactorsBlocked) - .as[(Seq[(Int, Array[Float])], Seq[(Int, Array[Float])])] - .flatMap { case (srcIter, dstIter) => - val m = srcIter.size - val n = math.min(dstIter.size, num) - val output = new Array[(Int, Int, Float)](m * n) - var i = 0 - val pq = new BoundedPriorityQueue[(Int, Float)](num)(Ordering.by(_._2)) - srcIter.foreach { case (srcId, srcFactor) => - dstIter.foreach { case (dstId, dstFactor) => - // We use F2jBLAS which is faster than a call to native BLAS for vector dot product - val score = BLAS.f2jBLAS.sdot(rank, srcFactor, 1, dstFactor, 1) - pq += dstId -> score + .as[(Array[Int], Array[Float], Array[Int], Array[Float])] + .mapPartitions { iter => + var scores: Array[Float] = null + var idxOrd: GuavaOrdering[Int] = null + iter.flatMap { case (srcIds, srcMat, dstIds, dstMat) => + require(srcMat.length == srcIds.length * rank) + require(dstMat.length == dstIds.length * rank) + val m = srcIds.length + val n = dstIds.length + if (scores == null || scores.length < n) { + scores = Array.ofDim[Float](n) + idxOrd = new GuavaOrdering[Int] { + override def compare(left: Int, right: Int): Int = { + Ordering[Float].compare(scores(left), scores(right)) + } + } } - pq.foreach { case (dstId, score) => - output(i) = (srcId, dstId, score) - i += 1 + + Iterator.range(0, m).flatMap { i => + // buffer = i-th vec in srcMat * dstMat + BLAS.f2jBLAS.sgemv("T", rank, n, 1.0F, dstMat, 0, rank, + srcMat, i * rank, 1, 0.0F, scores, 0, 1) + + val srcId = srcIds(i) + idxOrd.greatestOf(Iterator.range(0, n).asJava, num).asScala + .iterator.map { j => (srcId, dstIds(j), scores(j)) } } - pq.clear() } - output.toSeq } // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) @@ -499,9 +509,12 @@ class ALSModel private[ml] ( */ private def blockify( factors: Dataset[(Int, Array[Float])], - blockSize: Int): Dataset[Seq[(Int, Array[Float])]] = { + blockSize: Int): Dataset[(Array[Int], Array[Float])] = { import factors.sparkSession.implicits._ - factors.mapPartitions(_.grouped(blockSize)) + factors.mapPartitions { iter => + iter.grouped(blockSize) + .map(block => (block.map(_._1).toArray, block.flatMap(_._2).toArray)) + } } }