diff --git a/mllib/pom.xml b/mllib/pom.xml
index f0928e1268e43..b556df7a38f02 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -50,6 +50,11 @@
spark-sql_${scala.binary.version}
${project.version}
+
+ org.apache.spark
+ spark-graphx_${scala.binary.version}
+ ${project.version}
+
org.eclipse.jetty
jetty-server
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
new file mode 100644
index 0000000000000..b74dc5f67d291
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -0,0 +1,750 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.util.Random
+
+import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, sum => brzSum}
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.GraphImpl
+import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix}
+import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.KryoRegistrator
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.SparkContext._
+import org.apache.spark.util.random.XORShiftRandom
+
+import LDA._
+import LDAUtils._
+
+import scala.collection.mutable.ArrayBuffer
+
+class LDA private[mllib](
+ @transient private var corpus: Graph[VD, ED],
+ private val numTopics: Int,
+ private val numTerms: Int,
+ private var alpha: Double,
+ private var beta: Double,
+ private var alphaAS: Double,
+ private var storageLevel: StorageLevel)
+ extends Serializable with Logging {
+
+ def this(docs: RDD[(DocId, SSV)],
+ numTopics: Int,
+ alpha: Double,
+ beta: Double,
+ alphaAS: Double,
+ storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK,
+ computedModel: Broadcast[LDAModel] = null) {
+ this(initializeCorpus(docs, numTopics, storageLevel, computedModel),
+ numTopics, docs.first()._2.size, alpha, beta, alphaAS, storageLevel)
+ }
+
+ // scalastyle:off
+ /**
+ * 语料库文档数
+ */
+ val numDocs = docVertices.count()
+
+ /**
+ * 语料库总的词数(包含重复)
+ */
+ val numTokens = corpus.edges.map(e => e.attr.size.toDouble).sum().toLong
+
+ def setAlpha(alpha: Double): this.type = {
+ this.alpha = alpha
+ this
+ }
+
+ def setBeta(beta: Double): this.type = {
+ this.beta = beta
+ this
+ }
+
+ def setAlphaAS(alphaAS: Double): this.type = {
+ this.alphaAS = alphaAS
+ this
+ }
+
+ def setStorageLevel(newStorageLevel: StorageLevel): this.type = {
+ this.storageLevel = newStorageLevel
+ this
+ }
+
+ def setSeed(newSeed: Int): this.type = {
+ this.seed = newSeed
+ this
+ }
+
+ def getCorpus = corpus
+
+ // scalastyle:on
+
+ @transient private var seed = new Random().nextInt()
+ @transient private var innerIter = 1
+ @transient private var totalTopicCounter: BDV[Count] = collectTotalTopicCounter(corpus)
+
+ private def termVertices = corpus.vertices.filter(t => t._1 >= 0)
+
+ private def docVertices = corpus.vertices.filter(t => t._1 < 0)
+
+ private def checkpoint(corpus: Graph[VD, ED]): Unit = {
+ if (innerIter % 10 == 0 && corpus.edges.sparkContext.getCheckpointDir.isDefined) {
+ corpus.checkpoint()
+ }
+ }
+
+ private def collectTotalTopicCounter(graph: Graph[VD, ED]): BDV[Count] = {
+ val globalTopicCounter = collectGlobalCounter(graph, numTopics)
+ assert(brzSum(globalTopicCounter) == numTokens)
+ globalTopicCounter
+ }
+
+ private def gibbsSampling(): Unit = {
+ val sampledCorpus = sampleTokens(corpus, totalTopicCounter, innerIter + seed,
+ numTokens, numTopics, numTerms, alpha, alphaAS, beta)
+ sampledCorpus.persist(storageLevel)
+
+ val counterCorpus = updateCounter(sampledCorpus, numTopics)
+ checkpoint(counterCorpus)
+ counterCorpus.persist(storageLevel)
+ // counterCorpus.vertices.count()
+ counterCorpus.edges.count()
+ totalTopicCounter = collectTotalTopicCounter(counterCorpus)
+
+ corpus.edges.unpersist(false)
+ corpus.vertices.unpersist(false)
+ sampledCorpus.edges.unpersist(false)
+ sampledCorpus.vertices.unpersist(false)
+ corpus = counterCorpus
+ innerIter += 1
+ }
+
+ def saveModel(iter: Int = 1): LDAModel = {
+ var termTopicCounter: RDD[(VertexId, VD)] = null
+ for (iter <- 1 to iter) {
+ logInfo(s"Save TopicModel (Iteration $iter/$iter)")
+ var previousTermTopicCounter = termTopicCounter
+ gibbsSampling()
+ val newTermTopicCounter = termVertices
+ termTopicCounter = Option(termTopicCounter).map(_.join(newTermTopicCounter).map {
+ case (term, (a, b)) =>
+ (term, a :+ b)
+ }).getOrElse(newTermTopicCounter)
+
+ termTopicCounter.persist(storageLevel).count()
+ Option(previousTermTopicCounter).foreach(_.unpersist())
+ previousTermTopicCounter = termTopicCounter
+ }
+ val model = LDAModel(numTopics, numTerms, alpha, beta)
+ termTopicCounter.collect().foreach { case (term, counter) =>
+ model.merge(term.toInt, counter)
+ }
+ model.gtc :/= iter.toDouble
+ model.ttc.foreach { ttc =>
+ ttc :/= iter.toDouble
+ ttc.compact()
+ }
+ model
+ }
+
+ def runGibbsSampling(iterations: Int): Unit = {
+ for (iter <- 1 to iterations) {
+ // logInfo(s"Gibbs samplin perplexity $iter: ${perplexity}")
+ // logInfo(s"Gibbs sampling (Iteration $iter/$iterations)")
+ // val startedAt = System.nanoTime()
+ gibbsSampling()
+ // val endAt = System.nanoTime()
+ // val useTime = (endAt - startedAt) / 1e9
+ // logInfo(s"Gibbs sampling use time $iter: $useTime")
+ }
+ }
+
+ def mergeDuplicateTopic(threshold: Double = 0.95D): Map[Int, Int] = {
+ val rows = termVertices.map(t => t._2).map { bsv =>
+ val length = bsv.length
+ val used = bsv.activeSize
+ val index = bsv.index.slice(0, used)
+ val data = bsv.data.slice(0, used).map(_.toDouble)
+ new SSV(length, index, data).asInstanceOf[SV]
+ }
+ val simMatrix = new RowMatrix(rows).columnSimilarities()
+ val minMap = simMatrix.entries.filter { case MatrixEntry(row, column, sim) =>
+ sim > threshold && row != column
+ }.map { case MatrixEntry(row, column, sim) =>
+ (column.toInt, row.toInt)
+ }.groupByKey().map { case (topic, simTopics) =>
+ (topic, simTopics.min)
+ }.collect().toMap
+ if (minMap.size > 0) {
+ corpus = corpus.mapEdges(edges => {
+ edges.attr.map { topic =>
+ minMap.get(topic).getOrElse(topic)
+ }
+ })
+ corpus = updateCounter(corpus, numTopics)
+ }
+ minMap
+ }
+
+
+ // scalastyle:off
+ /**
+ * 词在所有主题分布和该词所在文本的主题分布乘积: p(w)=\sum_{k}{p(k|d)*p(w|k)}=
+ * \sum_{k}{\frac{{n}_{kw}+{\beta }_{w}} {{n}_{k}+\bar{\beta }} \frac{{n}_{kd}+{\alpha }_{k}} {\sum{{n}_{k}}+\bar{\alpha }}}=
+ * \sum_{k} \frac{{\alpha }_{k}{\beta }_{w} + {n}_{kw}{\alpha }_{k} + {n}_{kd}{\beta }_{w} + {n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta }} \frac{1}{\sum{{n}_{k}}+\bar{\alpha }}}
+ * \exp^{-(\sum{\log(p(w))})/N}
+ * N为语料库包含的token数
+ */
+ // scalastyle:on
+ def perplexity(): Double = {
+ val totalTopicCounter = this.totalTopicCounter
+ val numTopics = this.numTopics
+ val numTerms = this.numTerms
+ val alpha = this.alpha
+ val beta = this.beta
+ val totalSize = brzSum(totalTopicCounter)
+ var totalProb = 0D
+
+ // \frac{{\alpha }_{k}{\beta }_{w}}{{n}_{k}+\bar{\beta }}
+ totalTopicCounter.activeIterator.foreach { case (topic, cn) =>
+ totalProb += alpha * beta / (cn + numTerms * beta)
+ }
+
+ val termProb = corpus.mapVertices { (vid, counter) =>
+ val probDist = BSV.zeros[Double](numTopics)
+ if (vid >= 0) {
+ val termTopicCounter = counter
+ // \frac{{n}_{kw}{\alpha }_{k}}{{n}_{k}+\bar{\beta }}
+ termTopicCounter.activeIterator.foreach { case (topic, cn) =>
+ probDist(topic) = cn * alpha /
+ (totalTopicCounter(topic) + numTerms * beta)
+ }
+ } else {
+ val docTopicCounter = counter
+ // \frac{{n}_{kd}{\beta }_{w}}{{n}_{k}+\bar{\beta }}
+ docTopicCounter.activeIterator.foreach { case (topic, cn) =>
+ probDist(topic) = cn * beta /
+ (totalTopicCounter(topic) + numTerms * beta)
+ }
+ }
+ probDist.compact()
+ (counter, probDist)
+ }.mapTriplets { triplet =>
+ val (termTopicCounter, termProb) = triplet.srcAttr
+ val (docTopicCounter, docProb) = triplet.dstAttr
+ val docSize = brzSum(docTopicCounter)
+ val docTermSize = triplet.attr.length
+ var prob = 0D
+
+ // \frac{{n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta}}
+ docTopicCounter.activeIterator.foreach { case (topic, cn) =>
+ prob += cn * termTopicCounter(topic) /
+ (totalTopicCounter(topic) + numTerms * beta)
+ }
+ prob += brzSum(docProb) + brzSum(termProb) + totalProb
+ prob += prob / (docSize + numTopics * alpha)
+
+ docTermSize * Math.log(prob)
+ }.edges.map(t => t.attr).sum()
+
+ math.exp(-1 * termProb / totalSize)
+ }
+}
+
+object LDA {
+
+ private[mllib] type DocId = VertexId
+ private[mllib] type WordId = VertexId
+ private[mllib] type Count = Int
+ private[mllib] type ED = Array[Count]
+ private[mllib] type VD = BSV[Count]
+
+ def train(docs: RDD[(DocId, SSV)],
+ numTopics: Int = 2048,
+ totalIter: Int = 150,
+ alpha: Double = 0.01,
+ beta: Double = 0.01,
+ alphaAS: Double = 0.1): LDAModel = {
+ require(totalIter > 0, "totalIter is less than 0")
+ val topicModeling = new LDA(docs, numTopics, alpha, beta, alphaAS)
+ topicModeling.runGibbsSampling(totalIter - 1)
+ topicModeling.saveModel(1)
+ }
+
+ /**
+ * topicID termID+1:counter termID+1:counter ..
+ */
+ def trainAndSaveModel(
+ docs: RDD[(DocId, SSV)],
+ dir: String,
+ numTopics: Int = 2048,
+ totalIter: Int = 150,
+ alpha: Double = 0.01,
+ beta: Double = 0.01,
+ alphaAS: Double = 0.1): Unit = {
+ import org.apache.spark.mllib.regression.LabeledPoint
+ import org.apache.spark.mllib.util.MLUtils
+ import org.apache.spark.mllib.linalg.Vectors
+ val lda = new LDA(docs, numTopics, alpha, beta, alphaAS)
+ val numTerms = lda.numTerms
+ lda.runGibbsSampling(totalIter)
+ val rdd = lda.termVertices.flatMap { case (termId, counter) =>
+ counter.activeIterator.map { case (topic, cn) =>
+ val sv = BSV.zeros[Double](numTerms)
+ sv(termId.toInt) = cn.toDouble
+ (topic, sv)
+ }
+ }.reduceByKey { (a, b) => a + b }.map { case (topic, sv) =>
+ LabeledPoint(topic.toDouble, Vectors.fromBreeze(sv))
+ }
+ MLUtils.saveAsLibSVMFile(rdd, dir)
+ }
+
+ def incrementalTrain(docs: RDD[(DocId, SSV)],
+ computedModel: LDAModel,
+ alphaAS: Double = 1,
+ totalIter: Int = 150): LDAModel = {
+ require(totalIter > 0, "totalIter is less than 0")
+ val numTopics = computedModel.ttc.size
+ val alpha = computedModel.alpha
+ val beta = computedModel.beta
+
+ val broadcastModel = docs.context.broadcast(computedModel)
+ val topicModeling = new LDA(docs, numTopics, alpha, beta, alphaAS,
+ computedModel = broadcastModel)
+ broadcastModel.unpersist()
+ topicModeling.runGibbsSampling(totalIter - 1)
+ topicModeling.saveModel(1)
+ }
+
+ private[mllib] def sampleTokens(
+ graph: Graph[VD, ED],
+ totalTopicCounter: BDV[Count],
+ innerIter: Long,
+ numTokens: Double,
+ numTopics: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): Graph[VD, ED] = {
+ val parts = graph.edges.partitions.size
+ val nweGraph = graph.mapTriplets(
+ (pid, iter) => {
+ val gen = new XORShiftRandom(parts * innerIter + pid)
+ // table is a per term data structure
+ // in GraphX, edges in a partition are clustered by source IDs (term id in this case)
+ // so, use below simple cache to avoid calculating table each time
+ val lastTable = (new ArrayBuffer[Int](numTopics.toInt),
+ new ArrayBuffer[Int](numTopics.toInt),
+ new ArrayBuffer[Double](numTopics.toInt))
+ var lastVid = None.asInstanceOf[Option[VertexId]]
+ var lastWsum = 0.0
+ val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta)
+ val dData = new Array[Double](numTopics.toInt)
+ val t = generateAlias(dv._2, dv._1)
+ val tSum = dv._1
+ iter.map {
+ triplet =>
+ val termId = triplet.srcId
+ val docId = triplet.dstId
+ val termTopicCounter = triplet.srcAttr
+ val docTopicCounter = triplet.dstAttr
+ val topics = triplet.attr
+ for (i <- 0 until topics.length) {
+ val currentTopic = topics(i)
+ docTopicCounter.synchronized {
+ termTopicCounter.synchronized {
+ dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData,
+ currentTopic, numTokens, numTerms, alpha, alphaAS, beta)
+ if (lastVid != Some(termId) || gen.nextDouble() < 1e-4) {
+ lastWsum = wordTable(lastTable, totalTopicCounter, termTopicCounter,
+ termId, numTokens, numTerms, alpha, alphaAS, beta)
+ lastVid = Some(termId)
+ }
+ val newTopic = tokenSampling(gen, t, tSum, lastTable, termTopicCounter, lastWsum,
+ docTopicCounter, dData, currentTopic)
+
+ if (newTopic != currentTopic) {
+ topics(i) = newTopic
+ docTopicCounter(currentTopic) -= 1
+ docTopicCounter(newTopic) += 1
+ // if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact()
+
+ termTopicCounter(currentTopic) -= 1
+ termTopicCounter(newTopic) += 1
+ // if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact()
+
+ totalTopicCounter(currentTopic) -= 1
+ totalTopicCounter(newTopic) += 1
+ }
+ }
+ }
+ }
+
+ topics
+ }
+ }, TripletFields.All)
+ nweGraph
+ }
+
+ private def updateCounter(graph: Graph[VD, ED], numTopics: Int): Graph[VD, ED] = {
+ val newCounter = graph.aggregateMessages[VD](ctx => {
+ val topics = ctx.attr
+ val vector = BSV.zeros[Count](numTopics)
+ for (topic <- topics) {
+ vector(topic) += 1
+ }
+ ctx.sendToDst(vector)
+ ctx.sendToSrc(vector)
+ }, _ + _, TripletFields.EdgeOnly).mapValues(v => {
+ val used = v.used
+ if (v.index.length == used) {
+ v
+ } else {
+ val index = new Array[Int](used)
+ val data = new Array[Count](used)
+ Array.copy(v.index, 0, index, 0, used)
+ Array.copy(v.data, 0, data, 0, used)
+ new BSV[Count](index, data, numTopics)
+ }
+ })
+ // GraphImpl.fromExistingRDDs(newCounter, graph.edges)
+ GraphImpl(newCounter, graph.edges)
+ }
+
+ private def collectGlobalCounter(graph: Graph[VD, ED], numTopics: Int): BDV[Count] = {
+ graph.vertices.filter(t => t._1 >= 0).map(_._2).
+ aggregate(BDV.zeros[Count](numTopics))((a, b) => {
+ a :+= b
+ }, _ :+= _)
+ }
+
+ private def initializeCorpus(
+ docs: RDD[(LDA.DocId, SSV)],
+ numTopics: Int,
+ storageLevel: StorageLevel,
+ computedModel: Broadcast[LDAModel] = null): Graph[VD, ED] = {
+ val edges = docs.mapPartitionsWithIndex((pid, iter) => {
+ val gen = new Random(pid)
+ var model: LDAModel = null
+ if (computedModel != null) model = computedModel.value
+ iter.flatMap {
+ case (docId, doc) =>
+ val bsv = new BSV[Int](doc.indices, doc.values.map(_.toInt), doc.size)
+ initializeEdges(gen, bsv, docId, numTopics, model)
+ }
+ })
+ edges.persist(storageLevel)
+ var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel)
+ // degree-based hashing
+ val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0) }
+ val numPartitions = edges.partitions.size
+ val partitionStrategy = new DBHPartitioner(numPartitions)
+ val newEdges = degrees.triplets.map { e =>
+ (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr))
+ }.partitionBy(new HashPartitioner(numPartitions)).map(_._2)
+ corpus = Graph.fromEdges(newEdges, null, storageLevel, storageLevel)
+ // end degree-based hashing
+ // corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D)
+ corpus = updateCounter(corpus, numTopics).cache()
+ corpus.vertices.count()
+ corpus.edges.count()
+ edges.unpersist()
+ corpus
+ }
+
+ private def initializeEdges(
+ gen: Random,
+ doc: BSV[Int],
+ docId: DocId,
+ numTopics: Int,
+ computedModel: LDAModel = null): Array[Edge[ED]] = {
+ assert(docId >= 0)
+ val newDocId: DocId = -(docId + 1L)
+ val edges = if (computedModel == null) {
+ doc.activeIterator.filter(_._2 > 0).map { case (termId, counter) =>
+ val topics = new Array[Int](counter)
+ for (i <- 0 until counter) {
+ topics(i) = gen.nextInt(numTopics)
+ }
+ Edge(termId, newDocId, topics)
+ }.toArray
+ }
+ else {
+ computedModel.setSeed(gen.nextInt())
+ val tokens = computedModel.vector2Array(doc)
+ val topics = new Array[Int](tokens.length)
+ var docTopicCounter = computedModel.uniformDistSampler(tokens, topics)
+ for (t <- 0 until 15) {
+ docTopicCounter = computedModel.sampleTokens(docTopicCounter,
+ tokens, topics)
+ }
+ doc.activeIterator.filter(_._2 > 0).map { case (term, counter) =>
+ val ev = topics.zipWithIndex.filter { case (topic, offset) =>
+ term == tokens(offset)
+ }.map(_._1)
+ Edge(term, newDocId, ev)
+ }.toArray
+ }
+ assert(edges.length > 0)
+ edges
+ }
+
+ // scalastyle:off
+ /**
+ * 这里组合使用 Gibbs sampler 和 Metropolis Hastings sampler
+ * 每次采样的复杂度为: O(1)
+ * 使用 Gibbs sampler 采样论文 Rethinking LDA: Why Priors Matter 公式(3)
+ * \frac{{n}_{kw}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta}} \frac{{n}_{kd} ^{-di}+ \bar{\alpha} \frac{{n}_{k}^{-di} + \acute{\alpha}}{\sum{n}_{k} +\bar{\acute{\alpha}}}}{\sum{n}_{kd}^{-di} +\bar{\alpha}}
+ * = t + w + d
+ * t 全局相关部分
+ * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
+ * w 词相关部分
+ * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
+ * d 文档和词的乘积
+ * d = \frac{{n}_{kd}^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
+ * = \frac{{n}_{kd ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) }
+ * 其中
+ * \bar{\beta}=\sum_{w}{\beta}_{w}
+ * \bar{\alpha}=\sum_{k}{\alpha}_{k}
+ * \bar{\acute{\alpha}}=\bar{\acute{\alpha}}=\sum_{k}\acute{\alpha}
+ * {n}_{kd} 文档d中主题为k的tokens数
+ * {n}_{kw} 词中主题为k的tokens数
+ * {n}_{k} 语料库中主题为k的tokens数
+ * -di 减去当前token的主题
+ */
+ // scalastyle:on
+ private def tokenSampling(
+ gen: Random,
+ t: Table,
+ tSum: Double,
+ w: Table,
+ termTopicCounter: VD,
+ wSum: Double,
+ docTopicCounter: VD,
+ dData: Array[Double],
+ currentTopic: Int): Int = {
+ val index = docTopicCounter.index
+ val used = docTopicCounter.used
+ val dSum = dData(docTopicCounter.used - 1)
+ val distSum = tSum + wSum + dSum
+ val genSum = gen.nextDouble() * distSum
+ if (genSum < dSum) {
+ val dGenSum = gen.nextDouble() * dSum
+ val pos = binarySearchInterval(dData, dGenSum, 0, used, true)
+ index(pos)
+ } else if (genSum < (dSum + wSum)) {
+ sampleSV(gen, w, termTopicCounter, currentTopic)
+ } else {
+ sampleAlias(gen, t)
+ }
+ }
+
+
+ /**
+ * 分解后的公式为
+ * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
+ */
+ private def tDense(
+ totalTopicCounter: BDV[Count],
+ numTokens: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): (Double, BDV[Double]) = {
+ val numTopics = totalTopicCounter.length
+ val t = BDV.zeros[Double](numTopics)
+ val alphaSum = alpha * numTopics
+ val termSum = numTokens - 1D + alphaAS * numTopics
+ val betaSum = numTerms * beta
+ var sum = 0.0
+ for (topic <- 0 until numTopics) {
+ val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) /
+ ((totalTopicCounter(topic) + betaSum) * termSum)
+ t(topic) = last
+ sum += last
+ }
+ (sum, t)
+ }
+
+ /**
+ * 分解后的公式为
+ * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
+ */
+ private def wSparse(
+ totalTopicCounter: BDV[Count],
+ termTopicCounter: VD,
+ numTokens: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): (Double, BSV[Double]) = {
+ val numTopics = totalTopicCounter.length
+ val alphaSum = alpha * numTopics
+ val termSum = numTokens - 1D + alphaAS * numTopics
+ val betaSum = numTerms * beta
+ val w = BSV.zeros[Double](numTopics)
+ var sum = 0.0
+ termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t =>
+ val topic = t._1
+ val count = t._2
+ val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) /
+ ((totalTopicCounter(topic) + betaSum) * termSum)
+ w(topic) = last
+ sum += last
+ }
+ (sum, w)
+ }
+
+ /**
+ * 分解后的公式为
+ * d = \frac{{n}_{kd} ^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})}
+ * = \frac{{n}_{kd} ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) }
+ */
+ private def dSparse(
+ totalTopicCounter: BDV[Count],
+ termTopicCounter: VD,
+ docTopicCounter: VD,
+ d: Array[Double],
+ currentTopic: Int,
+ numTokens: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): Unit = {
+ val index = docTopicCounter.index
+ val data = docTopicCounter.data
+ val used = docTopicCounter.used
+
+ // val termSum = numTokens - 1D + alphaAS * numTopics
+ val betaSum = numTerms * beta
+ var sum = 0.0
+ for (i <- 0 until used) {
+ val topic = index(i)
+ val count = data(i)
+ val adjustment = if (currentTopic == topic) -1D else 0
+ val last = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) /
+ (totalTopicCounter(topic) + adjustment + betaSum)
+ // val lastD = (count + adjustment) * termSum * (termTopicCounter(topic) + adjustment + beta) /
+ // ((totalTopicCounter(topic) + adjustment + betaSum) * termSum)
+
+ sum += last
+ d(i) = sum
+ }
+ }
+
+ private def wordTable(
+ table:Table,
+ totalTopicCounter: BDV[Count],
+ termTopicCounter: VD,
+ termId: VertexId,
+ numTokens: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): Double = {
+ val sv = wSparse(totalTopicCounter, termTopicCounter,
+ numTokens, numTerms, alpha, alphaAS, beta)
+ generateAlias(sv._2, sv._1, Some(table))
+ sv._1
+ }
+
+ private def sampleSV(gen: Random, table: Table, sv: VD, currentTopic: Int): Int = {
+ val docTopic = sampleAlias(gen, table)
+ if (docTopic == currentTopic) {
+ val svCounter = sv(currentTopic)
+ // 这里的处理方法不太对.
+ // 如果采样到当前token的Topic这丢弃掉
+ // svCounter == 1 && table.length > 1 采样到token的Topic 但包含其他token
+ // svCounter > 1 && gen.nextDouble() < 1.0 / svCounter 采样的Topic 有1/svCounter 概率属于当前token
+ if ((svCounter == 1 && table._1.length > 1) ||
+ (svCounter > 1 && gen.nextDouble() < 1.0 / svCounter)) {
+ return sampleSV(gen, table, sv, currentTopic)
+ }
+ }
+ docTopic
+ }
+
+}
+
+/**
+ * Degree-Based Hashing, the paper:
+ * http://nips.cc/Conferences/2014/Program/event.php?ID=4569
+ * @param partitions
+ */
+private class DBHPartitioner(partitions: Int) extends Partitioner {
+ val mixingPrime: Long = 1125899906842597L
+
+ def numPartitions = partitions
+
+ def getPartition(key: Any): Int = {
+ val edge = key.asInstanceOf[EdgeTriplet[Int, ED]]
+ val srcDeg = edge.srcAttr
+ val dstDeg = edge.dstAttr
+ val srcId = edge.srcId
+ val dstId = edge.dstId
+ val minId = if (srcDeg < dstDeg) srcId else dstId
+ getPartition(minId)
+ }
+
+ def getPartition(idx: Int): PartitionID = {
+ (math.abs(idx * mixingPrime) % partitions).toInt
+ }
+
+ def getPartition(idx: Long): PartitionID = {
+ (math.abs(idx * mixingPrime) % partitions).toInt
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case h: DBHPartitioner =>
+ h.numPartitions == numPartitions
+ case _ =>
+ false
+ }
+
+ override def hashCode: Int = numPartitions
+}
+
+private[mllib] class LDAKryoRegistrator extends KryoRegistrator {
+ def registerClasses(kryo: com.esotericsoftware.kryo.Kryo) {
+ val gkr = new GraphKryoRegistrator
+ gkr.registerClasses(kryo)
+
+ kryo.register(classOf[BSV[LDA.Count]])
+ kryo.register(classOf[BSV[Double]])
+
+ kryo.register(classOf[BDV[LDA.Count]])
+ kryo.register(classOf[BDV[Double]])
+
+ kryo.register(classOf[SV])
+ kryo.register(classOf[SSV])
+ kryo.register(classOf[SDV])
+
+ kryo.register(classOf[LDA.ED])
+ kryo.register(classOf[LDA.VD])
+
+ kryo.register(classOf[Random])
+ kryo.register(classOf[LDA])
+ kryo.register(classOf[LDAModel])
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
new file mode 100644
index 0000000000000..8be67cc479f38
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -0,0 +1,439 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.lang.ref.SoftReference
+import java.util.Random
+import java.util.{PriorityQueue => JPriorityQueue}
+
+import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV,
+sum => brzSum, norm => brzNorm}
+
+import org.apache.spark.mllib.linalg.{Vectors, DenseVector => SDV, SparseVector => SSV}
+import org.apache.spark.util.collection.AppendOnlyMap
+import org.apache.spark.util.random.XORShiftRandom
+
+import LDAUtils._
+
+import scala.collection.mutable.ArrayBuffer
+
+class LDAModel private[mllib](
+ private[mllib] val gtc: BDV[Double],
+ private[mllib] val ttc: Array[BSV[Double]],
+ val alpha: Double,
+ val beta: Double,
+ val alphaAS: Double) extends Serializable {
+
+ def this(topicCounts: SDV, topicTermCounts: Array[SSV], alpha: Double, beta: Double) {
+ this(new BDV[Double](topicCounts.toArray), topicTermCounts.map(t =>
+ new BSV(t.indices, t.values, t.size)), alpha, beta, alpha)
+ }
+
+ @transient private lazy val numTopics = gtc.size
+ @transient private lazy val numTerms = ttc.size
+ @transient private lazy val numTokens = brzSum(gtc)
+ @transient private lazy val betaSum = numTerms * beta
+ @transient private lazy val alphaSum = numTopics * alpha
+ @transient private lazy val termSum = numTokens + alphaAS * numTopics
+
+ @transient private lazy val wordTableCache =
+ new AppendOnlyMap[Int, SoftReference[(Double, Table)]]()
+ @transient private lazy val (t, tSum) = {
+ val dv = tDense(gtc, numTokens, numTerms, alpha, alphaAS, beta)
+ (generateAlias(dv._2, dv._1), dv._1)
+ }
+ @transient private lazy val rand = new XORShiftRandom()
+
+ def setSeed(seed: Long): Unit = {
+ rand.setSeed(seed)
+ }
+
+ def globalTopicCounter = Vectors.fromBreeze(gtc)
+
+ def topicTermCounter = ttc.map(t => Vectors.fromBreeze(t))
+
+ def inference(
+ doc: SSV,
+ totalIter: Int = 10,
+ burnIn: Int = 5): SSV = {
+ require(totalIter > burnIn, "totalIter is less than burnInIter")
+ require(totalIter > 0, "totalIter is less than 0")
+ require(burnIn > 0, "burnInIter is less than 0")
+
+ val topicDist = BSV.zeros[Double](numTopics)
+ val tokens = vector2Array(new BSV[Int](doc.indices, doc.values.map(_.toInt), doc.size))
+ val topics = new Array[Int](tokens.length)
+
+ var docTopicCounter = uniformDistSampler(tokens, topics)
+ for (i <- 0 until totalIter) {
+ docTopicCounter = sampleTokens(docTopicCounter, tokens, topics)
+ if (i + burnIn >= totalIter) topicDist :+= docTopicCounter
+ }
+
+ topicDist.compact()
+ topicDist :/= brzNorm(topicDist, 1)
+ Vectors.fromBreeze(topicDist).asInstanceOf[SSV]
+ }
+
+ private[mllib] def vector2Array(vec: BV[Int]): Array[Int] = {
+ val docLen = brzSum(vec)
+ var offset = 0
+ val sent = new Array[Int](docLen)
+ vec.activeIterator.foreach { case (term, cn) =>
+ for (i <- 0 until cn) {
+ sent(offset) = term
+ offset += 1
+ }
+ }
+ sent
+ }
+
+ private[mllib] def uniformDistSampler(
+ tokens: Array[Int],
+ topics: Array[Int]): BSV[Double] = {
+ val docTopicCounter = BSV.zeros[Double](numTopics)
+ for (i <- 0 until tokens.length) {
+ val topic = uniformSampler(rand, numTopics)
+ topics(i) = topic
+ docTopicCounter(topic) += 1D
+ }
+ docTopicCounter
+ }
+
+ private[mllib] def sampleTokens(
+ docTopicCounter: BSV[Double],
+ tokens: Array[Int],
+ topics: Array[Int]): BSV[Double] = {
+ for (i <- 0 until topics.length) {
+ val termId = tokens(i)
+ val currentTopic = topics(i)
+ val d = dSparse(gtc, ttc(termId), docTopicCounter,
+ currentTopic, numTokens, numTerms, alpha, alphaAS, beta)
+
+ val (wSum, w) = wordTable(wordTableCache, gtc, ttc(termId), termId,
+ numTokens, numTerms, alpha, alphaAS, beta)
+ val newTopic = tokenSampling(rand, t, tSum, w, wSum, d)
+ if (newTopic != currentTopic) {
+ docTopicCounter(newTopic) += 1D
+ docTopicCounter(currentTopic) -= 1D
+ topics(i) = newTopic
+ if (docTopicCounter(currentTopic) == 0) {
+ docTopicCounter.compact()
+ }
+ }
+ }
+ docTopicCounter
+ }
+
+ private def tokenSampling(
+ gen: Random,
+ t: Table,
+ tSum: Double,
+ w: Table,
+ wSum: Double,
+ d: BSV[Double]): Int = {
+ val index = d.index
+ val data = d.data
+ val used = d.used
+ val dSum = data(d.used - 1)
+ val distSum = tSum + wSum + dSum
+ val genSum = gen.nextDouble() * distSum
+ if (genSum < dSum) {
+ val dGenSum = gen.nextDouble() * dSum
+ val pos = binarySearchInterval(data, dGenSum, 0, used, true)
+ index(pos)
+ } else if (genSum < (dSum + wSum)) {
+ sampleAlias(gen, w)
+ } else {
+ sampleAlias(gen, t)
+ }
+ }
+
+
+ private def tDense(
+ totalTopicCounter: BDV[Double],
+ numTokens: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): (Double, BDV[Double]) = {
+ val t = BDV.zeros[Double](numTopics)
+ var sum = 0.0
+ for (topic <- 0 until numTopics) {
+ val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) /
+ ((totalTopicCounter(topic) + betaSum) * termSum)
+ t(topic) = last
+ sum += last
+ }
+ (sum, t)
+ }
+
+ private def wSparse(
+ totalTopicCounter: BDV[Double],
+ termTopicCounter: BSV[Double],
+ numTokens: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): (Double, BSV[Double]) = {
+ val w = BSV.zeros[Double](numTopics)
+ var sum = 0.0
+ termTopicCounter.activeIterator.foreach { t =>
+ val topic = t._1
+ val count = t._2
+ val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) /
+ ((totalTopicCounter(topic) + betaSum) * termSum)
+ w(topic) = last
+ sum += last
+ }
+ (sum, w)
+ }
+
+ private def dSparse(
+ totalTopicCounter: BDV[Double],
+ termTopicCounter: BSV[Double],
+ docTopicCounter: BSV[Double],
+ currentTopic: Int,
+ numTokens: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): BSV[Double] = {
+ val numTopics = totalTopicCounter.length
+ // val termSum = numTokens - 1D + alphaAS * numTopics
+ val betaSum = numTerms * beta
+ val d = BSV.zeros[Double](numTopics)
+ var sum = 0.0
+ docTopicCounter.activeIterator.foreach { t =>
+ val topic = t._1
+ val count = if (currentTopic == topic && t._2 != 1) t._2 - 1 else t._2
+ // val last = count * termSum * (termTopicCounter(topic) + beta) /
+ // ((totalTopicCounter(topic) + betaSum) * termSum)
+ val last = count * (termTopicCounter(topic) + beta) /
+ (totalTopicCounter(topic) + betaSum)
+ sum += last
+ d(topic) = sum
+ }
+ d
+ }
+
+ private def wordTable(
+ cacheMap: AppendOnlyMap[Int, SoftReference[(Double, Table)]],
+ totalTopicCounter: BDV[Double],
+ termTopicCounter: BSV[Double],
+ termId: Int,
+ numTokens: Double,
+ numTerms: Double,
+ alpha: Double,
+ alphaAS: Double,
+ beta: Double): (Double, Table) = {
+ if (termTopicCounter.used == 0) return (0.0, null)
+ var w = cacheMap(termId)
+ if (w == null || w.get() == null) {
+ val t = wSparse(totalTopicCounter, termTopicCounter,
+ numTokens, numTerms, alpha, alphaAS, beta)
+ w = new SoftReference((t._1, generateAlias(t._2, t._1)))
+ cacheMap.update(termId, w)
+
+ }
+ w.get()
+ }
+
+ private[mllib] def mergeOne(term: Int, topic: Int, inc: Int) = {
+ gtc(topic) += inc
+ ttc(term)(topic) += inc
+ this
+ }
+
+ private[mllib] def merge(term: Int, counter: BV[Int]) = {
+ counter.activeIterator.foreach { case (topic, cn) =>
+ mergeOne(term, topic, cn)
+ }
+ this
+ }
+
+ private[mllib] def merge(other: LDAModel) = {
+ gtc :+= other.gtc
+ for (i <- 0 until ttc.length) {
+ ttc(i) :+= other.ttc(i)
+ }
+ this
+ }
+}
+
+object LDAModel {
+ def apply(numTopics: Int, numTerms: Int, alpha: Double = 0.1, beta: Double = 0.01) = {
+ new LDAModel(
+ BDV.zeros[Double](numTopics),
+ (0 until numTerms).map(_ => BSV.zeros[Double](numTopics)).toArray, alpha, beta, alpha)
+ }
+}
+
+private[mllib] object LDAUtils {
+
+ type Table = (ArrayBuffer[Int], ArrayBuffer[Int], ArrayBuffer[Double])
+
+ @transient private lazy val tableOrdering = new scala.math.Ordering[(Int, Double)] {
+ override def compare(x: (Int, Double), y: (Int, Double)): Int = {
+ Ordering.Double.compare(x._2, y._2)
+ }
+ }
+
+ @transient private lazy val tableReverseOrdering = tableOrdering.reverse
+
+ def generateAlias(sv: BV[Double], sum: Double, tableCache:Option[Table] = None): Table = {
+ val used = sv.activeSize
+ val probs = sv.activeIterator.slice(0, used)
+ generateAlias(probs, used, sum, tableCache)
+ }
+
+ def generateAlias(
+ probs: Iterator[(Int, Double)],
+ used: Int,
+ sum: Double, tableCache:Option[Table]): Table = {
+ val pMean = 1.0 / used
+ val table = tableCache.getOrElse(new ArrayBuffer[Int](used),
+ new ArrayBuffer[Int](used),
+ new ArrayBuffer[Double](used))
+ // reset and resize table cache
+ for (i <- 0 until Math.min(used, table._1.length)) {
+ table._1(i) = 0
+ table._2(i) = 0
+ table._3(i) = 0.0
+ }
+ if (used > table._1.length) {
+ val pad = 0 until (used - table._1.length)
+ table._1 ++= pad.map(_=>0)
+ table._2 ++= pad.map(_=>0)
+ table._3 ++= pad.map(_=>0.0)
+ } else {
+ table._1.reduceToSize(used)
+ table._2.reduceToSize(used)
+ table._3.reduceToSize(used)
+ }
+
+ val lq = new JPriorityQueue[(Int, Double)](used, tableOrdering)
+ val hq = new JPriorityQueue[(Int, Double)](used, tableReverseOrdering)
+
+ probs.slice(0, used).foreach { pair =>
+ val i = pair._1
+ val pi = pair._2 / sum
+ if (pi < pMean) {
+ lq.add((i, pi))
+ } else {
+ hq.add((i, pi))
+ }
+ }
+
+ var offset = 0
+ while (!lq.isEmpty & !hq.isEmpty) {
+ val (i, pi) = lq.remove()
+ val (h, ph) = hq.remove()
+ table._1(offset) = i
+ table._2(offset) = h
+ table._3(offset) = pi
+ val pd = ph - (pMean - pi)
+ if (pd >= pMean) {
+ hq.add((h, pd))
+ } else {
+ lq.add((h, pd))
+ }
+ offset += 1
+ }
+ while (!hq.isEmpty) {
+ val (h, ph) = hq.remove()
+ assert(ph - pMean < 1e-8)
+ table._1(offset) = h
+ table._2(offset) = h
+ table._3(offset) = ph
+ offset += 1
+ }
+
+ while (!lq.isEmpty) {
+ val (i, pi) = lq.remove()
+ assert(pMean - pi < 1e-8)
+ table._1(offset) = i
+ table._2(offset) = i
+ table._3(offset) = pi
+ offset += 1
+ }
+ table
+ }
+
+ def sampleAlias(gen: Random, table: Table): Int = {
+ val l = table._1.length
+ val bin = gen.nextInt(l)
+ val p = table._3(bin)
+ if (l * p > gen.nextDouble()) {
+ table._1(bin)
+ } else {
+ table._2(bin)
+ }
+ }
+
+ def uniformSampler(rand: Random, dimension: Int): Int = {
+ rand.nextInt(dimension)
+ }
+
+ def binarySearchInterval(
+ index: Array[Double],
+ key: Double,
+ begin: Int,
+ end: Int,
+ greater: Boolean): Int = {
+ if (begin == end) {
+ return if (greater) end else begin - 1
+ }
+ var b = begin
+ var e = end - 1
+
+ var mid: Int = (e + b) >> 1
+ while (b <= e) {
+ mid = (e + b) >> 1
+ val v = index(mid)
+ if (v < key) {
+ b = mid + 1
+ }
+ else if (v > key) {
+ e = mid - 1
+ }
+ else {
+ return mid
+ }
+ }
+ val v = index(mid)
+ mid = if ((greater && v >= key) || (!greater && v <= key)) {
+ mid
+ }
+ else if (greater) {
+ mid + 1
+ }
+ else {
+ mid - 1
+ }
+
+ if (greater) {
+ if (mid < end) assert(index(mid) >= key)
+ if (mid > 0) assert(index(mid - 1) <= key)
+ } else {
+ if (mid > 0) assert(index(mid) <= key)
+ if (mid < end - 1) assert(index(mid + 1) >= key)
+ }
+ mid
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
new file mode 100644
index 0000000000000..b60ef0d3be11f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.util.Random
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.scalatest.FunSuite
+
+import breeze.linalg.{DenseVector => BDV, SparseVector => BSV}
+import breeze.stats.distributions.Poisson
+import org.apache.spark.mllib.linalg.{Vectors, SparseVector => SSV}
+
+class LDASuite extends FunSuite with MLlibTestSparkContext {
+
+ import LDASuite._
+
+ test("LDA || Gibbs sampling") {
+ val model = generateRandomLDAModel(numTopics, numTerms)
+ val corpus = sampleCorpus(model, numDocs, numTerms, numTopics)
+
+ val data = sc.parallelize(corpus, 2)
+ val pps = new Array[Double](incrementalLearning)
+ val lda = new LDA(data, numTopics, alpha, beta, alphaAS)
+ var i = 0
+ val startedAt = System.currentTimeMillis()
+ while (i < incrementalLearning) {
+ lda.runGibbsSampling(totalIterations)
+ pps(i) = lda.perplexity
+ i += 1
+ }
+
+ println((System.currentTimeMillis() - startedAt) / 1e3)
+ pps.foreach(println)
+
+ val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs}
+ assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.6)
+ assert(pps.head - pps.last > 0)
+ }
+
+}
+
+object LDASuite {
+ val numTopics = 5
+ val numTerms = 1000
+ val numDocs = 100
+ val expectedDocLength = 300
+ val alpha = 0.01
+ val alphaAS = 1D
+ val beta = 0.01
+ val totalIterations = 2
+ val burnInIterations = 1
+ val incrementalLearning = 10
+
+ /**
+ * Generate a random LDA model, i.e. the topic-term matrix.
+ */
+ def generateRandomLDAModel(numTopics: Int, numTerms: Int): Array[BDV[Double]] = {
+ val model = new Array[BDV[Double]](numTopics)
+ val width = numTerms * 1.0 / numTopics
+ var topic = 0
+ var i = 0
+ while (topic < numTopics) {
+ val topicCentroid = width * (topic + 1)
+ model(topic) = BDV.zeros[Double](numTerms)
+ i = 0
+ while (i < numTerms) {
+ // treat the term list as a circle, so the distance between the first one and the last one
+ // is 1, not n-1.
+ val distance = Math.abs(topicCentroid - i) % (numTerms / 2)
+ // Possibility is decay along with distance
+ model(topic)(i) = 1.0 / (1 + Math.abs(distance))
+ i += 1
+ }
+ topic += 1
+ }
+ model
+ }
+
+ /**
+ * Sample one document given the topic-term matrix.
+ */
+ def ldaSampler(
+ model: Array[BDV[Double]],
+ topicDist: BDV[Double],
+ numTermsPerDoc: Int): Array[Int] = {
+ val samples = new Array[Int](numTermsPerDoc)
+ val rand = new Random()
+ (0 until numTermsPerDoc).foreach { i =>
+ samples(i) = multinomialDistSampler(
+ rand,
+ model(multinomialDistSampler(rand, topicDist))
+ )
+ }
+ samples
+ }
+
+ /**
+ * Sample corpus (many documents) from a given topic-term matrix.
+ */
+ def sampleCorpus(
+ model: Array[BDV[Double]],
+ numDocs: Int,
+ numTerms: Int,
+ numTopics: Int): Array[(Long, SSV)] = {
+ (0 until numDocs).map { i =>
+ val rand = new Random()
+ val numTermsPerDoc = Poisson.distribution(expectedDocLength).sample()
+ val numTopicsPerDoc = rand.nextInt(numTopics / 2) + 1
+ val topicDist = BDV.zeros[Double](numTopics)
+ (0 until numTopicsPerDoc).foreach { _ =>
+ topicDist(rand.nextInt(numTopics)) += 1
+ }
+ val sv = BSV.zeros[Double](numTerms)
+ ldaSampler(model, topicDist, numTermsPerDoc).foreach { term => sv(term) += 1}
+ (i.toLong, Vectors.fromBreeze(sv).asInstanceOf[SSV])
+ }.toArray
+ }
+
+ /**
+ * A multinomial distribution sampler, using roulette method to sample an Int back.
+ */
+ def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = {
+ val distSum = rand.nextDouble() * breeze.linalg.sum[BDV[Double], Double](dist)
+
+ def loop(index: Int, accum: Double): Int = {
+ if (index == dist.length) return dist.length - 1
+ val sum = accum - dist(index)
+ if (sum <= 0) index else loop(index + 1, sum)
+ }
+
+ loop(0, distSum)
+ }
+}