Skip to content

Commit 70f9d7f

Browse files
sueannjkbradley
authored andcommitted
[SPARK-19535][ML] RecommendForAllUsers RecommendForAllItems for ALS on Dataframe
## What changes were proposed in this pull request? This is a simple implementation of RecommendForAllUsers & RecommendForAllItems for the Dataframe version of ALS. It uses Dataframe operations (not a wrapper on the RDD implementation). Haven't benchmarked against a wrapper, but unit test examples do work. ## How was this patch tested? Unit tests ``` $ build/sbt > mllib/testOnly *ALSSuite -- -z "recommendFor" > mllib/testOnly ``` Author: Your Name <[email protected]> Author: sueann <[email protected]> Closes #17090 from sueann/SPARK-19535.
1 parent 369a148 commit 70f9d7f

File tree

4 files changed

+297
-9
lines changed

4 files changed

+297
-9
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ import org.apache.spark.ml.util._
4040
import org.apache.spark.mllib.linalg.CholeskyDecomposition
4141
import org.apache.spark.mllib.optimization.NNLS
4242
import org.apache.spark.rdd.RDD
43-
import org.apache.spark.sql.{DataFrame, Dataset}
43+
import org.apache.spark.sql.{DataFrame, Dataset, Row}
44+
import org.apache.spark.sql.catalyst.encoders.RowEncoder
4445
import org.apache.spark.sql.functions._
4546
import org.apache.spark.sql.types._
4647
import org.apache.spark.storage.StorageLevel
@@ -284,18 +285,20 @@ class ALSModel private[ml] (
284285
@Since("2.2.0")
285286
def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value)
286287

288+
private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) =>
289+
if (featuresA != null && featuresB != null) {
290+
// TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for
291+
// potential optimization.
292+
blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1)
293+
} else {
294+
Float.NaN
295+
}
296+
}
297+
287298
@Since("2.0.0")
288299
override def transform(dataset: Dataset[_]): DataFrame = {
289300
transformSchema(dataset.schema)
290-
// Register a UDF for DataFrame, and then
291301
// create a new column named map(predictionCol) by running the predict UDF.
292-
val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) =>
293-
if (userFeatures != null && itemFeatures != null) {
294-
blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1)
295-
} else {
296-
Float.NaN
297-
}
298-
}
299302
val predictions = dataset
300303
.join(userFactors,
301304
checkedCast(dataset($(userCol))) === userFactors("id"), "left")
@@ -327,6 +330,64 @@ class ALSModel private[ml] (
327330

328331
@Since("1.6.0")
329332
override def write: MLWriter = new ALSModel.ALSModelWriter(this)
333+
334+
/**
335+
* Returns top `numItems` items recommended for each user, for all users.
336+
* @param numItems max number of recommendations for each user
337+
* @return a DataFrame of (userCol: Int, recommendations), where recommendations are
338+
* stored as an array of (itemCol: Int, rating: Float) Rows.
339+
*/
340+
@Since("2.2.0")
341+
def recommendForAllUsers(numItems: Int): DataFrame = {
342+
recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems)
343+
}
344+
345+
/**
346+
* Returns top `numUsers` users recommended for each item, for all items.
347+
* @param numUsers max number of recommendations for each item
348+
* @return a DataFrame of (itemCol: Int, recommendations), where recommendations are
349+
* stored as an array of (userCol: Int, rating: Float) Rows.
350+
*/
351+
@Since("2.2.0")
352+
def recommendForAllItems(numUsers: Int): DataFrame = {
353+
recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers)
354+
}
355+
356+
/**
357+
* Makes recommendations for all users (or items).
358+
* @param srcFactors src factors for which to generate recommendations
359+
* @param dstFactors dst factors used to make recommendations
360+
* @param srcOutputColumn name of the column for the source ID in the output DataFrame
361+
* @param dstOutputColumn name of the column for the destination ID in the output DataFrame
362+
* @param num max number of recommendations for each record
363+
* @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are
364+
* stored as an array of (dstOutputColumn: Int, rating: Float) Rows.
365+
*/
366+
private def recommendForAll(
367+
srcFactors: DataFrame,
368+
dstFactors: DataFrame,
369+
srcOutputColumn: String,
370+
dstOutputColumn: String,
371+
num: Int): DataFrame = {
372+
import srcFactors.sparkSession.implicits._
373+
374+
val ratings = srcFactors.crossJoin(dstFactors)
375+
.select(
376+
srcFactors("id"),
377+
dstFactors("id"),
378+
predict(srcFactors("features"), dstFactors("features")))
379+
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
380+
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
381+
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
382+
.toDF("id", "recommendations")
383+
384+
val arrayType = ArrayType(
385+
new StructType()
386+
.add(dstOutputColumn, IntegerType)
387+
.add("rating", FloatType)
388+
)
389+
recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType)
390+
}
330391
}
331392

