diff --git a/mllib/pom.xml b/mllib/pom.xml index f0928e1268e43..b556df7a38f02 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -50,6 +50,11 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-graphx_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala new file mode 100644 index 0000000000000..b74dc5f67d291 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -0,0 +1,750 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, sum => brzSum} + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} +import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV} +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.KryoRegistrator +import org.apache.spark.storage.StorageLevel +import org.apache.spark.SparkContext._ +import org.apache.spark.util.random.XORShiftRandom + +import LDA._ +import LDAUtils._ + +import scala.collection.mutable.ArrayBuffer + +class LDA private[mllib]( + @transient private var corpus: Graph[VD, ED], + private val numTopics: Int, + private val numTerms: Int, + private var alpha: Double, + private var beta: Double, + private var alphaAS: Double, + private var storageLevel: StorageLevel) + extends Serializable with Logging { + + def this(docs: RDD[(DocId, SSV)], + numTopics: Int, + alpha: Double, + beta: Double, + alphaAS: Double, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + computedModel: Broadcast[LDAModel] = null) { + this(initializeCorpus(docs, numTopics, storageLevel, computedModel), + numTopics, docs.first()._2.size, alpha, beta, alphaAS, storageLevel) + } + + // scalastyle:off + /** + * 语料库文档数 + */ + val numDocs = docVertices.count() + + /** + * 语料库总的词数(包含重复) + */ + val numTokens = corpus.edges.map(e => e.attr.size.toDouble).sum().toLong + + def setAlpha(alpha: Double): this.type = { + this.alpha = alpha + this + } + + def setBeta(beta: Double): this.type = { + this.beta = beta + this + } + + def setAlphaAS(alphaAS: Double): this.type = { + this.alphaAS = alphaAS + this + } + + def setStorageLevel(newStorageLevel: StorageLevel): this.type = { + this.storageLevel = newStorageLevel + this + } + + def setSeed(newSeed: Int): this.type = { + this.seed = newSeed + this + } + + def getCorpus = corpus + + // scalastyle:on + + @transient private var seed = new Random().nextInt() + @transient private var innerIter = 1 + @transient private var totalTopicCounter: BDV[Count] = collectTotalTopicCounter(corpus) + + private def termVertices = corpus.vertices.filter(t => t._1 >= 0) + + private def docVertices = corpus.vertices.filter(t => t._1 < 0) + + private def checkpoint(corpus: Graph[VD, ED]): Unit = { + if (innerIter % 10 == 0 && corpus.edges.sparkContext.getCheckpointDir.isDefined) { + corpus.checkpoint() + } + } + + private def collectTotalTopicCounter(graph: Graph[VD, ED]): BDV[Count] = { + val globalTopicCounter = collectGlobalCounter(graph, numTopics) + assert(brzSum(globalTopicCounter) == numTokens) + globalTopicCounter + } + + private def gibbsSampling(): Unit = { + val sampledCorpus = sampleTokens(corpus, totalTopicCounter, innerIter + seed, + numTokens, numTopics, numTerms, alpha, alphaAS, beta) + sampledCorpus.persist(storageLevel) + + val counterCorpus = updateCounter(sampledCorpus, numTopics) + checkpoint(counterCorpus) + counterCorpus.persist(storageLevel) + // counterCorpus.vertices.count() + counterCorpus.edges.count() + totalTopicCounter = collectTotalTopicCounter(counterCorpus) + + corpus.edges.unpersist(false) + corpus.vertices.unpersist(false) + sampledCorpus.edges.unpersist(false) + sampledCorpus.vertices.unpersist(false) + corpus = counterCorpus + innerIter += 1 + } + + def saveModel(iter: Int = 1): LDAModel = { + var termTopicCounter: RDD[(VertexId, VD)] = null + for (iter <- 1 to iter) { + logInfo(s"Save TopicModel (Iteration $iter/$iter)") + var previousTermTopicCounter = termTopicCounter + gibbsSampling() + val newTermTopicCounter = termVertices + termTopicCounter = Option(termTopicCounter).map(_.join(newTermTopicCounter).map { + case (term, (a, b)) => + (term, a :+ b) + }).getOrElse(newTermTopicCounter) + + termTopicCounter.persist(storageLevel).count() + Option(previousTermTopicCounter).foreach(_.unpersist()) + previousTermTopicCounter = termTopicCounter + } + val model = LDAModel(numTopics, numTerms, alpha, beta) + termTopicCounter.collect().foreach { case (term, counter) => + model.merge(term.toInt, counter) + } + model.gtc :/= iter.toDouble + model.ttc.foreach { ttc => + ttc :/= iter.toDouble + ttc.compact() + } + model + } + + def runGibbsSampling(iterations: Int): Unit = { + for (iter <- 1 to iterations) { + // logInfo(s"Gibbs samplin perplexity $iter: ${perplexity}") + // logInfo(s"Gibbs sampling (Iteration $iter/$iterations)") + // val startedAt = System.nanoTime() + gibbsSampling() + // val endAt = System.nanoTime() + // val useTime = (endAt - startedAt) / 1e9 + // logInfo(s"Gibbs sampling use time $iter: $useTime") + } + } + + def mergeDuplicateTopic(threshold: Double = 0.95D): Map[Int, Int] = { + val rows = termVertices.map(t => t._2).map { bsv => + val length = bsv.length + val used = bsv.activeSize + val index = bsv.index.slice(0, used) + val data = bsv.data.slice(0, used).map(_.toDouble) + new SSV(length, index, data).asInstanceOf[SV] + } + val simMatrix = new RowMatrix(rows).columnSimilarities() + val minMap = simMatrix.entries.filter { case MatrixEntry(row, column, sim) => + sim > threshold && row != column + }.map { case MatrixEntry(row, column, sim) => + (column.toInt, row.toInt) + }.groupByKey().map { case (topic, simTopics) => + (topic, simTopics.min) + }.collect().toMap + if (minMap.size > 0) { + corpus = corpus.mapEdges(edges => { + edges.attr.map { topic => + minMap.get(topic).getOrElse(topic) + } + }) + corpus = updateCounter(corpus, numTopics) + } + minMap + } + + + // scalastyle:off + /** + * 词在所有主题分布和该词所在文本的主题分布乘积: p(w)=\sum_{k}{p(k|d)*p(w|k)}= + * \sum_{k}{\frac{{n}_{kw}+{\beta }_{w}} {{n}_{k}+\bar{\beta }} \frac{{n}_{kd}+{\alpha }_{k}} {\sum{{n}_{k}}+\bar{\alpha }}}= + * \sum_{k} \frac{{\alpha }_{k}{\beta }_{w} + {n}_{kw}{\alpha }_{k} + {n}_{kd}{\beta }_{w} + {n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta }} \frac{1}{\sum{{n}_{k}}+\bar{\alpha }}} + * \exp^{-(\sum{\log(p(w))})/N} + * N为语料库包含的token数 + */ + // scalastyle:on + def perplexity(): Double = { + val totalTopicCounter = this.totalTopicCounter + val numTopics = this.numTopics + val numTerms = this.numTerms + val alpha = this.alpha + val beta = this.beta + val totalSize = brzSum(totalTopicCounter) + var totalProb = 0D + + // \frac{{\alpha }_{k}{\beta }_{w}}{{n}_{k}+\bar{\beta }} + totalTopicCounter.activeIterator.foreach { case (topic, cn) => + totalProb += alpha * beta / (cn + numTerms * beta) + } + + val termProb = corpus.mapVertices { (vid, counter) => + val probDist = BSV.zeros[Double](numTopics) + if (vid >= 0) { + val termTopicCounter = counter + // \frac{{n}_{kw}{\alpha }_{k}}{{n}_{k}+\bar{\beta }} + termTopicCounter.activeIterator.foreach { case (topic, cn) => + probDist(topic) = cn * alpha / + (totalTopicCounter(topic) + numTerms * beta) + } + } else { + val docTopicCounter = counter + // \frac{{n}_{kd}{\beta }_{w}}{{n}_{k}+\bar{\beta }} + docTopicCounter.activeIterator.foreach { case (topic, cn) => + probDist(topic) = cn * beta / + (totalTopicCounter(topic) + numTerms * beta) + } + } + probDist.compact() + (counter, probDist) + }.mapTriplets { triplet => + val (termTopicCounter, termProb) = triplet.srcAttr + val (docTopicCounter, docProb) = triplet.dstAttr + val docSize = brzSum(docTopicCounter) + val docTermSize = triplet.attr.length + var prob = 0D + + // \frac{{n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta}} + docTopicCounter.activeIterator.foreach { case (topic, cn) => + prob += cn * termTopicCounter(topic) / + (totalTopicCounter(topic) + numTerms * beta) + } + prob += brzSum(docProb) + brzSum(termProb) + totalProb + prob += prob / (docSize + numTopics * alpha) + + docTermSize * Math.log(prob) + }.edges.map(t => t.attr).sum() + + math.exp(-1 * termProb / totalSize) + } +} + +object LDA { + + private[mllib] type DocId = VertexId + private[mllib] type WordId = VertexId + private[mllib] type Count = Int + private[mllib] type ED = Array[Count] + private[mllib] type VD = BSV[Count] + + def train(docs: RDD[(DocId, SSV)], + numTopics: Int = 2048, + totalIter: Int = 150, + alpha: Double = 0.01, + beta: Double = 0.01, + alphaAS: Double = 0.1): LDAModel = { + require(totalIter > 0, "totalIter is less than 0") + val topicModeling = new LDA(docs, numTopics, alpha, beta, alphaAS) + topicModeling.runGibbsSampling(totalIter - 1) + topicModeling.saveModel(1) + } + + /** + * topicID termID+1:counter termID+1:counter .. + */ + def trainAndSaveModel( + docs: RDD[(DocId, SSV)], + dir: String, + numTopics: Int = 2048, + totalIter: Int = 150, + alpha: Double = 0.01, + beta: Double = 0.01, + alphaAS: Double = 0.1): Unit = { + import org.apache.spark.mllib.regression.LabeledPoint + import org.apache.spark.mllib.util.MLUtils + import org.apache.spark.mllib.linalg.Vectors + val lda = new LDA(docs, numTopics, alpha, beta, alphaAS) + val numTerms = lda.numTerms + lda.runGibbsSampling(totalIter) + val rdd = lda.termVertices.flatMap { case (termId, counter) => + counter.activeIterator.map { case (topic, cn) => + val sv = BSV.zeros[Double](numTerms) + sv(termId.toInt) = cn.toDouble + (topic, sv) + } + }.reduceByKey { (a, b) => a + b }.map { case (topic, sv) => + LabeledPoint(topic.toDouble, Vectors.fromBreeze(sv)) + } + MLUtils.saveAsLibSVMFile(rdd, dir) + } + + def incrementalTrain(docs: RDD[(DocId, SSV)], + computedModel: LDAModel, + alphaAS: Double = 1, + totalIter: Int = 150): LDAModel = { + require(totalIter > 0, "totalIter is less than 0") + val numTopics = computedModel.ttc.size + val alpha = computedModel.alpha + val beta = computedModel.beta + + val broadcastModel = docs.context.broadcast(computedModel) + val topicModeling = new LDA(docs, numTopics, alpha, beta, alphaAS, + computedModel = broadcastModel) + broadcastModel.unpersist() + topicModeling.runGibbsSampling(totalIter - 1) + topicModeling.saveModel(1) + } + + private[mllib] def sampleTokens( + graph: Graph[VD, ED], + totalTopicCounter: BDV[Count], + innerIter: Long, + numTokens: Double, + numTopics: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): Graph[VD, ED] = { + val parts = graph.edges.partitions.size + val nweGraph = graph.mapTriplets( + (pid, iter) => { + val gen = new XORShiftRandom(parts * innerIter + pid) + // table is a per term data structure + // in GraphX, edges in a partition are clustered by source IDs (term id in this case) + // so, use below simple cache to avoid calculating table each time + val lastTable = (new ArrayBuffer[Int](numTopics.toInt), + new ArrayBuffer[Int](numTopics.toInt), + new ArrayBuffer[Double](numTopics.toInt)) + var lastVid = None.asInstanceOf[Option[VertexId]] + var lastWsum = 0.0 + val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) + val dData = new Array[Double](numTopics.toInt) + val t = generateAlias(dv._2, dv._1) + val tSum = dv._1 + iter.map { + triplet => + val termId = triplet.srcId + val docId = triplet.dstId + val termTopicCounter = triplet.srcAttr + val docTopicCounter = triplet.dstAttr + val topics = triplet.attr + for (i <- 0 until topics.length) { + val currentTopic = topics(i) + docTopicCounter.synchronized { + termTopicCounter.synchronized { + dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData, + currentTopic, numTokens, numTerms, alpha, alphaAS, beta) + if (lastVid != Some(termId) || gen.nextDouble() < 1e-4) { + lastWsum = wordTable(lastTable, totalTopicCounter, termTopicCounter, + termId, numTokens, numTerms, alpha, alphaAS, beta) + lastVid = Some(termId) + } + val newTopic = tokenSampling(gen, t, tSum, lastTable, termTopicCounter, lastWsum, + docTopicCounter, dData, currentTopic) + + if (newTopic != currentTopic) { + topics(i) = newTopic + docTopicCounter(currentTopic) -= 1 + docTopicCounter(newTopic) += 1 + // if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact() + + termTopicCounter(currentTopic) -= 1 + termTopicCounter(newTopic) += 1 + // if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact() + + totalTopicCounter(currentTopic) -= 1 + totalTopicCounter(newTopic) += 1 + } + } + } + } + + topics + } + }, TripletFields.All) + nweGraph + } + + private def updateCounter(graph: Graph[VD, ED], numTopics: Int): Graph[VD, ED] = { + val newCounter = graph.aggregateMessages[VD](ctx => { + val topics = ctx.attr + val vector = BSV.zeros[Count](numTopics) + for (topic <- topics) { + vector(topic) += 1 + } + ctx.sendToDst(vector) + ctx.sendToSrc(vector) + }, _ + _, TripletFields.EdgeOnly).mapValues(v => { + val used = v.used + if (v.index.length == used) { + v + } else { + val index = new Array[Int](used) + val data = new Array[Count](used) + Array.copy(v.index, 0, index, 0, used) + Array.copy(v.data, 0, data, 0, used) + new BSV[Count](index, data, numTopics) + } + }) + // GraphImpl.fromExistingRDDs(newCounter, graph.edges) + GraphImpl(newCounter, graph.edges) + } + + private def collectGlobalCounter(graph: Graph[VD, ED], numTopics: Int): BDV[Count] = { + graph.vertices.filter(t => t._1 >= 0).map(_._2). + aggregate(BDV.zeros[Count](numTopics))((a, b) => { + a :+= b + }, _ :+= _) + } + + private def initializeCorpus( + docs: RDD[(LDA.DocId, SSV)], + numTopics: Int, + storageLevel: StorageLevel, + computedModel: Broadcast[LDAModel] = null): Graph[VD, ED] = { + val edges = docs.mapPartitionsWithIndex((pid, iter) => { + val gen = new Random(pid) + var model: LDAModel = null + if (computedModel != null) model = computedModel.value + iter.flatMap { + case (docId, doc) => + val bsv = new BSV[Int](doc.indices, doc.values.map(_.toInt), doc.size) + initializeEdges(gen, bsv, docId, numTopics, model) + } + }) + edges.persist(storageLevel) + var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) + // degree-based hashing + val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0) } + val numPartitions = edges.partitions.size + val partitionStrategy = new DBHPartitioner(numPartitions) + val newEdges = degrees.triplets.map { e => + (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr)) + }.partitionBy(new HashPartitioner(numPartitions)).map(_._2) + corpus = Graph.fromEdges(newEdges, null, storageLevel, storageLevel) + // end degree-based hashing + // corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D) + corpus = updateCounter(corpus, numTopics).cache() + corpus.vertices.count() + corpus.edges.count() + edges.unpersist() + corpus + } + + private def initializeEdges( + gen: Random, + doc: BSV[Int], + docId: DocId, + numTopics: Int, + computedModel: LDAModel = null): Array[Edge[ED]] = { + assert(docId >= 0) + val newDocId: DocId = -(docId + 1L) + val edges = if (computedModel == null) { + doc.activeIterator.filter(_._2 > 0).map { case (termId, counter) => + val topics = new Array[Int](counter) + for (i <- 0 until counter) { + topics(i) = gen.nextInt(numTopics) + } + Edge(termId, newDocId, topics) + }.toArray + } + else { + computedModel.setSeed(gen.nextInt()) + val tokens = computedModel.vector2Array(doc) + val topics = new Array[Int](tokens.length) + var docTopicCounter = computedModel.uniformDistSampler(tokens, topics) + for (t <- 0 until 15) { + docTopicCounter = computedModel.sampleTokens(docTopicCounter, + tokens, topics) + } + doc.activeIterator.filter(_._2 > 0).map { case (term, counter) => + val ev = topics.zipWithIndex.filter { case (topic, offset) => + term == tokens(offset) + }.map(_._1) + Edge(term, newDocId, ev) + }.toArray + } + assert(edges.length > 0) + edges + } + + // scalastyle:off + /** + * 这里组合使用 Gibbs sampler 和 Metropolis Hastings sampler + * 每次采样的复杂度为: O(1) + * 使用 Gibbs sampler 采样论文 Rethinking LDA: Why Priors Matter 公式(3) + * \frac{{n}_{kw}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta}} \frac{{n}_{kd} ^{-di}+ \bar{\alpha} \frac{{n}_{k}^{-di} + \acute{\alpha}}{\sum{n}_{k} +\bar{\acute{\alpha}}}}{\sum{n}_{kd}^{-di} +\bar{\alpha}} + * = t + w + d + * t 全局相关部分 + * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + * w 词相关部分 + * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + * d 文档和词的乘积 + * d = \frac{{n}_{kd}^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + * = \frac{{n}_{kd ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) } + * 其中 + * \bar{\beta}=\sum_{w}{\beta}_{w} + * \bar{\alpha}=\sum_{k}{\alpha}_{k} + * \bar{\acute{\alpha}}=\bar{\acute{\alpha}}=\sum_{k}\acute{\alpha} + * {n}_{kd} 文档d中主题为k的tokens数 + * {n}_{kw} 词中主题为k的tokens数 + * {n}_{k} 语料库中主题为k的tokens数 + * -di 减去当前token的主题 + */ + // scalastyle:on + private def tokenSampling( + gen: Random, + t: Table, + tSum: Double, + w: Table, + termTopicCounter: VD, + wSum: Double, + docTopicCounter: VD, + dData: Array[Double], + currentTopic: Int): Int = { + val index = docTopicCounter.index + val used = docTopicCounter.used + val dSum = dData(docTopicCounter.used - 1) + val distSum = tSum + wSum + dSum + val genSum = gen.nextDouble() * distSum + if (genSum < dSum) { + val dGenSum = gen.nextDouble() * dSum + val pos = binarySearchInterval(dData, dGenSum, 0, used, true) + index(pos) + } else if (genSum < (dSum + wSum)) { + sampleSV(gen, w, termTopicCounter, currentTopic) + } else { + sampleAlias(gen, t) + } + } + + + /** + * 分解后的公式为 + * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + */ + private def tDense( + totalTopicCounter: BDV[Count], + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BDV[Double]) = { + val numTopics = totalTopicCounter.length + val t = BDV.zeros[Double](numTopics) + val alphaSum = alpha * numTopics + val termSum = numTokens - 1D + alphaAS * numTopics + val betaSum = numTerms * beta + var sum = 0.0 + for (topic <- 0 until numTopics) { + val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) / + ((totalTopicCounter(topic) + betaSum) * termSum) + t(topic) = last + sum += last + } + (sum, t) + } + + /** + * 分解后的公式为 + * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + */ + private def wSparse( + totalTopicCounter: BDV[Count], + termTopicCounter: VD, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BSV[Double]) = { + val numTopics = totalTopicCounter.length + val alphaSum = alpha * numTopics + val termSum = numTokens - 1D + alphaAS * numTopics + val betaSum = numTerms * beta + val w = BSV.zeros[Double](numTopics) + var sum = 0.0 + termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t => + val topic = t._1 + val count = t._2 + val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / + ((totalTopicCounter(topic) + betaSum) * termSum) + w(topic) = last + sum += last + } + (sum, w) + } + + /** + * 分解后的公式为 + * d = \frac{{n}_{kd} ^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + * = \frac{{n}_{kd} ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) } + */ + private def dSparse( + totalTopicCounter: BDV[Count], + termTopicCounter: VD, + docTopicCounter: VD, + d: Array[Double], + currentTopic: Int, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): Unit = { + val index = docTopicCounter.index + val data = docTopicCounter.data + val used = docTopicCounter.used + + // val termSum = numTokens - 1D + alphaAS * numTopics + val betaSum = numTerms * beta + var sum = 0.0 + for (i <- 0 until used) { + val topic = index(i) + val count = data(i) + val adjustment = if (currentTopic == topic) -1D else 0 + val last = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) / + (totalTopicCounter(topic) + adjustment + betaSum) + // val lastD = (count + adjustment) * termSum * (termTopicCounter(topic) + adjustment + beta) / + // ((totalTopicCounter(topic) + adjustment + betaSum) * termSum) + + sum += last + d(i) = sum + } + } + + private def wordTable( + table:Table, + totalTopicCounter: BDV[Count], + termTopicCounter: VD, + termId: VertexId, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): Double = { + val sv = wSparse(totalTopicCounter, termTopicCounter, + numTokens, numTerms, alpha, alphaAS, beta) + generateAlias(sv._2, sv._1, Some(table)) + sv._1 + } + + private def sampleSV(gen: Random, table: Table, sv: VD, currentTopic: Int): Int = { + val docTopic = sampleAlias(gen, table) + if (docTopic == currentTopic) { + val svCounter = sv(currentTopic) + // 这里的处理方法不太对. + // 如果采样到当前token的Topic这丢弃掉 + // svCounter == 1 && table.length > 1 采样到token的Topic 但包含其他token + // svCounter > 1 && gen.nextDouble() < 1.0 / svCounter 采样的Topic 有1/svCounter 概率属于当前token + if ((svCounter == 1 && table._1.length > 1) || + (svCounter > 1 && gen.nextDouble() < 1.0 / svCounter)) { + return sampleSV(gen, table, sv, currentTopic) + } + } + docTopic + } + +} + +/** + * Degree-Based Hashing, the paper: + * http://nips.cc/Conferences/2014/Program/event.php?ID=4569 + * @param partitions + */ +private class DBHPartitioner(partitions: Int) extends Partitioner { + val mixingPrime: Long = 1125899906842597L + + def numPartitions = partitions + + def getPartition(key: Any): Int = { + val edge = key.asInstanceOf[EdgeTriplet[Int, ED]] + val srcDeg = edge.srcAttr + val dstDeg = edge.dstAttr + val srcId = edge.srcId + val dstId = edge.dstId + val minId = if (srcDeg < dstDeg) srcId else dstId + getPartition(minId) + } + + def getPartition(idx: Int): PartitionID = { + (math.abs(idx * mixingPrime) % partitions).toInt + } + + def getPartition(idx: Long): PartitionID = { + (math.abs(idx * mixingPrime) % partitions).toInt + } + + override def equals(other: Any): Boolean = other match { + case h: DBHPartitioner => + h.numPartitions == numPartitions + case _ => + false + } + + override def hashCode: Int = numPartitions +} + +private[mllib] class LDAKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: com.esotericsoftware.kryo.Kryo) { + val gkr = new GraphKryoRegistrator + gkr.registerClasses(kryo) + + kryo.register(classOf[BSV[LDA.Count]]) + kryo.register(classOf[BSV[Double]]) + + kryo.register(classOf[BDV[LDA.Count]]) + kryo.register(classOf[BDV[Double]]) + + kryo.register(classOf[SV]) + kryo.register(classOf[SSV]) + kryo.register(classOf[SDV]) + + kryo.register(classOf[LDA.ED]) + kryo.register(classOf[LDA.VD]) + + kryo.register(classOf[Random]) + kryo.register(classOf[LDA]) + kryo.register(classOf[LDAModel]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala new file mode 100644 index 0000000000000..8be67cc479f38 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -0,0 +1,439 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.lang.ref.SoftReference +import java.util.Random +import java.util.{PriorityQueue => JPriorityQueue} + +import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, +sum => brzSum, norm => brzNorm} + +import org.apache.spark.mllib.linalg.{Vectors, DenseVector => SDV, SparseVector => SSV} +import org.apache.spark.util.collection.AppendOnlyMap +import org.apache.spark.util.random.XORShiftRandom + +import LDAUtils._ + +import scala.collection.mutable.ArrayBuffer + +class LDAModel private[mllib]( + private[mllib] val gtc: BDV[Double], + private[mllib] val ttc: Array[BSV[Double]], + val alpha: Double, + val beta: Double, + val alphaAS: Double) extends Serializable { + + def this(topicCounts: SDV, topicTermCounts: Array[SSV], alpha: Double, beta: Double) { + this(new BDV[Double](topicCounts.toArray), topicTermCounts.map(t => + new BSV(t.indices, t.values, t.size)), alpha, beta, alpha) + } + + @transient private lazy val numTopics = gtc.size + @transient private lazy val numTerms = ttc.size + @transient private lazy val numTokens = brzSum(gtc) + @transient private lazy val betaSum = numTerms * beta + @transient private lazy val alphaSum = numTopics * alpha + @transient private lazy val termSum = numTokens + alphaAS * numTopics + + @transient private lazy val wordTableCache = + new AppendOnlyMap[Int, SoftReference[(Double, Table)]]() + @transient private lazy val (t, tSum) = { + val dv = tDense(gtc, numTokens, numTerms, alpha, alphaAS, beta) + (generateAlias(dv._2, dv._1), dv._1) + } + @transient private lazy val rand = new XORShiftRandom() + + def setSeed(seed: Long): Unit = { + rand.setSeed(seed) + } + + def globalTopicCounter = Vectors.fromBreeze(gtc) + + def topicTermCounter = ttc.map(t => Vectors.fromBreeze(t)) + + def inference( + doc: SSV, + totalIter: Int = 10, + burnIn: Int = 5): SSV = { + require(totalIter > burnIn, "totalIter is less than burnInIter") + require(totalIter > 0, "totalIter is less than 0") + require(burnIn > 0, "burnInIter is less than 0") + + val topicDist = BSV.zeros[Double](numTopics) + val tokens = vector2Array(new BSV[Int](doc.indices, doc.values.map(_.toInt), doc.size)) + val topics = new Array[Int](tokens.length) + + var docTopicCounter = uniformDistSampler(tokens, topics) + for (i <- 0 until totalIter) { + docTopicCounter = sampleTokens(docTopicCounter, tokens, topics) + if (i + burnIn >= totalIter) topicDist :+= docTopicCounter + } + + topicDist.compact() + topicDist :/= brzNorm(topicDist, 1) + Vectors.fromBreeze(topicDist).asInstanceOf[SSV] + } + + private[mllib] def vector2Array(vec: BV[Int]): Array[Int] = { + val docLen = brzSum(vec) + var offset = 0 + val sent = new Array[Int](docLen) + vec.activeIterator.foreach { case (term, cn) => + for (i <- 0 until cn) { + sent(offset) = term + offset += 1 + } + } + sent + } + + private[mllib] def uniformDistSampler( + tokens: Array[Int], + topics: Array[Int]): BSV[Double] = { + val docTopicCounter = BSV.zeros[Double](numTopics) + for (i <- 0 until tokens.length) { + val topic = uniformSampler(rand, numTopics) + topics(i) = topic + docTopicCounter(topic) += 1D + } + docTopicCounter + } + + private[mllib] def sampleTokens( + docTopicCounter: BSV[Double], + tokens: Array[Int], + topics: Array[Int]): BSV[Double] = { + for (i <- 0 until topics.length) { + val termId = tokens(i) + val currentTopic = topics(i) + val d = dSparse(gtc, ttc(termId), docTopicCounter, + currentTopic, numTokens, numTerms, alpha, alphaAS, beta) + + val (wSum, w) = wordTable(wordTableCache, gtc, ttc(termId), termId, + numTokens, numTerms, alpha, alphaAS, beta) + val newTopic = tokenSampling(rand, t, tSum, w, wSum, d) + if (newTopic != currentTopic) { + docTopicCounter(newTopic) += 1D + docTopicCounter(currentTopic) -= 1D + topics(i) = newTopic + if (docTopicCounter(currentTopic) == 0) { + docTopicCounter.compact() + } + } + } + docTopicCounter + } + + private def tokenSampling( + gen: Random, + t: Table, + tSum: Double, + w: Table, + wSum: Double, + d: BSV[Double]): Int = { + val index = d.index + val data = d.data + val used = d.used + val dSum = data(d.used - 1) + val distSum = tSum + wSum + dSum + val genSum = gen.nextDouble() * distSum + if (genSum < dSum) { + val dGenSum = gen.nextDouble() * dSum + val pos = binarySearchInterval(data, dGenSum, 0, used, true) + index(pos) + } else if (genSum < (dSum + wSum)) { + sampleAlias(gen, w) + } else { + sampleAlias(gen, t) + } + } + + + private def tDense( + totalTopicCounter: BDV[Double], + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BDV[Double]) = { + val t = BDV.zeros[Double](numTopics) + var sum = 0.0 + for (topic <- 0 until numTopics) { + val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) / + ((totalTopicCounter(topic) + betaSum) * termSum) + t(topic) = last + sum += last + } + (sum, t) + } + + private def wSparse( + totalTopicCounter: BDV[Double], + termTopicCounter: BSV[Double], + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BSV[Double]) = { + val w = BSV.zeros[Double](numTopics) + var sum = 0.0 + termTopicCounter.activeIterator.foreach { t => + val topic = t._1 + val count = t._2 + val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / + ((totalTopicCounter(topic) + betaSum) * termSum) + w(topic) = last + sum += last + } + (sum, w) + } + + private def dSparse( + totalTopicCounter: BDV[Double], + termTopicCounter: BSV[Double], + docTopicCounter: BSV[Double], + currentTopic: Int, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): BSV[Double] = { + val numTopics = totalTopicCounter.length + // val termSum = numTokens - 1D + alphaAS * numTopics + val betaSum = numTerms * beta + val d = BSV.zeros[Double](numTopics) + var sum = 0.0 + docTopicCounter.activeIterator.foreach { t => + val topic = t._1 + val count = if (currentTopic == topic && t._2 != 1) t._2 - 1 else t._2 + // val last = count * termSum * (termTopicCounter(topic) + beta) / + // ((totalTopicCounter(topic) + betaSum) * termSum) + val last = count * (termTopicCounter(topic) + beta) / + (totalTopicCounter(topic) + betaSum) + sum += last + d(topic) = sum + } + d + } + + private def wordTable( + cacheMap: AppendOnlyMap[Int, SoftReference[(Double, Table)]], + totalTopicCounter: BDV[Double], + termTopicCounter: BSV[Double], + termId: Int, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, Table) = { + if (termTopicCounter.used == 0) return (0.0, null) + var w = cacheMap(termId) + if (w == null || w.get() == null) { + val t = wSparse(totalTopicCounter, termTopicCounter, + numTokens, numTerms, alpha, alphaAS, beta) + w = new SoftReference((t._1, generateAlias(t._2, t._1))) + cacheMap.update(termId, w) + + } + w.get() + } + + private[mllib] def mergeOne(term: Int, topic: Int, inc: Int) = { + gtc(topic) += inc + ttc(term)(topic) += inc + this + } + + private[mllib] def merge(term: Int, counter: BV[Int]) = { + counter.activeIterator.foreach { case (topic, cn) => + mergeOne(term, topic, cn) + } + this + } + + private[mllib] def merge(other: LDAModel) = { + gtc :+= other.gtc + for (i <- 0 until ttc.length) { + ttc(i) :+= other.ttc(i) + } + this + } +} + +object LDAModel { + def apply(numTopics: Int, numTerms: Int, alpha: Double = 0.1, beta: Double = 0.01) = { + new LDAModel( + BDV.zeros[Double](numTopics), + (0 until numTerms).map(_ => BSV.zeros[Double](numTopics)).toArray, alpha, beta, alpha) + } +} + +private[mllib] object LDAUtils { + + type Table = (ArrayBuffer[Int], ArrayBuffer[Int], ArrayBuffer[Double]) + + @transient private lazy val tableOrdering = new scala.math.Ordering[(Int, Double)] { + override def compare(x: (Int, Double), y: (Int, Double)): Int = { + Ordering.Double.compare(x._2, y._2) + } + } + + @transient private lazy val tableReverseOrdering = tableOrdering.reverse + + def generateAlias(sv: BV[Double], sum: Double, tableCache:Option[Table] = None): Table = { + val used = sv.activeSize + val probs = sv.activeIterator.slice(0, used) + generateAlias(probs, used, sum, tableCache) + } + + def generateAlias( + probs: Iterator[(Int, Double)], + used: Int, + sum: Double, tableCache:Option[Table]): Table = { + val pMean = 1.0 / used + val table = tableCache.getOrElse(new ArrayBuffer[Int](used), + new ArrayBuffer[Int](used), + new ArrayBuffer[Double](used)) + // reset and resize table cache + for (i <- 0 until Math.min(used, table._1.length)) { + table._1(i) = 0 + table._2(i) = 0 + table._3(i) = 0.0 + } + if (used > table._1.length) { + val pad = 0 until (used - table._1.length) + table._1 ++= pad.map(_=>0) + table._2 ++= pad.map(_=>0) + table._3 ++= pad.map(_=>0.0) + } else { + table._1.reduceToSize(used) + table._2.reduceToSize(used) + table._3.reduceToSize(used) + } + + val lq = new JPriorityQueue[(Int, Double)](used, tableOrdering) + val hq = new JPriorityQueue[(Int, Double)](used, tableReverseOrdering) + + probs.slice(0, used).foreach { pair => + val i = pair._1 + val pi = pair._2 / sum + if (pi < pMean) { + lq.add((i, pi)) + } else { + hq.add((i, pi)) + } + } + + var offset = 0 + while (!lq.isEmpty & !hq.isEmpty) { + val (i, pi) = lq.remove() + val (h, ph) = hq.remove() + table._1(offset) = i + table._2(offset) = h + table._3(offset) = pi + val pd = ph - (pMean - pi) + if (pd >= pMean) { + hq.add((h, pd)) + } else { + lq.add((h, pd)) + } + offset += 1 + } + while (!hq.isEmpty) { + val (h, ph) = hq.remove() + assert(ph - pMean < 1e-8) + table._1(offset) = h + table._2(offset) = h + table._3(offset) = ph + offset += 1 + } + + while (!lq.isEmpty) { + val (i, pi) = lq.remove() + assert(pMean - pi < 1e-8) + table._1(offset) = i + table._2(offset) = i + table._3(offset) = pi + offset += 1 + } + table + } + + def sampleAlias(gen: Random, table: Table): Int = { + val l = table._1.length + val bin = gen.nextInt(l) + val p = table._3(bin) + if (l * p > gen.nextDouble()) { + table._1(bin) + } else { + table._2(bin) + } + } + + def uniformSampler(rand: Random, dimension: Int): Int = { + rand.nextInt(dimension) + } + + def binarySearchInterval( + index: Array[Double], + key: Double, + begin: Int, + end: Int, + greater: Boolean): Int = { + if (begin == end) { + return if (greater) end else begin - 1 + } + var b = begin + var e = end - 1 + + var mid: Int = (e + b) >> 1 + while (b <= e) { + mid = (e + b) >> 1 + val v = index(mid) + if (v < key) { + b = mid + 1 + } + else if (v > key) { + e = mid - 1 + } + else { + return mid + } + } + val v = index(mid) + mid = if ((greater && v >= key) || (!greater && v <= key)) { + mid + } + else if (greater) { + mid + 1 + } + else { + mid - 1 + } + + if (greater) { + if (mid < end) assert(index(mid) >= key) + if (mid > 0) assert(index(mid - 1) <= key) + } else { + if (mid > 0) assert(index(mid) <= key) + if (mid < end - 1) assert(index(mid + 1) >= key) + } + mid + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala new file mode 100644 index 0000000000000..b60ef0d3be11f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.scalatest.FunSuite + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import breeze.stats.distributions.Poisson +import org.apache.spark.mllib.linalg.{Vectors, SparseVector => SSV} + +class LDASuite extends FunSuite with MLlibTestSparkContext { + + import LDASuite._ + + test("LDA || Gibbs sampling") { + val model = generateRandomLDAModel(numTopics, numTerms) + val corpus = sampleCorpus(model, numDocs, numTerms, numTopics) + + val data = sc.parallelize(corpus, 2) + val pps = new Array[Double](incrementalLearning) + val lda = new LDA(data, numTopics, alpha, beta, alphaAS) + var i = 0 + val startedAt = System.currentTimeMillis() + while (i < incrementalLearning) { + lda.runGibbsSampling(totalIterations) + pps(i) = lda.perplexity + i += 1 + } + + println((System.currentTimeMillis() - startedAt) / 1e3) + pps.foreach(println) + + val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs} + assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.6) + assert(pps.head - pps.last > 0) + } + +} + +object LDASuite { + val numTopics = 5 + val numTerms = 1000 + val numDocs = 100 + val expectedDocLength = 300 + val alpha = 0.01 + val alphaAS = 1D + val beta = 0.01 + val totalIterations = 2 + val burnInIterations = 1 + val incrementalLearning = 10 + + /** + * Generate a random LDA model, i.e. the topic-term matrix. + */ + def generateRandomLDAModel(numTopics: Int, numTerms: Int): Array[BDV[Double]] = { + val model = new Array[BDV[Double]](numTopics) + val width = numTerms * 1.0 / numTopics + var topic = 0 + var i = 0 + while (topic < numTopics) { + val topicCentroid = width * (topic + 1) + model(topic) = BDV.zeros[Double](numTerms) + i = 0 + while (i < numTerms) { + // treat the term list as a circle, so the distance between the first one and the last one + // is 1, not n-1. + val distance = Math.abs(topicCentroid - i) % (numTerms / 2) + // Possibility is decay along with distance + model(topic)(i) = 1.0 / (1 + Math.abs(distance)) + i += 1 + } + topic += 1 + } + model + } + + /** + * Sample one document given the topic-term matrix. + */ + def ldaSampler( + model: Array[BDV[Double]], + topicDist: BDV[Double], + numTermsPerDoc: Int): Array[Int] = { + val samples = new Array[Int](numTermsPerDoc) + val rand = new Random() + (0 until numTermsPerDoc).foreach { i => + samples(i) = multinomialDistSampler( + rand, + model(multinomialDistSampler(rand, topicDist)) + ) + } + samples + } + + /** + * Sample corpus (many documents) from a given topic-term matrix. + */ + def sampleCorpus( + model: Array[BDV[Double]], + numDocs: Int, + numTerms: Int, + numTopics: Int): Array[(Long, SSV)] = { + (0 until numDocs).map { i => + val rand = new Random() + val numTermsPerDoc = Poisson.distribution(expectedDocLength).sample() + val numTopicsPerDoc = rand.nextInt(numTopics / 2) + 1 + val topicDist = BDV.zeros[Double](numTopics) + (0 until numTopicsPerDoc).foreach { _ => + topicDist(rand.nextInt(numTopics)) += 1 + } + val sv = BSV.zeros[Double](numTerms) + ldaSampler(model, topicDist, numTermsPerDoc).foreach { term => sv(term) += 1} + (i.toLong, Vectors.fromBreeze(sv).asInstanceOf[SSV]) + }.toArray + } + + /** + * A multinomial distribution sampler, using roulette method to sample an Int back. + */ + def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = { + val distSum = rand.nextDouble() * breeze.linalg.sum[BDV[Double], Double](dist) + + def loop(index: Int, accum: Double): Int = { + if (index == dist.length) return dist.length - 1 + val sum = accum - dist(index) + if (sum <= 0) index else loop(index + 1, sum) + } + + loop(0, distSum) + } +}