@@ -379,20 +379,14 @@ class ALSModel private[ml] (
379379 // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output.
380380 val topKAggregator = new TopByKeyAggregator [Int , Int , Float ](num, Ordering .by(_._2))
381381 val recs = ratings.as[(Int , Int , Float )].groupByKey(_._1).agg(topKAggregator.toColumn)
382- .toDF(srcOutputColumn, " recommendations" )
383-
384- // There is some performance hit from converting the (Int, Float) tuples to
385- // (dstOutputColumn: Int, rating: Float) structs using .rdd. Need SPARK-16483 for a fix.
386- val schema = new StructType ()
387- .add(srcOutputColumn, IntegerType )
388- .add(" recommendations" ,
389- ArrayType (
390- StructType (
391- StructField (dstOutputColumn, IntegerType , nullable = false ) ::
392- StructField (" rating" , FloatType , nullable = false ) ::
393- Nil
394- )))
395- recs.sparkSession.createDataFrame(recs.rdd, schema)
382+ .toDF(" id" , " recommendations" )
383+
384+ val arrayType = ArrayType (
385+ new StructType ()
386+ .add(dstOutputColumn, IntegerType )
387+ .add(" rating" , FloatType )
388+ )
389+ recs.select($" id" as srcOutputColumn , $" recommendations" cast arrayType)
396390 }
397391}
398392
0 commit comments