332393
@Since("1.6.0")
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.recommendation
19+
20+
import scala.language.implicitConversions
21+
import scala.reflect.runtime.universe.TypeTag
22+
23+
import org.apache.spark.sql.{Encoder, Encoders}
24+
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
25+
import org.apache.spark.sql.expressions.Aggregator
26+
import org.apache.spark.util.BoundedPriorityQueue
27+
28+
29+
/**
30+
* Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds
31+
* the top `num` K2 items based on the given Ordering.
32+
*/
33+
private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag]
34+
(num: Int, ord: Ordering[(K2, V)])
35+
extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] {
36+
37+
override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord)
38+
39+
override def reduce(
40+
q: BoundedPriorityQueue[(K2, V)],
41+
a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = {
42+
q += {(a._2, a._3)}
43+
}
44+
45+
override def merge(
46+
q1: BoundedPriorityQueue[(K2, V)],
47+
q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = {
48+
q1 ++= q2
49+
}
50+
51+
override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = {
52+
r.toArray.sorted(ord.reverse)
53+
}
54+
55+
override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = {
56+
Encoders.kryo[BoundedPriorityQueue[(K2, V)]]
57+
}
58+
59+
override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]()
60+
}

mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Random
2222

2323
import scala.collection.mutable
2424
import scala.collection.mutable.ArrayBuffer
25+
import scala.collection.mutable.WrappedArray
2526
import scala.collection.JavaConverters._
2627
import scala.language.existentials
2728

