Skip to content

Commit 0bc114e

Browse files
committed
Revert "Modify describeTopics to take advantage of DataFrame serialization"
This reverts commit 6e3cf05.
1 parent 89cbd77 commit 0bc114e

File tree

2 files changed

+18
-19
lines changed

2 files changed

+18
-19
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/LDAModelWrapper.scala

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
*/
1717
package org.apache.spark.mllib.api.python
1818

19+
import scala.collection.JavaConverters
20+
1921
import org.apache.spark.SparkContext
20-
import org.apache.spark.api.java.JavaSparkContext
2122
import org.apache.spark.mllib.clustering.LDAModel
2223
import org.apache.spark.mllib.linalg.Matrix
23-
import org.apache.spark.sql.{DataFrame, SQLContext}
2424

2525
/**
2626
* Wrapper around LDAModel to provide helper methods in Python
@@ -31,17 +31,14 @@ private[python] class LDAModelWrapper(model: LDAModel) {
3131

3232
def vocabSize(): Int = model.vocabSize
3333

34-
def describeTopics(jsc: JavaSparkContext): DataFrame = describeTopics(this.model.vocabSize, jsc)
34+
def describeTopics(): java.util.List[Array[Any]] = describeTopics(this.model.vocabSize)
3535

36-
def describeTopics(maxTermsPerTopic: Int, jsc: JavaSparkContext): DataFrame = {
37-
val sqlContext = new SQLContext(jsc.sc)
38-
import sqlContext.implicits._
36+
def describeTopics(maxTermsPerTopic: Int): java.util.List[Array[Any]] = {
3937

40-
// Since the return value of `describeTopics` is a little complicated,
41-
// the return value are converted to `Row` to take advantage of DataFrame serialization.
42-
val topics = model.describeTopics(maxTermsPerTopic)
43-
val rdd = jsc.sc.parallelize(topics)
44-
rdd.toDF("terms", "termWeights")
38+
val seq = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) =>
39+
Array.empty[Any] ++ terms ++ termWeights
40+
}.toSeq
41+
JavaConverters.seqAsJavaListConverter(seq).asJava
4542
}
4643

4744
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)

python/pyspark/mllib/clustering.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -741,18 +741,20 @@ def describeTopics(self, maxTermsPerTopic=None):
741741
WARNING: If vocabSize and k are large, this can return a large object!
742742
"""
743743
if maxTermsPerTopic is None:
744-
df = self.call("describeTopics", self._sc)
744+
topics = self.call("describeTopics")
745745
else:
746-
df = self.call("describeTopics", maxTermsPerTopic, self._sc)
746+
topics = self.call("describeTopics", maxTermsPerTopic)
747747

748+
# Converts the result to make the format similar to Scala.
749+
# The returned value is mixed up with topics and topi weights.
748750
converted = []
749-
rows = df.collect()
750-
for row in df.collect():
751-
terms = row["terms"]
752-
termWeights = row["termWeights"]
753-
if len(terms) != len(termWeights):
751+
for elms in [list(elms) for elms in topics]:
752+
half_len = int(len(elms) / 2)
753+
topics = elms[:half_len]
754+
topicWeights = elms[(-1 * half_len):]
755+
if len(topics) != len(topicWeights):
754756
raise TypeError("Something wrong with a return value: %s" % (topics))
755-
converted.append((terms, termWeights))
757+
converted.append((topics, topicWeights))
756758
return converted
757759

758760
@classmethod

0 commit comments

Comments
 (0)