Skip to content

Commit 7237c36

Browse files
committed
[SPARK-8467][MLlib][PySpark] Add LDAModel.describeTopics() in Python
1 parent 74ba952 commit 7237c36

File tree

3 files changed

+98
-20
lines changed

3 files changed

+98
-20
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.api.python
18+
19+
import org.apache.spark.SparkContext
20+
import org.apache.spark.mllib.clustering.LDAModel
21+
import org.apache.spark.mllib.linalg.Matrix
22+
23+
/**
24+
* Wrapper around LDAModel to provide helper methods in Python
25+
*/
26+
private[python] class LDAModelWrapper(model: LDAModel) {
27+
28+
def topicsMatrix(): Matrix = model.topicsMatrix
29+
30+
def vocabSize(): Int = model.vocabSize
31+
32+
def describeTopics(): java.util.List[Array[Any]] = describeTopics(this.model.vocabSize)
33+
34+
def describeTopics(maxTermsPerTopic: Int): java.util.List[Array[Any]] = {
35+
import scala.collection.JavaConversions._
36+
37+
val javaList: java.util.List[Array[Any]] =
38+
model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) =>
39+
var array = Array.empty[Any]
40+
Array.empty[Any] ++ terms ++ termWeights
41+
}.toList
42+
javaList
43+
}
44+
45+
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
46+
}

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ private[python] class PythonMLLibAPI extends Serializable {
517517
topicConcentration: Double,
518518
seed: java.lang.Long,
519519
checkpointInterval: Int,
520-
optimizer: String): LDAModel = {
520+
optimizer: String): LDAModelWrapper = {
521521
val algo = new LDA()
522522
.setK(k)
523523
.setMaxIterations(maxIterations)
@@ -535,7 +535,16 @@ private[python] class PythonMLLibAPI extends Serializable {
535535
case _ => throw new IllegalArgumentException("input values contains invalid type value.")
536536
}
537537
}
538-
algo.run(documents)
538+
val model = algo.run(documents)
539+
new LDAModelWrapper(model)
540+
}
541+
542+
/**
543+
* Load a LDA model
544+
*/
545+
def loadLDAModel(jsc: JavaSparkContext, path: String): LDAModelWrapper = {
546+
val model = DistributedLDAModel.load(jsc.sc, path)
547+
new LDAModelWrapper(model)
539548
}
540549

541550

python/pyspark/mllib/clustering.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ def predictOnValues(self, dstream):
667667
return dstream.mapValues(lambda x: self._model.predict(x))
668668

669669

670-
class LDAModel(JavaModelWrapper):
670+
class LDAModel(JavaModelWrapper, JavaSaveable, Loader):
671671

672672
""" A clustering model derived from the LDA method.
673673
@@ -690,6 +690,21 @@ class LDAModel(JavaModelWrapper):
690690
>>> model = LDA.train(rdd, k=2)
691691
>>> model.vocabSize()
692692
2
693+
>>> topics = model.describeTopics()
694+
>>> len(topics)
695+
2
696+
>>> len(list(topics[0])[0])
697+
2
698+
>>> len(list(topics[0])[1])
699+
2
700+
>>> topics = model.describeTopics(1)
701+
>>> len(topics)
702+
2
703+
>>> len(list(topics[0])[0])
704+
1
705+
>>> len(list(topics[0])[1])
706+
1
707+
693708
>>> topics = model.topicsMatrix()
694709
>>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
695710
>>> assert_almost_equal(topics, topics_expect, 1)
@@ -720,18 +735,27 @@ def vocabSize(self):
720735
"""Vocabulary size (number of terms or terms in the vocabulary)"""
721736
return self.call("vocabSize")
722737

723-
@since('1.5.0')
724-
def save(self, sc, path):
725-
"""Save the LDAModel on to disk.
738+
def describeTopics(self, maxTermsPerTopic=None):
739+
"""Return the topics described by weighted terms.
726740
727-
:param sc: SparkContext
728-
:param path: str, path to where the model needs to be stored.
741+
WARNING: If vocabSize and k are large, this can return a large object!
729742
"""
730-
if not isinstance(sc, SparkContext):
731-
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
732-
if not isinstance(path, basestring):
733-
raise TypeError("path should be a basestring, got type %s" % type(path))
734-
self._java_model.save(sc._jsc.sc(), path)
743+
if maxTermsPerTopic is None:
744+
topics = self.call("describeTopics")
745+
else:
746+
topics = self.call("describeTopics", maxTermsPerTopic)
747+
748+
# Converts the result to make the format similar to Scala.
749+
# The returned value is mixed up with topics and topi weights.
750+
converted = []
751+
for elms in [list(elms) for elms in topics]:
752+
half_len = int(len(elms) / 2)
753+
topics = elms[:half_len]
754+
topicWeights = elms[(-1 * half_len):]
755+
if len(topics) != len(topicWeights):
756+
raise TypeError("Something wrong with a return value: %s" % (topics))
757+
converted.append((topics, topicWeights))
758+
return converted
735759

736760
@classmethod
737761
@since('1.5.0')
@@ -745,9 +769,8 @@ def load(cls, sc, path):
745769
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
746770
if not isinstance(path, basestring):
747771
raise TypeError("path should be a basestring, got type %s" % type(path))
748-
java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load(
749-
sc._jsc.sc(), path)
750-
return cls(java_model)
772+
wrapper_model = callMLlibFunc("loadLDAModel", sc, path)
773+
return LDAModel(wrapper_model)
751774

752775

753776
class LDA(object):
@@ -773,10 +796,10 @@ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
773796
:param optimizer: LDAOptimizer used to perform the actual calculation.
774797
Currently "em", "online" are supported. Default to "em".
775798
"""
776-
model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
777-
docConcentration, topicConcentration, seed,
778-
checkpointInterval, optimizer)
779-
return LDAModel(model)
799+
wrapper_model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
800+
docConcentration, topicConcentration, seed,
801+
checkpointInterval, optimizer)
802+
return LDAModel(wrapper_model)
780803

781804

782805
def _test():

0 commit comments

Comments
 (0)