Skip to content

Commit 14cdbf6

Browse files
author
Peng
committed
Optimize ALS recommendForAll
1 parent 2eaf4f3 commit 14cdbf6

File tree

1 file changed

+24
-31
lines changed

1 file changed

+24
-31
lines changed

mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717

1818
package org.apache.spark.mllib.recommendation
1919

20-
import java.io.IOException
21-
import java.lang.{Integer => JavaInteger}
22-
23-
import scala.collection.mutable
24-
20+
import breeze.linalg.min
2521
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
2622
import com.github.fommil.netlib.BLAS.{getInstance => blas}
23+
import java.io.IOException
24+
import java.lang.{Integer => JavaInteger}
2725
import org.apache.hadoop.fs.Path
2826
import org.json4s._
2927
import org.json4s.JsonDSL._
3028
import org.json4s.jackson.JsonMethods._
29+
import scala.collection.mutable
30+
import scala.collection.mutable.PriorityQueue
3131

3232
import org.apache.spark.SparkContext
3333
import org.apache.spark.annotation.Since
@@ -277,17 +277,23 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
277277
val srcBlocks = blockify(rank, srcFeatures)
278278
val dstBlocks = blockify(rank, dstFeatures)
279279
val ratings = srcBlocks.cartesian(dstBlocks).flatMap {
280-
case ((srcIds, srcFactors), (dstIds, dstFactors)) =>
281-
val m = srcIds.length
282-
val n = dstIds.length
283-
val ratings = srcFactors.transpose.multiply(dstFactors)
284-
val output = new Array[(Int, (Int, Double))](m * n)
285-
var k = 0
286-
ratings.foreachActive { (i, j, r) =>
287-
output(k) = (srcIds(i), (dstIds(j), r))
288-
k += 1
289-
}
290-
output.toSeq
280+
case (users, items) =>
281+
val m = users.size
282+
val n = min(items.size, num)
283+
val output = new Array[(Int, (Int, Double))](m * n)
284+
var j = 0
285+
users.foreach (user => {
286+
def order(a: (Int, Double)) = a._2
287+
val pq: PriorityQueue[(Int, Double)] = PriorityQueue()(Ordering.by(order))
288+
items.foreach (item => {
289+
val rate = blas.ddot(rank, user._2, 1, item._2, 1)
290+
pq.enqueue((item._1, rate))
291+
})
292+
for(i <- 0 to n-1)
293+
output(j + i) = (user._1, pq.dequeue())
294+
j += n
295+
})
296+
output.toSeq
291297
}
292298
ratings.topByKey(num)(Ordering.by(_._2))
293299
}
@@ -297,23 +303,10 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
297303
*/
298304
private def blockify(
299305
rank: Int,
300-
features: RDD[(Int, Array[Double])]): RDD[(Array[Int], DenseMatrix)] = {
306+
features: RDD[(Int, Array[Double])]): RDD[Seq[(Int, Array[Double])]] = {
301307
val blockSize = 4096 // TODO: tune the block size
302-
val blockStorage = rank * blockSize
303308
features.mapPartitions { iter =>
304-
iter.grouped(blockSize).map { grouped =>
305-
val ids = mutable.ArrayBuilder.make[Int]
306-
ids.sizeHint(blockSize)
307-
val factors = mutable.ArrayBuilder.make[Double]
308-
factors.sizeHint(blockStorage)
309-
var i = 0
310-
grouped.foreach { case (id, factor) =>
311-
ids += id
312-
factors ++= factor
313-
i += 1
314-
}
315-
(ids.result(), new DenseMatrix(rank, i, factors.result()))
316-
}
309+
iter.grouped(blockSize)
317310
}
318311
}
319312

0 commit comments

Comments
 (0)