Skip to content

Commit 2ff0e79

Browse files
yu-iskwdavies
authored andcommitted
[SPARK-8467] [MLLIB] [PYSPARK] Add LDAModel.describeTopics() in Python
Could jkbradley and davies review it? - Create a wrapper class: `LDAModelWrapper` for `LDAModel`. Because we can't deal with the return value of`describeTopics` in Scala from pyspark directly. `Array[(Array[Int], Array[Double])]` is too complicated to convert it. - Add `loadLDAModel` in `PythonMLlibAPI`. Since `LDAModel` in Scala is an abstract class and we need to call `load` of `DistributedLDAModel`. [[SPARK-8467] Add LDAModel.describeTopics() in Python - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8467) Author: Yu ISHIKAWA <[email protected]> Closes #8643 from yu-iskw/SPARK-8467-2.
1 parent 7f74190 commit 2ff0e79

File tree

3 files changed

+75
-17
lines changed

3 files changed

+75
-17
lines changed
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.mllib.api.python
18+
19+
import scala.collection.JavaConverters
20+
21+
import org.apache.spark.SparkContext
22+
import org.apache.spark.mllib.clustering.LDAModel
23+
import org.apache.spark.mllib.linalg.Matrix
24+
25+
/**
26+
* Wrapper around LDAModel to provide helper methods in Python
27+
*/
28+
private[python] class LDAModelWrapper(model: LDAModel) {
29+
30+
def topicsMatrix(): Matrix = model.topicsMatrix
31+
32+
def vocabSize(): Int = model.vocabSize
33+
34+
def describeTopics(): Array[Byte] = describeTopics(this.model.vocabSize)
35+
36+
def describeTopics(maxTermsPerTopic: Int): Array[Byte] = {
37+
val topics = model.describeTopics(maxTermsPerTopic).map { case (terms, termWeights) =>
38+
val jTerms = JavaConverters.seqAsJavaListConverter(terms).asJava
39+
val jTermWeights = JavaConverters.seqAsJavaListConverter(termWeights).asJava
40+
Array[Any](jTerms, jTermWeights)
41+
}
42+
SerDe.dumps(JavaConverters.seqAsJavaListConverter(topics).asJava)
43+
}
44+
45+
def save(sc: SparkContext, path: String): Unit = model.save(sc, path)
46+
}

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ private[python] class PythonMLLibAPI extends Serializable {
517517
topicConcentration: Double,
518518
seed: java.lang.Long,
519519
checkpointInterval: Int,
520-
optimizer: String): LDAModel = {
520+
optimizer: String): LDAModelWrapper = {
521521
val algo = new LDA()
522522
.setK(k)
523523
.setMaxIterations(maxIterations)
@@ -535,7 +535,16 @@ private[python] class PythonMLLibAPI extends Serializable {
535535
case _ => throw new IllegalArgumentException("input values contains invalid type value.")
536536
}
537537
}
538-
algo.run(documents)
538+
val model = algo.run(documents)
539+
new LDAModelWrapper(model)
540+
}
541+
542+
/**
543+
* Load a LDA model
544+
*/
545+
def loadLDAModel(jsc: JavaSparkContext, path: String): LDAModelWrapper = {
546+
val model = DistributedLDAModel.load(jsc.sc, path)
547+
new LDAModelWrapper(model)
539548
}
540549

541550

python/pyspark/mllib/clustering.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def predictOnValues(self, dstream):
671671
return dstream.mapValues(lambda x: self._model.predict(x))
672672

673673

674-
class LDAModel(JavaModelWrapper):
674+
class LDAModel(JavaModelWrapper, JavaSaveable, Loader):
675675

676676
""" A clustering model derived from the LDA method.
677677
@@ -691,9 +691,14 @@ class LDAModel(JavaModelWrapper):
691691
... [2, SparseVector(2, {0: 1.0})],
692692
... ]
693693
>>> rdd = sc.parallelize(data)
694-
>>> model = LDA.train(rdd, k=2)
694+
>>> model = LDA.train(rdd, k=2, seed=1)
695695
>>> model.vocabSize()
696696
2
697+
>>> model.describeTopics()
698+
[([1, 0], [0.5..., 0.49...]), ([0, 1], [0.5..., 0.49...])]
699+
>>> model.describeTopics(1)
700+
[([1], [0.5...]), ([0], [0.5...])]
701+
697702
>>> topics = model.topicsMatrix()
698703
>>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
699704
>>> assert_almost_equal(topics, topics_expect, 1)
@@ -724,18 +729,17 @@ def vocabSize(self):
724729
"""Vocabulary size (number of terms or terms in the vocabulary)"""
725730
return self.call("vocabSize")
726731

727-
@since('1.5.0')
728-
def save(self, sc, path):
729-
"""Save the LDAModel on to disk.
732+
@since('1.6.0')
733+
def describeTopics(self, maxTermsPerTopic=None):
734+
"""Return the topics described by weighted terms.
730735
731-
:param sc: SparkContext
732-
:param path: str, path to where the model needs to be stored.
736+
WARNING: If vocabSize and k are large, this can return a large object!
733737
"""
734-
if not isinstance(sc, SparkContext):
735-
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
736-
if not isinstance(path, basestring):
737-
raise TypeError("path should be a basestring, got type %s" % type(path))
738-
self._java_model.save(sc._jsc.sc(), path)
738+
if maxTermsPerTopic is None:
739+
topics = self.call("describeTopics")
740+
else:
741+
topics = self.call("describeTopics", maxTermsPerTopic)
742+
return topics
739743

740744
@classmethod
741745
@since('1.5.0')
@@ -749,9 +753,8 @@ def load(cls, sc, path):
749753
raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
750754
if not isinstance(path, basestring):
751755
raise TypeError("path should be a basestring, got type %s" % type(path))
752-
java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load(
753-
sc._jsc.sc(), path)
754-
return cls(java_model)
756+
model = callMLlibFunc("loadLDAModel", sc, path)
757+
return LDAModel(model)
755758

756759

757760
class LDA(object):

0 commit comments

Comments
 (0)