Skip to content

Commit 6a7e3d1

Browse files
author
Your Name
committed
no longer needing to cause serialization costs
1 parent b0680db commit 6a7e3d1

File tree

1 file changed

+8
-14
lines changed
  • mllib/src/main/scala/org/apache/spark/ml/recommendation

1 file changed

+8
-14
lines changed

mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -379,20 +379,14 @@ class ALSModel private[ml] (
379379
// We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
380380
val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2))
381381
val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn)
382-
.toDF(srcOutputColumn, "recommendations")
383-
384-
// There is some performance hit from converting the (Int, Float) tuples to
385-
// (dstOutputColumn: Int, rating: Float) structs using .rdd. Need SPARK-16483 for a fix.
386-
val schema = new StructType()
387-
.add(srcOutputColumn, IntegerType)
388-
.add("recommendations",
389-
ArrayType(
390-
StructType(
391-
StructField(dstOutputColumn, IntegerType, nullable = false) ::
392-
StructField("rating", FloatType, nullable = false) ::
393-
Nil
394-
)))
395-
recs.sparkSession.createDataFrame(recs.rdd, schema)
382+
.toDF("id", "recommendations")
383+
384+
val arrayType = ArrayType(
385+
new StructType()
386+
.add(dstOutputColumn, IntegerType)
387+
.add("rating", FloatType)
388+
)
389+
recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType)
396390
}
397391
}
398392

0 commit comments

Comments
 (0)