Skip to content

Commit 6e3cf05

Browse files
committed
Modify describeTopics to take advantage of DataFrame serialization
1 parent f10574e commit 6e3cf05

File tree

2 files changed

+19
-18
lines changed

2 files changed

+19
-18
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
*/
1717
package org.apache.spark.mllib.api.python
1818

19-
import scala.collection.JavaConverters
20-
2119
import org.apache.spark.SparkContext
20+
import org.apache.spark.api.java.JavaSparkContext
2221
import org.apache.spark.mllib.clustering.LDAModel
2322
import org.apache.spark.mllib.linalg.Matrix
23+
import org.apache.spark.sql.{DataFrame, SQLContext}
2424

2525
/**
2626
* Wrapper around LDAModel to provide helper methods in Python
@@ -31,14 +31,17 @@ private[python] class LDAModelWrapper(model: LDAModel) {
3131

3232
def vocabSize(): Int = model.vocabSize
3333

34-
def describeTopics(): java.util.List[Array[Any]] = describeTopics(this.model.vocabSize)
34+
def describeTopics(jsc: JavaSparkContext): DataFrame = describeTopics(this.model.vocabSize, jsc)
3535

36-
def describeTopics(maxTermsPerTopic: Int): java.util.List[Array[Any]] = {
36+
def describeTopics(maxTermsPerTopic: Int, jsc: JavaSparkContext): DataFrame = {
37+
val sqlContext = new SQLContext(jsc.sc)
38+
import sqlContext.implicits._
3739

38-
val seq = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) =>
39-
Array.empty[Any] ++ terms ++ termWeights
40-
}.toSeq
41-
JavaConverters.seqAsJavaListConverter(seq).asJava
40+
// Since the return value of `describeTopics` is a little complicated,
41+
// the return value are converted to `Row` to take advantage of DataFrame serialization.
42+
val topics = model.describeTopics(maxTermsPerTopic)
43+
val rdd = jsc.sc.parallelize(topics)
44+
rdd.toDF("terms", "termWeights")
4245
}
4346

4447
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)

python/pyspark/mllib/clustering.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -741,20 +741,18 @@ def describeTopics(self, maxTermsPerTopic=None):
741741
WARNING: If vocabSize and k are large, this can return a large object!
742742
"""
743743
if maxTermsPerTopic is None:
744-
topics = self.call("describeTopics")
744+
df = self.call("describeTopics", self._sc)
745745
else:
746-
topics = self.call("describeTopics", maxTermsPerTopic)
746+
df = self.call("describeTopics", maxTermsPerTopic, self._sc)
747747

748-
# Converts the result to make the format similar to Scala.
749-
# The returned value is mixed up with topics and topi weights.
750748
converted = []
751-
for elms in [list(elms) for elms in topics]:
752-
half_len = int(len(elms) / 2)
753-
topics = elms[:half_len]
754-
topicWeights = elms[(-1 * half_len):]
755-
if len(topics) != len(topicWeights):
749+
rows = df.collect()
750+
for row in df.collect():
751+
terms = row["terms"]
752+
termWeights = row["termWeights"]
753+
if len(terms) != len(termWeights):
756754
raise TypeError("Something wrong with a return value: %s" % (topics))
757-
converted.append((topics, topicWeights))
755+
converted.append((terms, termWeights))
758756
return converted
759757

760758
@classmethod

0 commit comments

Comments
 (0)