1616 */
1717package org .apache .spark .mllib .api .python
1818
19- import scala .collection .JavaConverters
20-
2119import org .apache .spark .SparkContext
20+ import org .apache .spark .api .java .JavaSparkContext
2221import org .apache .spark .mllib .clustering .LDAModel
2322import org .apache .spark .mllib .linalg .Matrix
23+ import org .apache .spark .sql .{DataFrame , SQLContext }
2424
2525/**
2626 * Wrapper around LDAModel to provide helper methods in Python
@@ -31,14 +31,17 @@ private[python] class LDAModelWrapper(model: LDAModel) {
3131
3232 def vocabSize (): Int = model.vocabSize
3333
34- def describeTopics (): java.util. List [ Array [ Any ]] = describeTopics(this .model.vocabSize)
34+ def describeTopics (jsc : JavaSparkContext ): DataFrame = describeTopics(this .model.vocabSize, jsc )
3535
36- def describeTopics (maxTermsPerTopic : Int ): java.util.List [Array [Any ]] = {
36+ def describeTopics (maxTermsPerTopic : Int , jsc : JavaSparkContext ): DataFrame = {
37+ val sqlContext = new SQLContext (jsc.sc)
38+ import sqlContext .implicits ._
3739
38- val seq = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) =>
39- Array .empty[Any ] ++ terms ++ termWeights
40- }.toSeq
41- JavaConverters .seqAsJavaListConverter(seq).asJava
40+ // Since the return value of `describeTopics` is a little complicated,
41+ // the return value are converted to `Row` to take advantage of DataFrame serialization.
42+ val topics = model.describeTopics(maxTermsPerTopic)
43+ val rdd = jsc.sc.parallelize(topics)
44+ rdd.toDF(" terms" , " termWeights" )
4245 }
4346
4447 def save (sc : SparkContext , path : String ): Unit = model.save(sc, path)
0 commit comments