Skip to content

Commit 4692769

Browse files
yu-iskwjkbradley
authored andcommitted
[SPARK-6259] [MLLIB] Python API for LDA
I implemented the Python API for LDA. But I didn't implemented a method for `LDAModel.describeTopics()`, beause it's a little hard to implement it now. And adding document about that and an example code would fit for another issue. TODO: LDAModel.describeTopics() in Python must be also implemented. But it would be nice to fit for another issue. Implementing it is a little hard, since the return value of `describeTopics` in Scala consists of Tuple classes. Author: Yu ISHIKAWA <[email protected]> Closes apache#6791 from yu-iskw/SPARK-6259 and squashes the following commits: 6855f59 [Yu ISHIKAWA] LDA inherits object 28bd165 [Yu ISHIKAWA] Change the place of testing code d7a332a [Yu ISHIKAWA] Remove the doc comment about the optimizer's default value 083e226 [Yu ISHIKAWA] Add the comment about the supported values and the default value of `optimizer` 9f8bed8 [Yu ISHIKAWA] Simplify casting faa9764 [Yu ISHIKAWA] Add some comments for the LDA paramters 98f645a [Yu ISHIKAWA] Remove the interface for `describeTopics`. Because it is not implemented. 57ac03d [Yu ISHIKAWA] Remove the unnecessary import in Python unit testing 73412c3 [Yu ISHIKAWA] Fix the typo 2278829 [Yu ISHIKAWA] Fix the indentation 39514ec [Yu ISHIKAWA] Modify how to cast the input data 8117e18 [Yu ISHIKAWA] Fix the validation problems by `lint-scala` 77fd1b7 [Yu ISHIKAWA] Not use LabeledPoint 68f0653 [Yu ISHIKAWA] Support some parameters for `ALS.train()` in Python 25ef2ac [Yu ISHIKAWA] Resolve conflicts with rebasing
1 parent c6b1a9e commit 4692769

File tree

2 files changed

+98
-1
lines changed

2 files changed

+98
-1
lines changed

mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,39 @@ private[python] class PythonMLLibAPI extends Serializable {
502502
new MatrixFactorizationModelWrapper(model)
503503
}
504504

505+
/**
506+
* Java stub for Python mllib LDA.run()
507+
*/
508+
def trainLDAModel(
509+
data: JavaRDD[java.util.List[Any]],
510+
k: Int,
511+
maxIterations: Int,
512+
docConcentration: Double,
513+
topicConcentration: Double,
514+
seed: java.lang.Long,
515+
checkpointInterval: Int,
516+
optimizer: String): LDAModel = {
517+
val algo = new LDA()
518+
.setK(k)
519+
.setMaxIterations(maxIterations)
520+
.setDocConcentration(docConcentration)
521+
.setTopicConcentration(topicConcentration)
522+
.setCheckpointInterval(checkpointInterval)
523+
.setOptimizer(optimizer)
524+
525+
if (seed != null) algo.setSeed(seed)
526+
527+
val documents = data.rdd.map(_.asScala.toArray).map { r =>
528+
r(0) match {
529+
case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
530+
case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
531+
case _ => throw new IllegalArgumentException("input values contains invalid type value.")
532+
}
533+
}
534+
algo.run(documents)
535+
}
536+
537+
505538
/**
506539
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
507540
* to the Java object instead of the content of the Java object. Extra care

python/pyspark/mllib/clustering.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,13 +31,15 @@
3131
from pyspark.rdd import RDD, ignore_unicode_prefix
3232
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py
3333
from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
34+
from pyspark.mllib.regression import LabeledPoint
3435
from pyspark.mllib.stat.distribution import MultivariateGaussian
3536
from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable
3637
from pyspark.streaming import DStream
3738

3839
__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture',
3940
'PowerIterationClusteringModel', 'PowerIterationClustering',
40-
'StreamingKMeans', 'StreamingKMeansModel']
41+
'StreamingKMeans', 'StreamingKMeansModel',
42+
'LDA', 'LDAModel']
4143

4244

4345
@inherit_doc
@@ -563,6 +565,68 @@ def predictOnValues(self, dstream):
563565
return dstream.mapValues(lambda x: self._model.predict(x))
564566

565567

568+
class LDAModel(JavaModelWrapper):
569+
570+
""" A clustering model derived from the LDA method.
571+
572+
Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
573+
Terminology
574+
- "word" = "term": an element of the vocabulary
575+
- "token": instance of a term appearing in a document
576+
- "topic": multinomial distribution over words representing some concept
577+
References:
578+
- Original LDA paper (journal version):
579+
Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
580+
581+
>>> from pyspark.mllib.linalg import Vectors
582+
>>> from numpy.testing import assert_almost_equal
583+
>>> data = [
584+
... [1, Vectors.dense([0.0, 1.0])],
585+
... [2, SparseVector(2, {0: 1.0})],
586+
... ]
587+
>>> rdd = sc.parallelize(data)
588+
>>> model = LDA.train(rdd, k=2)
589+
>>> model.vocabSize()
590+
2
591+
>>> topics = model.topicsMatrix()
592+
>>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
593+
>>> assert_almost_equal(topics, topics_expect, 1)
594+
"""
595+
596+
def topicsMatrix(self):
597+
"""Inferred topics, where each topic is represented by a distribution over terms."""
598+
return self.call("topicsMatrix").toArray()
599+
600+
def vocabSize(self):
601+
"""Vocabulary size (number of terms or terms in the vocabulary)"""
602+
return self.call("vocabSize")
603+
604+
605+
class LDA(object):
606+
607+
@classmethod
608+
def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
609+
topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"):
610+
"""Train a LDA model.
611+
612+
:param rdd: RDD of data points
613+
:param k: Number of clusters you want
614+
:param maxIterations: Number of iterations. Default to 20
615+
:param docConcentration: Concentration parameter (commonly named "alpha")
616+
for the prior placed on documents' distributions over topics ("theta").
617+
:param topicConcentration: Concentration parameter (commonly named "beta" or "eta")
618+
for the prior placed on topics' distributions over terms.
619+
:param seed: Random Seed
620+
:param checkpointInterval: Period (in iterations) between checkpoints.
621+
:param optimizer: LDAOptimizer used to perform the actual calculation.
622+
Currently "em", "online" are supported. Default to "em".
623+
"""
624+
model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
625+
docConcentration, topicConcentration, seed,
626+
checkpointInterval, optimizer)
627+
return LDAModel(model)
628+
629+
566630
def _test():
567631
import doctest
568632
import pyspark.mllib.clustering

0 commit comments

Comments
 (0)