@@ -660,6 +661,99 @@ class ALSSuite
660661
model.setColdStartStrategy(s).transform(data)
661662
}
662663
}
664+
665+
private def getALSModel = {
666+
val spark = this.spark
667+
import spark.implicits._
668+
669+
val userFactors = Seq(
670+
(0, Array(6.0f, 4.0f)),
671+
(1, Array(3.0f, 4.0f)),
672+
(2, Array(3.0f, 6.0f))
673+
).toDF("id", "features")
674+
val itemFactors = Seq(
675+
(3, Array(5.0f, 6.0f)),
676+
(4, Array(6.0f, 2.0f)),
677+
(5, Array(3.0f, 6.0f)),
678+
(6, Array(4.0f, 1.0f))
679+
).toDF("id", "features")
680+
val als = new ALS().setRank(2)
681+
new ALSModel(als.uid, als.getRank, userFactors, itemFactors)
682+
.setUserCol("user")
683+
.setItemCol("item")
684+
}
685+
686+
test("recommendForAllUsers with k < num_items") {
687+
val topItems = getALSModel.recommendForAllUsers(2)
688+
assert(topItems.count() == 3)
689+
assert(topItems.columns.contains("user"))
690+
691+
val expected = Map(
692+
0 -> Array((3, 54f), (4, 44f)),
693+
1 -> Array((3, 39f), (5, 33f)),
694+
2 -> Array((3, 51f), (5, 45f))
695+
)
696+
checkRecommendations(topItems, expected, "item")
697+
}
698+
699+
test("recommendForAllUsers with k = num_items") {
700+
val topItems = getALSModel.recommendForAllUsers(4)
701+
assert(topItems.count() == 3)
702+
assert(topItems.columns.contains("user"))
703+
704+
val expected = Map(
705+
0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
706+
1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)),
707+
2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f))
708+
)
709+
checkRecommendations(topItems, expected, "item")
710+
}
711+
712+
test("recommendForAllItems with k < num_users") {
713+
val topUsers = getALSModel.recommendForAllItems(2)
714+
assert(topUsers.count() == 4)
715+
assert(topUsers.columns.contains("item"))
716+
717+
val expected = Map(
718+
3 -> Array((0, 54f), (2, 51f)),
719+
4 -> Array((0, 44f), (2, 30f)),
720+
5 -> Array((2, 45f), (0, 42f)),
721+
6 -> Array((0, 28f), (2, 18f))
722+
)
723+
checkRecommendations(topUsers, expected, "user")
724+
}
725+
726+
test("recommendForAllItems with k = num_users") {
727+
val topUsers = getALSModel.recommendForAllItems(3)
728+
assert(topUsers.count() == 4)
729+
assert(topUsers.columns.contains("item"))
730+
731+
val expected = Map(
732+
3 -> Array((0, 54f), (2, 51f), (1, 39f)),
733+
4 -> Array((0, 44f), (2, 30f), (1, 26f)),
734+
5 -> Array((2, 45f), (0, 42f), (1, 33f)),
735+
6 -> Array((0, 28f), (2, 18f), (1, 16f))
736+
)
737+
checkRecommendations(topUsers, expected, "user")
738+
}
739+
740+
private def checkRecommendations(
741+
topK: DataFrame,
742+
expected: Map[Int, Array[(Int, Float)]],
743+
dstColName: String): Unit = {
744+
val spark = this.spark
745+
import spark.implicits._
746+
747+
assert(topK.columns.contains("recommendations"))
748+
topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: Seq[(Int, Float)]) =>
749+
assert(recs === expected(id))
750+
}
751+
topK.collect().foreach { row =>
752+
val recs = row.getAs[WrappedArray[Row]]("recommendations")
753+
assert(recs(0).fieldIndex(dstColName) == 0)
754+
assert(recs(0).fieldIndex("rating") == 1)
755+
}
756+
}
663757
}
664758

665759
class ALSCleanerSuite extends SparkFunSuite {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.recommendation
19+
20+
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.mllib.util.MLlibTestSparkContext
22+
import org.apache.spark.sql.Dataset
23+
24+
25+
class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext {
26+
27+
private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = {
28+
val sqlContext = spark.sqlContext
29+
import sqlContext.implicits._
30+
31+
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2))
32+
Seq(
33+
(0, 3, 54f),
34+
(0, 4, 44f),
35+
(0, 5, 42f),
36+
(0, 6, 28f),
37+
(1, 3, 39f),
38+
(2, 3, 51f),
39+
(2, 5, 45f),
40+
(2, 6, 18f)
41+
).toDS().groupByKey(_._1).agg(topKAggregator.toColumn)
42+
}
43+
44+
test("topByKey with k < #items") {
45+
val topK = getTopK(2)
46+
assert(topK.count() === 3)
47+
48+
val expected = Map(
49+
0 -> Array((3, 54f), (4, 44f)),
50+
1 -> Array((3, 39f)),
51+
2 -> Array((3, 51f), (5, 45f))
52+
)
53+
checkTopK(topK, expected)
54+
}
55+
56+
test("topByKey with k > #items") {
57+
val topK = getTopK(5)
58+
assert(topK.count() === 3)
59+
60+
val expected = Map(
61+
0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)),
62+
1 -> Array((3, 39f)),
63+
2 -> Array((3, 51f), (5, 45f), (6, 18f))
64+
)
65+
checkTopK(topK, expected)
66+
}
67+
68+
private def checkTopK(
69+
topK: Dataset[(Int, Array[(Int, Float)])],
70+
expected: Map[Int, Array[(Int, Float)]]): Unit = {
71+
topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) }
72+
}
73+
}

0 commit comments

Comments
 (0)