Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -236,6 +237,8 @@ class ALSModel private[ml] (
@transient val itemFactors: DataFrame)
extends Model[ALSModel] with ALSModelParams with MLWritable {

import org.apache.spark.ml.recommendation.ALS.Rating

/** @group setParam */
@Since("1.4.0")
def setUserCol(value: String): this.type = set(userCol, value)
Expand Down Expand Up @@ -269,6 +272,44 @@ class ALSModel private[ml] (
predict(userFactors("features"), itemFactors("features")).as($(predictionCol)))
}

/**
* Recommends top items for all users.
*
* @param num how many items to return for every user.
* @return a DataFrame that stores recommendations in two columns: `user` and `ratings`, where
* every row contains a userID and an array of [[Rating]] objects which contains the
* same userId, recommended itemID and "score".
*/
@Since("2.1.0")
def recommendItemsForUsers(num: Int): DataFrame = {
val spark = userFactors.sparkSession
import spark.implicits._
toMLlibModel.recommendProductsForUsers(num).toDF("user", "ratings")
}

/**
* Recommends top users for all items.
*
* @param num how many users to return for every item.
* @return a DataFrame that stores recommendations in two columns: `item` and `ratings`, where
* every row contains a itemID and an array of [[Rating]] objects which contains the
* same itemID, recommended userID and "score".
*/
@Since("2.1.0")
def recommendUsersForItems(num: Int): DataFrame = {
val spark = userFactors.sparkSession
import spark.implicits._
toMLlibModel.recommendProductsForUsers(num).toDF("item", "ratings")
}

private def toMLlibModel: MatrixFactorizationModel = {
val userFeatures = userFactors.select("id", "features").rdd
.map(r => (r.getInt(0), r.getSeq[Float](1).toArray.map(_.toDouble)))
val itemFeatures = itemFactors.select("id", "features").rdd
.map(r => (r.getInt(0), r.getSeq[Float](1).toArray.map(_.toDouble)))
new MatrixFactorizationModel(rank, userFeatures, itemFeatures)
}

@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
// user and item will be cast to Int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,20 @@ class ALSSuite
implicitPrefs = true, seed = 0)
}

test("recommend for all") {
val spark = this.spark
import spark.implicits._
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
val model = new ALS().fit(ratings.toDF())
val items = model.recommendItemsForUsers(2)
assert(items.count() == 4
&& items.select("ratings").rdd.collect().forall(_.getSeq[Rating[Int]](0).length == 2))

val users = model.recommendUsersForItems(2)
assert(users.count() == 4
&& users.select("ratings").rdd.collect().forall(_.getSeq[Rating[Int]](0).length == 2))
}

test("read/write") {
import ALSSuite._
val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
Expand Down