From 4d318cf9ab769dc9ed60c4cf914c8711e78e421d Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 5 Feb 2015 12:53:18 +0800 Subject: [PATCH 01/21] =?UTF-8?q?=E6=B7=BB=E5=8A=A0LDA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mllib/pom.xml | 5 + .../apache/spark/mllib/clustering/LDA.scala | 621 ++++++++++++++++++ .../spark/mllib/clustering/LDAModel.scala | 393 +++++++++++ .../spark/mllib/clustering/LDASuite.scala | 149 +++++ 4 files changed, 1168 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala diff --git a/mllib/pom.xml b/mllib/pom.xml index f0928e1268e43..b556df7a38f02 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -50,6 +50,11 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-graphx_${scala.binary.version} + ${project.version} + org.eclipse.jetty jetty-server diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala new file mode 100644 index 0000000000000..efa7991425742 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -0,0 +1,621 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.lang.ref.SoftReference +import java.util.Random + +import breeze.collection.mutable.OpenAddressHashArray +import breeze.linalg.{DenseVector => BDV, HashVector => BHV, +SparseVector => BSV, sum => brzSum} + +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.graphx._ +import org.apache.spark.Logging +import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} +import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV} +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer.KryoRegistrator +import org.apache.spark.storage.StorageLevel +import org.apache.spark.SparkContext._ +import org.apache.spark.util.collection.AppendOnlyMap +import org.apache.spark.util.random.XORShiftRandom +import org.apache.spark.util.Utils + +import LDA._ +import LDAUtils._ + +class LDA private[mllib]( + @transient var corpus: Graph[VD, ED], + val numTopics: Int, + val numTerms: Int, + val alpha: Double, + val beta: Double, + val alphaAS: Double, + @transient val storageLevel: StorageLevel) + extends Serializable with Logging { + + def this(docs: RDD[(DocId, SSV)], + numTopics: Int, + alpha: Double, + beta: Double, + alphaAS: Double, + storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK, + computedModel: Broadcast[LDAModel] = null) { + this(initializeCorpus(docs, numTopics, storageLevel, computedModel), + numTopics, docs.first()._2.size, alpha, beta, alphaAS, storageLevel) + } + + // scalastyle:off + /** + * 语料库文档数 + */ + val numDocs = docVertices.count() + + /** + * 语料库总的词数(包含重复) + */ + val numTokens = corpus.edges.map(e => e.attr.size.toDouble).sum().toLong + // scalastyle:on + + @transient private val sc = corpus.vertices.context + @transient private val seed = new Random().nextInt() + @transient private var innerIter = 1 + @transient private var totalTopicCounter: BDV[Count] = collectTotalTopicCounter(corpus) + + private def termVertices = corpus.vertices.filter(t => t._1 >= 0) + + private def docVertices = corpus.vertices.filter(t => t._1 < 0) + + private def checkpoint(): Unit = { + if (innerIter % 10 == 0 && sc.getCheckpointDir.isDefined) { + val edges = corpus.edges.map(t => t) + edges.checkpoint() + val newCorpus: Graph[VD, ED] = Graph.fromEdges(edges, null, + storageLevel, storageLevel) + corpus = updateCounter(newCorpus, numTopics).persist(storageLevel) + } + } + + private def collectTotalTopicCounter(graph: Graph[VD, ED]): BDV[Count] = { + val globalTopicCounter = collectGlobalCounter(graph, numTopics) + assert(brzSum(globalTopicCounter) == numTokens) + globalTopicCounter + } + + private def gibbsSampling(): Unit = { + val sampledCorpus = sampleTokens(corpus, totalTopicCounter, innerIter + seed, + numTokens, numTopics, numTerms, alpha, alphaAS, beta) + sampledCorpus.persist(storageLevel) + + val counterCorpus = updateCounter(sampledCorpus, numTopics) + counterCorpus.persist(storageLevel) + counterCorpus.vertices.count() + counterCorpus.edges.count() + totalTopicCounter = collectTotalTopicCounter(counterCorpus) + + corpus.edges.unpersist(false) + corpus.vertices.unpersist(false) + sampledCorpus.edges.unpersist(false) + sampledCorpus.vertices.unpersist(false) + corpus = counterCorpus + + checkpoint() + innerIter += 1 + } + + def saveModel(burnInIter: Int): LDAModel = { + var termTopicCounter: RDD[(VertexId, VD)] = null + for (iter <- 1 to burnInIter) { + logInfo(s"Save TopicModel (Iteration $iter/$burnInIter)") + var previousTermTopicCounter = termTopicCounter + gibbsSampling() + val newTermTopicCounter = termVertices + termTopicCounter = Option(termTopicCounter).map(_.join(newTermTopicCounter).map { + case (term, (a, b)) => + val c = new BHV(a) + new BHV(b) + (term, c.array) + }).getOrElse(newTermTopicCounter) + + termTopicCounter.cache().count() + Option(previousTermTopicCounter).foreach(_.unpersist()) + previousTermTopicCounter = termTopicCounter + } + val model = LDAModel(numTopics, numTerms, alpha, beta) + termTopicCounter.collect().foreach { case (term, counter) => + model.merge(term.toInt, new BHV(counter)) + } + model.gtc :/= burnInIter.toDouble + model.ttc.foreach { ttc => + ttc :/= burnInIter.toDouble + ttc.compact() + } + model + } + + def runGibbsSampling(iterations: Int): Unit = { + for (iter <- 1 to iterations) { + println(s"perplexity $iter: ${perplexity()}") + logInfo(s"Start Gibbs sampling (Iteration $iter/$iterations)") + gibbsSampling() + } + } + + def mergeDuplicateTopic(threshold: Double = 0.95D): Map[Int, Int] = { + val rows = termVertices.map(t => t._2).map { bsv => + val length = bsv.length + val used = bsv.activeSize + val index = bsv.index.slice(0, used) + val data = bsv.data.slice(0, used).map(_.toDouble) + new SSV(length, index, data).asInstanceOf[SV] + } + val simMatrix = new RowMatrix(rows).columnSimilarities() + val minMap = simMatrix.entries.filter { case MatrixEntry(row, column, sim) => + sim > threshold && row != column + }.map { case MatrixEntry(row, column, sim) => + (column.toInt, row.toInt) + }.groupByKey().map { case (topic, simTopics) => + (topic, simTopics.min) + }.collect().toMap + if (minMap.size > 0) { + corpus = corpus.mapEdges(edges => { + edges.attr.map { topic => + minMap.get(topic).getOrElse(topic) + } + }) + corpus = updateCounter(corpus, numTopics) + } + minMap + } + + + // scalastyle:off + /** + * 词在所有主题分布和该词所在文本的主题分布乘积: p(w)=\sum_{k}{p(k|d)*p(w|k)}= + * \sum_{k}{\frac{{n}_{kw}+{\beta }_{w}} {{n}_{k}+\bar{\beta }} \frac{{n}_{kd}+{\alpha }_{k}} {\sum{{n}_{k}}+\bar{\alpha }}}= + * \sum_{k} \frac{{\alpha }_{k}{\beta }_{w} + {n}_{kw}{\alpha }_{k} + {n}_{kd}{\beta }_{w} + {n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta }} \frac{1}{\sum{{n}_{k}}+\bar{\alpha }}} + * \exp^{-(\sum{\log(p(w))})/N} + * N为语料库包含的token数 + */ + // scalastyle:on + def perplexity(): Double = { + val totalTopicCounter = this.totalTopicCounter + val numTopics = this.numTopics + val numTerms = this.numTerms + val alpha = this.alpha + val beta = this.beta + val totalSize = brzSum(totalTopicCounter) + var totalProb = 0D + + // \frac{{\alpha }_{k}{\beta }_{w}}{{n}_{k}+\bar{\beta }} + totalTopicCounter.activeIterator.foreach { case (topic, cn) => + totalProb += alpha * beta / (cn + numTerms * beta) + } + + val termProb = corpus.mapVertices { (vid, counter) => + val probDist = BSV.zeros[Double](numTopics) + if (vid >= 0) { + val termTopicCounter = new BHV(counter) + // \frac{{n}_{kw}{\alpha }_{k}}{{n}_{k}+\bar{\beta }} + termTopicCounter.activeIterator.foreach { case (topic, cn) => + probDist(topic) = cn * alpha / + (totalTopicCounter(topic) + numTerms * beta) + } + } else { + val docTopicCounter = new BHV(counter) + // \frac{{n}_{kd}{\beta }_{w}}{{n}_{k}+\bar{\beta }} + docTopicCounter.activeIterator.foreach { case (topic, cn) => + probDist(topic) = cn * beta / + (totalTopicCounter(topic) + numTerms * beta) + } + } + probDist.compact() + (counter, probDist) + }.mapTriplets { triplet => + val (termTopicCounter, termProb) = triplet.srcAttr + val (docTopicCounter, docProb) = triplet.dstAttr + val docSize = brzSum(new BHV(docTopicCounter)) + val docTermSize = triplet.attr.length + var prob = 0D + + // \frac{{n}_{kw}{n}_{kd}}{{n}_{k}+\bar{\beta}} + docTopicCounter.activeIterator.foreach { case (topic, cn) => + prob += cn * termTopicCounter(topic) / + (totalTopicCounter(topic) + numTerms * beta) + } + prob += brzSum(docProb) + brzSum(termProb) + totalProb + prob += prob / (docSize + numTopics * alpha) + + docTermSize * Math.log(prob) + }.edges.map(t => t.attr).sum() + + math.exp(-1 * termProb / totalSize) + } +} + +object LDA { + + private[mllib] type DocId = VertexId + private[mllib] type WordId = VertexId + private[mllib] type Count = Int + private[mllib] type ED = Array[Count] + private[mllib] type VD = OpenAddressHashArray[Int] + + def train(docs: RDD[(DocId, SSV)], + numTopics: Int = 2048, + totalIter: Int = 150, + burnIn: Int = 5, + alpha: Double = 0.1, + beta: Double = 0.01, + alphaAS: Double = 0.1): LDAModel = { + require(totalIter > burnIn, "totalIter is less than burnIn") + require(totalIter > 0, "totalIter is less than 0") + require(burnIn > 0, "burnIn is less than 0") + val topicModeling = new LDA(docs, numTopics, alpha, beta, alphaAS) + topicModeling.runGibbsSampling(totalIter - burnIn) + topicModeling.saveModel(burnIn) + } + + def incrementalTrain(docs: RDD[(DocId, SSV)], + computedModel: LDAModel, + alphaAS: Double = 1, + totalIter: Int = 150, + burnIn: Int = 5): LDAModel = { + require(totalIter > burnIn, "totalIter is less than burnIn") + require(totalIter > 0, "totalIter is less than 0") + require(burnIn > 0, "burnIn is less than 0") + val numTopics = computedModel.ttc.size + val alpha = computedModel.alpha + val beta = computedModel.beta + + val broadcastModel = docs.context.broadcast(computedModel) + val topicModeling = new LDA(docs, numTopics, alpha, beta, alphaAS, + computedModel = broadcastModel) + broadcastModel.unpersist() + topicModeling.runGibbsSampling(totalIter - burnIn) + topicModeling.saveModel(burnIn) + } + + private[mllib] def sampleTokens( + graph: Graph[VD, ED], + totalTopicCounter: BDV[Count], + innerIter: Long, + numTokens: Double, + numTopics: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): Graph[VD, ED] = { + val parts = graph.edges.partitions.size + val nweGraph = graph.mapTriplets( + (pid, iter) => { + val gen = new XORShiftRandom(parts * innerIter + pid) + val wordTableCache = new AppendOnlyMap[VertexId, SoftReference[(Double, Table)]]() + var t: Table = null + var tSum: Double = 0.0 + + iter.map { + triplet => + val termId = triplet.srcId + val docId = triplet.dstId + val termTopicCounter = triplet.srcAttr + val docTopicCounter = triplet.dstAttr + val topics = triplet.attr + for (i <- 0 until topics.length) { + val currentTopic = topics(i) + if (t == null || gen.nextDouble() < 1e-6) { + val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) + t = generateAlias(dv._2, dv._1) + tSum = dv._1 + } + + val (dSum, d) = docTopicCounter.synchronized { + termTopicCounter.synchronized { + docTable(totalTopicCounter, termTopicCounter, docTopicCounter, + currentTopic, numTokens, numTerms, alpha, alphaAS, beta) + } + } + val (wSum, w) = termTopicCounter.synchronized { + wordTable(gen, wordTableCache, totalTopicCounter, + termTopicCounter, termId, numTokens, numTerms, alpha, alphaAS, beta) + } + val newTopic = docTopicCounter.synchronized { + termTopicCounter.synchronized { + tokenSampling(gen, t, tSum, w, wSum, d, dSum) + } + } + + if (newTopic != currentTopic) { + docTopicCounter.synchronized { + docTopicCounter(currentTopic) -= 1 + docTopicCounter(newTopic) += 1 + } + termTopicCounter.synchronized { + termTopicCounter(currentTopic) -= 1 + termTopicCounter(newTopic) += 1 + } + + totalTopicCounter(currentTopic) -= 1 + totalTopicCounter(newTopic) += 1 + + topics(i) = newTopic + } + + } + + topics + } + }, TripletFields.All) + nweGraph + } + + private def updateCounter(graph: Graph[VD, ED], numTopics: Int): Graph[VD, ED] = { + val newCounter = graph.aggregateMessages[BSV[Count]](ctx => { + val topics = ctx.attr + val vector = BSV.zeros[Count](numTopics) + for (topic <- topics) { + vector(topic) += 1 + } + ctx.sendToDst(vector) + ctx.sendToSrc(vector) + }, _ + _, TripletFields.EdgeOnly).mapValues { a => + val b = new VD(a.length) + a.activeIterator.foreach { t => + b(t._1) = t._2 + } + b + } + graph.joinVertices(newCounter)((_, _, nc) => nc) + } + + private def collectGlobalCounter(graph: Graph[VD, ED], numTopics: Int): BDV[Count] = { + graph.vertices.filter(t => t._1 >= 0).map(_._2). + aggregate(BDV.zeros[Count](numTopics))((a, b) => { + a :+= new BHV(b) + }, _ :+= _) + } + + private def initializeCorpus( + docs: RDD[(LDA.DocId, SSV)], + numTopics: Int, + storageLevel: StorageLevel, + computedModel: Broadcast[LDAModel] = null): Graph[VD, ED] = { + val edges = docs.mapPartitionsWithIndex((pid, iter) => { + val gen = new Random(pid) + var model: LDAModel = null + if (computedModel != null) model = computedModel.value + iter.flatMap { + case (docId, doc) => + initializeEdges(gen, new BSV[Int](doc.indices, doc.values.map(_.toInt), doc.size), + docId, numTopics, model) + } + }) + var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) + corpus.partitionBy(PartitionStrategy.EdgePartition1D) + corpus = updateCounter(corpus, numTopics).cache() + corpus.vertices.count() + corpus + } + + private def initializeEdges( + gen: Random, + doc: BSV[Int], + docId: DocId, + numTopics: Int, + computedModel: LDAModel = null): Array[Edge[ED]] = { + assert(docId >= 0) + val newDocId: DocId = -(docId + 1L) + if (computedModel == null) { + doc.activeIterator.map { case (termId, counter) => + val ev = (0 until counter).map { i => + gen.nextInt(numTopics) + }.toArray + Edge(termId, newDocId, ev) + }.toArray + } + else { + computedModel.setSeed(gen.nextInt()) + val tokens = computedModel.vector2Array(doc) + val topics = new Array[Int](tokens.length) + var docTopicCounter = computedModel.uniformDistSampler(tokens, topics) + for (t <- 0 until 15) { + docTopicCounter = computedModel.sampleTokens(docTopicCounter, + tokens, topics) + } + doc.activeIterator.map { case (term, counter) => + val ev = topics.zipWithIndex.filter { case (topic, offset) => + term == tokens(offset) + }.map(_._1) + Edge(term, newDocId, ev) + }.toArray + } + } + + // scalastyle:off + /** + * 这里组合使用 Gibbs sampler 和 Metropolis Hastings sampler + * 每次采样的复杂度为: O(1) + * 使用 Gibbs sampler 采样论文 Rethinking LDA: Why Priors Matter 公式(3) + * \frac{{n}_{kw}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta}} \frac{{n}_{kd} ^{-di}+ \bar{\alpha} \frac{{n}_{k}^{-di} + \acute{\alpha}}{\sum{n}_{k} +\bar{\acute{\alpha}}}}{\sum{n}_{kd}^{-di} +\bar{\alpha}} + * 其中 + * \bar{\beta}=\sum_{w}{\beta}_{w} + * \bar{\alpha}=\sum_{k}{\alpha}_{k} + * \bar{\acute{\alpha}}=\bar{\acute{\alpha}}=\sum_{k}\acute{\alpha} + * {n}_{kd} 是文档d中主题为k的tokens数 + * {n}_{kw} 词中主题为k的tokens数 + * {n}_{k} 是语料库中主题为k的tokens数 + */ + // scalastyle:on + def tokenSampling( + gen: Random, + t: Table, + tSum: Double, + w: Table, + wSum: Double, + d: Table, + dSum: Double): Int = { + val distSum = tSum + wSum + dSum + val genSum = gen.nextDouble() * distSum + if (genSum < dSum) { + sampleAlias(gen, d) + } else if (genSum < (dSum + wSum)) { + sampleAlias(gen, w) + } else { + sampleAlias(gen, t) + } + } + + + private def tDense( + totalTopicCounter: BDV[Count], + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BDV[Double]) = { + val numTopics = totalTopicCounter.length + val t = BDV.zeros[Double](numTopics) + val alphaSum = alpha * numTopics + val termSum = numTokens - 1D + alphaAS * numTopics + val betaSum = numTerms * beta + var sum = 0.0 + for (topic <- 0 until numTopics) { + val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) / + ((totalTopicCounter(topic) + betaSum) * termSum) + t(topic) = last + sum += last + } + (sum, t) + } + + private def wSparse( + totalTopicCounter: BDV[Count], + termTopicCounter: VD, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BSV[Double]) = { + val numTopics = totalTopicCounter.length + val alphaSum = alpha * numTopics + val termSum = numTokens - 1D + alphaAS * numTopics + val betaSum = numTerms * beta + val w = BSV.zeros[Double](numTopics) + var sum = 0.0 + termTopicCounter.activeIterator.foreach { t => + val topic = t._1 + val count = t._2 + val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / + ((totalTopicCounter(topic) + betaSum) * termSum) + w(topic) = last + sum += last + } + (sum, w) + } + + private def dSparse( + totalTopicCounter: BDV[Count], + termTopicCounter: VD, + docTopicCounter: VD, + currentTopic: Int, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BSV[Double]) = { + val numTopics = totalTopicCounter.length + // val termSum = numTokens - 1D + alphaAS * numTopics + val betaSum = numTerms * beta + val d = BSV.zeros[Double](numTopics) + var sum = 0.0 + docTopicCounter.activeIterator.foreach { t => + val topic = t._1 + val count = if (currentTopic == topic && t._2 != 1) t._2 - 1 else t._2 + // val last = count * termSum * (termTopicCounter(topic) + beta) / + // ((totalTopicCounter(topic) + betaSum) * termSum) + val last = count * (termTopicCounter(topic) + beta) / + (totalTopicCounter(topic) + betaSum) + d(topic) = last + sum += last + } + (sum, d) + } + + private def wordTable( + gen: Random, + cacheMap: AppendOnlyMap[VertexId, SoftReference[(Double, Table)]], + totalTopicCounter: BDV[Count], + termTopicCounter: VD, + termId: VertexId, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, Table) = { + var w = cacheMap(termId) + if (w == null || w.get() == null || gen.nextDouble() < 1e-5) { + termTopicCounter.synchronized { + val t = wSparse(totalTopicCounter, termTopicCounter, + numTokens, numTerms, alpha, alphaAS, beta) + w = new SoftReference((t._1, generateAlias(t._2, t._1))) + cacheMap.update(termId, w) + } + } + w.get() + } + + private def docTable( + totalTopicCounter: BDV[Count], + termTopicCounter: VD, + docTopicCounter: VD, + currentTopic: Int, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, Table) = { + val d = dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, + currentTopic, numTokens, numTerms, alpha, alphaAS, beta) + (d._1, generateAlias(d._2, d._1)) + } + +} + +class LDAKryoRegistrator extends KryoRegistrator { + def registerClasses(kryo: com.esotericsoftware.kryo.Kryo) { + val gkr = new GraphKryoRegistrator + gkr.registerClasses(kryo) + + kryo.register(classOf[BSV[LDA.Count]]) + kryo.register(classOf[BSV[Double]]) + + kryo.register(classOf[BDV[LDA.Count]]) + kryo.register(classOf[BDV[Double]]) + + kryo.register(classOf[SV]) + kryo.register(classOf[SSV]) + kryo.register(classOf[SDV]) + + kryo.register(classOf[LDA.ED]) + kryo.register(classOf[LDA.VD]) + + kryo.register(classOf[Random]) + kryo.register(classOf[LDA]) + kryo.register(classOf[LDAModel]) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala new file mode 100644 index 0000000000000..e6f64a033b83e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.lang.ref.SoftReference +import java.util.Random +import java.util.{PriorityQueue => JPriorityQueue} + +import org.apache.spark.util.random.XORShiftRandom + +import scala.reflect.ClassTag + +import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, +sum => brzSum, norm => brzNorm} + +import org.apache.spark.mllib.linalg.{Vectors, DenseVector => SDV, SparseVector => SSV} +import org.apache.spark.util.collection.AppendOnlyMap + +import LDAUtils._ + +class LDAModel private[mllib]( + private[mllib] val gtc: BDV[Double], + private[mllib] val ttc: Array[BSV[Double]], + val alpha: Double, + val beta: Double, + val alphaAS: Double) extends Serializable { + + def this(topicCounts: SDV, topicTermCounts: Array[SSV], alpha: Double, beta: Double) { + this(new BDV[Double](topicCounts.toArray), topicTermCounts.map(t => + new BSV(t.indices, t.values, t.size)), alpha, beta, alpha) + } + + @transient private lazy val numTopics = gtc.size + @transient private lazy val numTerms = ttc.size + @transient private lazy val numTokens = brzSum(gtc) + @transient private lazy val betaSum = numTerms * beta + @transient private lazy val alphaSum = numTopics * alpha + @transient private lazy val termSum = numTokens + alphaAS * numTopics + + @transient private lazy val wordTableCache = + new AppendOnlyMap[Int, SoftReference[(Double, Table)]]() + @transient private lazy val (t, tSum) = { + val dv = tDense(gtc, numTokens, numTerms, alpha, alphaAS, beta) + (generateAlias(dv._2, dv._1), dv._1) + } + @transient private lazy val rand = new XORShiftRandom() + + def setSeed(seed: Long): Unit = { + rand.setSeed(seed) + } + + def globalTopicCounter = Vectors.fromBreeze(gtc) + + def topicTermCounter = ttc.map(t => Vectors.fromBreeze(t)) + + def inference( + doc: SSV, + totalIter: Int = 10, + burnIn: Int = 5): SSV = { + require(totalIter > burnIn, "totalIter is less than burnInIter") + require(totalIter > 0, "totalIter is less than 0") + require(burnIn > 0, "burnInIter is less than 0") + + val topicDist = BSV.zeros[Double](numTopics) + val tokens = vector2Array(new BSV[Int](doc.indices, doc.values.map(_.toInt), doc.size)) + val topics = new Array[Int](tokens.length) + + var docTopicCounter = uniformDistSampler(tokens, topics) + for (i <- 0 until totalIter) { + docTopicCounter = sampleTokens(docTopicCounter, tokens, topics) + if (i + burnIn >= totalIter) topicDist :+= docTopicCounter + } + + topicDist.compact() + topicDist :/= brzNorm(topicDist, 1) + Vectors.fromBreeze(topicDist).asInstanceOf[SSV] + } + + private[mllib] def vector2Array(vec: BV[Int]): Array[Int] = { + val docLen = brzSum(vec) + var offset = 0 + val sent = new Array[Int](docLen) + vec.activeIterator.foreach { case (term, cn) => + for (i <- 0 until cn) { + sent(offset) = term + offset += 1 + } + } + sent + } + + private[mllib] def uniformDistSampler( + tokens: Array[Int], + topics: Array[Int]): BSV[Double] = { + val docTopicCounter = BSV.zeros[Double](numTopics) + for (i <- 0 until tokens.length) { + val topic = uniformSampler(rand, numTopics) + topics(i) = topic + docTopicCounter(topic) += 1D + } + docTopicCounter + } + + private[mllib] def sampleTokens( + docTopicCounter: BSV[Double], + tokens: Array[Int], + topics: Array[Int]): BSV[Double] = { + for (i <- 0 until topics.length) { + val termId = tokens(i) + val currentTopic = topics(i) + val (dSum, d) = docTable(gtc, ttc(termId), docTopicCounter, + currentTopic, numTokens, numTerms, alpha, alphaAS, beta) + + val (wSum, w) = wordTable(wordTableCache, gtc, ttc(termId), termId, + numTokens, numTerms, alpha, alphaAS, beta) + val newTopic = tokenSampling(rand, t, tSum, w, wSum, d, dSum) + if (newTopic != currentTopic) { + docTopicCounter(newTopic) += 1D + docTopicCounter(currentTopic) -= 1D + topics(i) = newTopic + if (docTopicCounter(currentTopic) == 0) { + docTopicCounter.compact() + } + } + } + docTopicCounter + } + + private def tokenSampling( + gen: Random, + t: Table, + tSum: Double, + w: Table, + wSum: Double, + d: Table, + dSum: Double): Int = { + val distSum = tSum + wSum + dSum + val genSum = gen.nextDouble() * distSum + if (genSum < dSum) { + sampleAlias(gen, d) + } else if (genSum < (dSum + wSum)) { + sampleAlias(gen, w) + } else { + sampleAlias(gen, t) + } + } + + + private def tDense( + totalTopicCounter: BDV[Double], + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BDV[Double]) = { + val t = BDV.zeros[Double](numTopics) + var sum = 0.0 + for (topic <- 0 until numTopics) { + val last = beta * alphaSum * (totalTopicCounter(topic) + alphaAS) / + ((totalTopicCounter(topic) + betaSum) * termSum) + t(topic) = last + sum += last + } + (sum, t) + } + + private def wSparse( + totalTopicCounter: BDV[Double], + termTopicCounter: BSV[Double], + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BSV[Double]) = { + val w = BSV.zeros[Double](numTopics) + var sum = 0.0 + termTopicCounter.activeIterator.foreach { t => + val topic = t._1 + val count = t._2 + val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / + ((totalTopicCounter(topic) + betaSum) * termSum) + w(topic) = last + sum += last + } + (sum, w) + } + + private def dSparse( + totalTopicCounter: BDV[Double], + termTopicCounter: BSV[Double], + docTopicCounter: BSV[Double], + currentTopic: Int, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, BSV[Double]) = { + val d = BSV.zeros[Double](numTopics) + var sum = 0.0 + docTopicCounter.activeIterator.foreach { t => + val topic = t._1 + val count = if (currentTopic == topic && t._2 != 1) t._2 - 1 else t._2 + val last = count * (termTopicCounter(topic) + beta) / + (totalTopicCounter(topic) + betaSum) + d(topic) = last + sum += last + } + (sum, d) + } + + private def wordTable( + cacheMap: AppendOnlyMap[Int, SoftReference[(Double, Table)]], + totalTopicCounter: BDV[Double], + termTopicCounter: BSV[Double], + termId: Int, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, Table) = { + var w = cacheMap(termId) + if (w == null || w.get() == null) { + val t = wSparse(totalTopicCounter, termTopicCounter, + numTokens, numTerms, alpha, alphaAS, beta) + w = new SoftReference((t._1, generateAlias(t._2, t._1))) + cacheMap.update(termId, w) + + } + w.get() + } + + private def docTable( + totalTopicCounter: BDV[Double], + termTopicCounter: BSV[Double], + docTopicCounter: BSV[Double], + currentTopic: Int, + numTokens: Double, + numTerms: Double, + alpha: Double, + alphaAS: Double, + beta: Double): (Double, Table) = { + val d = dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, + currentTopic, numTokens, numTerms, alpha, alphaAS, beta) + (d._1, generateAlias(d._2, d._1)) + } + + private[mllib] def mergeOne(term: Int, topic: Int, inc: Int) = { + gtc(topic) += inc + ttc(term)(topic) += inc + this + } + + private[mllib] def merge(term: Int, counter: BV[Int]) = { + counter.activeIterator.foreach { case (topic, cn) => + mergeOne(term, topic, cn) + } + this + } + + private[mllib] def merge(other: LDAModel) = { + gtc :+= other.gtc + for (i <- 0 until ttc.length) { + ttc(i) :+= other.ttc(i) + } + this + } +} + +object LDAModel { + def apply(numTopics: Int, numTerms: Int, alpha: Double = 0.1, beta: Double = 0.01) = { + new LDAModel( + BDV.zeros[Double](numTopics), + (0 until numTerms).map(_ => BSV.zeros[Double](numTopics)).toArray, alpha, beta, alpha) + } +} + +private[mllib] object LDAUtils { + + type Table = Array[(Int, Int, Double)] + + @transient private lazy val tableOrdering = new scala.math.Ordering[(Int, Double)] { + override def compare(x: (Int, Double), y: (Int, Double)): Int = { + Ordering.Double.compare(x._2, y._2) + } + } + + @transient private lazy val tableReverseOrdering = tableOrdering.reverse + + def generateAlias(sv: BV[Double], sum: Double): Table = { + val used = sv.activeSize + val probs = sv.activeIterator.slice(0, used) + generateAlias(probs, used, sum) + } + + def generateAlias( + probs: Iterator[(Int, Double)], + used: Int, + sum: Double): Table = { + val pMean = 1.0 / used + val table = new Table(used) + + val lq = new JPriorityQueue[(Int, Double)](used, tableOrdering) + val hq = new JPriorityQueue[(Int, Double)](used, tableReverseOrdering) + + probs.slice(0, used).foreach { pair => + val i = pair._1 + val pi = pair._2 / sum + if (pi < pMean) { + lq.add((i, pi)) + } else { + hq.add((i, pi)) + } + } + + var offset = 0 + while (!lq.isEmpty & !hq.isEmpty) { + val (i, pi) = lq.remove() + val (h, ph) = hq.remove() + table(offset) = (i, h, pi) + val pd = ph - (pMean - pi) + if (pd >= pMean) { + hq.add((h, pd)) + } else { + lq.add((h, pd)) + } + offset += 1 + } + while (!hq.isEmpty) { + val (h, ph) = hq.remove() + assert(ph - pMean < 1e-8) + table(offset) = (h, h, ph) + offset += 1 + } + + while (!lq.isEmpty) { + val (i, pi) = lq.remove() + assert(pMean - pi < 1e-8) + table(offset) = (i, i, pi) + offset += 1 + } + + // 测试代码 随即抽样一个样本验证其概率 + // val (di, dp) = probs(Utils.random.nextInt(used)) + // val ds = table.map { t => + // if (t._1 == di) { + // if (t._2 == t._1) { + // pMean + // } else { + // t._3 + // } + // } else if (t._2 == di) { + // pMean - t._3 + // } else { + // 0.0 + // } + // }.sum + // assert((ds - dp).abs < 1e-4) + + table + } + + def sampleAlias(gen: Random, table: Table): Int = { + val l = table.length + val bin = gen.nextInt(l) + val i = table(bin)._1 + val h = table(bin)._2 + val p = table(bin)._3 + if (l * p > gen.nextDouble()) { + i + } else { + h + } + } + + def uniformSampler(rand: Random, dimension: Int): Int = { + rand.nextInt(dimension) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala new file mode 100644 index 0000000000000..b60ef0d3be11f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.scalatest.FunSuite + +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV} +import breeze.stats.distributions.Poisson +import org.apache.spark.mllib.linalg.{Vectors, SparseVector => SSV} + +class LDASuite extends FunSuite with MLlibTestSparkContext { + + import LDASuite._ + + test("LDA || Gibbs sampling") { + val model = generateRandomLDAModel(numTopics, numTerms) + val corpus = sampleCorpus(model, numDocs, numTerms, numTopics) + + val data = sc.parallelize(corpus, 2) + val pps = new Array[Double](incrementalLearning) + val lda = new LDA(data, numTopics, alpha, beta, alphaAS) + var i = 0 + val startedAt = System.currentTimeMillis() + while (i < incrementalLearning) { + lda.runGibbsSampling(totalIterations) + pps(i) = lda.perplexity + i += 1 + } + + println((System.currentTimeMillis() - startedAt) / 1e3) + pps.foreach(println) + + val ppsDiff = pps.init.zip(pps.tail).map { case (lhs, rhs) => lhs - rhs} + assert(ppsDiff.count(_ > 0).toDouble / ppsDiff.size > 0.6) + assert(pps.head - pps.last > 0) + } + +} + +object LDASuite { + val numTopics = 5 + val numTerms = 1000 + val numDocs = 100 + val expectedDocLength = 300 + val alpha = 0.01 + val alphaAS = 1D + val beta = 0.01 + val totalIterations = 2 + val burnInIterations = 1 + val incrementalLearning = 10 + + /** + * Generate a random LDA model, i.e. the topic-term matrix. + */ + def generateRandomLDAModel(numTopics: Int, numTerms: Int): Array[BDV[Double]] = { + val model = new Array[BDV[Double]](numTopics) + val width = numTerms * 1.0 / numTopics + var topic = 0 + var i = 0 + while (topic < numTopics) { + val topicCentroid = width * (topic + 1) + model(topic) = BDV.zeros[Double](numTerms) + i = 0 + while (i < numTerms) { + // treat the term list as a circle, so the distance between the first one and the last one + // is 1, not n-1. + val distance = Math.abs(topicCentroid - i) % (numTerms / 2) + // Possibility is decay along with distance + model(topic)(i) = 1.0 / (1 + Math.abs(distance)) + i += 1 + } + topic += 1 + } + model + } + + /** + * Sample one document given the topic-term matrix. + */ + def ldaSampler( + model: Array[BDV[Double]], + topicDist: BDV[Double], + numTermsPerDoc: Int): Array[Int] = { + val samples = new Array[Int](numTermsPerDoc) + val rand = new Random() + (0 until numTermsPerDoc).foreach { i => + samples(i) = multinomialDistSampler( + rand, + model(multinomialDistSampler(rand, topicDist)) + ) + } + samples + } + + /** + * Sample corpus (many documents) from a given topic-term matrix. + */ + def sampleCorpus( + model: Array[BDV[Double]], + numDocs: Int, + numTerms: Int, + numTopics: Int): Array[(Long, SSV)] = { + (0 until numDocs).map { i => + val rand = new Random() + val numTermsPerDoc = Poisson.distribution(expectedDocLength).sample() + val numTopicsPerDoc = rand.nextInt(numTopics / 2) + 1 + val topicDist = BDV.zeros[Double](numTopics) + (0 until numTopicsPerDoc).foreach { _ => + topicDist(rand.nextInt(numTopics)) += 1 + } + val sv = BSV.zeros[Double](numTerms) + ldaSampler(model, topicDist, numTermsPerDoc).foreach { term => sv(term) += 1} + (i.toLong, Vectors.fromBreeze(sv).asInstanceOf[SSV]) + }.toArray + } + + /** + * A multinomial distribution sampler, using roulette method to sample an Int back. + */ + def multinomialDistSampler(rand: Random, dist: BDV[Double]): Int = { + val distSum = rand.nextDouble() * breeze.linalg.sum[BDV[Double], Double](dist) + + def loop(index: Int, accum: Double): Int = { + if (index == dist.length) return dist.length - 1 + val sum = accum - dist(index) + if (sum <= 0) index else loop(index + 1, sum) + } + + loop(0, distSum) + } +} From 36dc60b954775800812e01d2a4ac546f9ba42aab Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Tue, 10 Feb 2015 20:42:57 +0800 Subject: [PATCH 02/21] =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 71 ++++++--------- .../spark/mllib/clustering/LDAModel.scala | 88 ++++++++++++++----- 2 files changed, 90 insertions(+), 69 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index efa7991425742..3a7e65976349b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -20,9 +20,7 @@ package org.apache.spark.mllib.clustering import java.lang.ref.SoftReference import java.util.Random -import breeze.collection.mutable.OpenAddressHashArray -import breeze.linalg.{DenseVector => BDV, HashVector => BHV, -SparseVector => BSV, sum => brzSum} +import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, sum => brzSum} import org.apache.spark.broadcast.Broadcast import org.apache.spark.graphx._ @@ -35,7 +33,6 @@ import org.apache.spark.storage.StorageLevel import org.apache.spark.SparkContext._ import org.apache.spark.util.collection.AppendOnlyMap import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.util.Utils import LDA._ import LDAUtils._ @@ -86,8 +83,7 @@ class LDA private[mllib]( if (innerIter % 10 == 0 && sc.getCheckpointDir.isDefined) { val edges = corpus.edges.map(t => t) edges.checkpoint() - val newCorpus: Graph[VD, ED] = Graph.fromEdges(edges, null, - storageLevel, storageLevel) + val newCorpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) corpus = updateCounter(newCorpus, numTopics).persist(storageLevel) } } @@ -128,8 +124,7 @@ class LDA private[mllib]( val newTermTopicCounter = termVertices termTopicCounter = Option(termTopicCounter).map(_.join(newTermTopicCounter).map { case (term, (a, b)) => - val c = new BHV(a) + new BHV(b) - (term, c.array) + (term, a :+ b) }).getOrElse(newTermTopicCounter) termTopicCounter.cache().count() @@ -138,7 +133,7 @@ class LDA private[mllib]( } val model = LDAModel(numTopics, numTerms, alpha, beta) termTopicCounter.collect().foreach { case (term, counter) => - model.merge(term.toInt, new BHV(counter)) + model.merge(term.toInt, counter) } model.gtc :/= burnInIter.toDouble model.ttc.foreach { ttc => @@ -210,14 +205,14 @@ class LDA private[mllib]( val termProb = corpus.mapVertices { (vid, counter) => val probDist = BSV.zeros[Double](numTopics) if (vid >= 0) { - val termTopicCounter = new BHV(counter) + val termTopicCounter = counter // \frac{{n}_{kw}{\alpha }_{k}}{{n}_{k}+\bar{\beta }} termTopicCounter.activeIterator.foreach { case (topic, cn) => probDist(topic) = cn * alpha / (totalTopicCounter(topic) + numTerms * beta) } } else { - val docTopicCounter = new BHV(counter) + val docTopicCounter = counter // \frac{{n}_{kd}{\beta }_{w}}{{n}_{k}+\bar{\beta }} docTopicCounter.activeIterator.foreach { case (topic, cn) => probDist(topic) = cn * beta / @@ -229,7 +224,7 @@ class LDA private[mllib]( }.mapTriplets { triplet => val (termTopicCounter, termProb) = triplet.srcAttr val (docTopicCounter, docProb) = triplet.dstAttr - val docSize = brzSum(new BHV(docTopicCounter)) + val docSize = brzSum(docTopicCounter) val docTermSize = triplet.attr.length var prob = 0D @@ -254,7 +249,7 @@ object LDA { private[mllib] type WordId = VertexId private[mllib] type Count = Int private[mllib] type ED = Array[Count] - private[mllib] type VD = OpenAddressHashArray[Int] + private[mllib] type VD = BSV[Count] def train(docs: RDD[(DocId, SSV)], numTopics: Int = 2048, @@ -324,9 +319,9 @@ object LDA { tSum = dv._1 } - val (dSum, d) = docTopicCounter.synchronized { + val d = docTopicCounter.synchronized { termTopicCounter.synchronized { - docTable(totalTopicCounter, termTopicCounter, docTopicCounter, + dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, currentTopic, numTokens, numTerms, alpha, alphaAS, beta) } } @@ -336,7 +331,7 @@ object LDA { } val newTopic = docTopicCounter.synchronized { termTopicCounter.synchronized { - tokenSampling(gen, t, tSum, w, wSum, d, dSum) + tokenSampling(gen, t, tSum, w, wSum, d) } } @@ -365,7 +360,7 @@ object LDA { } private def updateCounter(graph: Graph[VD, ED], numTopics: Int): Graph[VD, ED] = { - val newCounter = graph.aggregateMessages[BSV[Count]](ctx => { + val newCounter = graph.aggregateMessages[VD](ctx => { val topics = ctx.attr val vector = BSV.zeros[Count](numTopics) for (topic <- topics) { @@ -373,20 +368,14 @@ object LDA { } ctx.sendToDst(vector) ctx.sendToSrc(vector) - }, _ + _, TripletFields.EdgeOnly).mapValues { a => - val b = new VD(a.length) - a.activeIterator.foreach { t => - b(t._1) = t._2 - } - b - } + }, _ + _, TripletFields.EdgeOnly) graph.joinVertices(newCounter)((_, _, nc) => nc) } private def collectGlobalCounter(graph: Graph[VD, ED], numTopics: Int): BDV[Count] = { graph.vertices.filter(t => t._1 >= 0).map(_._2). aggregate(BDV.zeros[Count](numTopics))((a, b) => { - a :+= new BHV(b) + a :+= b }, _ :+= _) } @@ -467,12 +456,17 @@ object LDA { tSum: Double, w: Table, wSum: Double, - d: Table, - dSum: Double): Int = { + d: BSV[Double]): Int = { + val index = d.index + val data = d.data + val used = d.used + val dSum = data(d.used - 1) val distSum = tSum + wSum + dSum val genSum = gen.nextDouble() * distSum if (genSum < dSum) { - sampleAlias(gen, d) + val dGenSum = gen.nextDouble() * dSum + val pos = binarySearchInterval(data, dGenSum, 0, used, true) + index(pos) } else if (genSum < (dSum + wSum)) { sampleAlias(gen, w) } else { @@ -537,7 +531,7 @@ object LDA { numTerms: Double, alpha: Double, alphaAS: Double, - beta: Double): (Double, BSV[Double]) = { + beta: Double): BSV[Double] = { val numTopics = totalTopicCounter.length // val termSum = numTokens - 1D + alphaAS * numTopics val betaSum = numTerms * beta @@ -550,10 +544,10 @@ object LDA { // ((totalTopicCounter(topic) + betaSum) * termSum) val last = count * (termTopicCounter(topic) + beta) / (totalTopicCounter(topic) + betaSum) - d(topic) = last sum += last + d(topic) = sum } - (sum, d) + d } private def wordTable( @@ -579,21 +573,6 @@ object LDA { w.get() } - private def docTable( - totalTopicCounter: BDV[Count], - termTopicCounter: VD, - docTopicCounter: VD, - currentTopic: Int, - numTokens: Double, - numTerms: Double, - alpha: Double, - alphaAS: Double, - beta: Double): (Double, Table) = { - val d = dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, - currentTopic, numTokens, numTerms, alpha, alphaAS, beta) - (d._1, generateAlias(d._2, d._1)) - } - } class LDAKryoRegistrator extends KryoRegistrator { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index e6f64a033b83e..05fd49f5ae162 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -123,12 +123,12 @@ class LDAModel private[mllib]( for (i <- 0 until topics.length) { val termId = tokens(i) val currentTopic = topics(i) - val (dSum, d) = docTable(gtc, ttc(termId), docTopicCounter, + val d = dSparse(gtc, ttc(termId), docTopicCounter, currentTopic, numTokens, numTerms, alpha, alphaAS, beta) val (wSum, w) = wordTable(wordTableCache, gtc, ttc(termId), termId, numTokens, numTerms, alpha, alphaAS, beta) - val newTopic = tokenSampling(rand, t, tSum, w, wSum, d, dSum) + val newTopic = tokenSampling(rand, t, tSum, w, wSum, d) if (newTopic != currentTopic) { docTopicCounter(newTopic) += 1D docTopicCounter(currentTopic) -= 1D @@ -147,12 +147,17 @@ class LDAModel private[mllib]( tSum: Double, w: Table, wSum: Double, - d: Table, - dSum: Double): Int = { + d: BSV[Double]): Int = { + val index = d.index + val data = d.data + val used = d.used + val dSum = data(d.used - 1) val distSum = tSum + wSum + dSum val genSum = gen.nextDouble() * distSum if (genSum < dSum) { - sampleAlias(gen, d) + val dGenSum = gen.nextDouble() * dSum + val pos = binarySearchInterval(data, dGenSum, 0, used, true) + index(pos) } else if (genSum < (dSum + wSum)) { sampleAlias(gen, w) } else { @@ -209,18 +214,23 @@ class LDAModel private[mllib]( numTerms: Double, alpha: Double, alphaAS: Double, - beta: Double): (Double, BSV[Double]) = { + beta: Double): BSV[Double] = { + val numTopics = totalTopicCounter.length + // val termSum = numTokens - 1D + alphaAS * numTopics + val betaSum = numTerms * beta val d = BSV.zeros[Double](numTopics) var sum = 0.0 docTopicCounter.activeIterator.foreach { t => val topic = t._1 val count = if (currentTopic == topic && t._2 != 1) t._2 - 1 else t._2 + // val last = count * termSum * (termTopicCounter(topic) + beta) / + // ((totalTopicCounter(topic) + betaSum) * termSum) val last = count * (termTopicCounter(topic) + beta) / (totalTopicCounter(topic) + betaSum) - d(topic) = last sum += last + d(topic) = sum } - (sum, d) + d } private def wordTable( @@ -244,21 +254,6 @@ class LDAModel private[mllib]( w.get() } - private def docTable( - totalTopicCounter: BDV[Double], - termTopicCounter: BSV[Double], - docTopicCounter: BSV[Double], - currentTopic: Int, - numTokens: Double, - numTerms: Double, - alpha: Double, - alphaAS: Double, - beta: Double): (Double, Table) = { - val d = dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, - currentTopic, numTokens, numTerms, alpha, alphaAS, beta) - (d._1, generateAlias(d._2, d._1)) - } - private[mllib] def mergeOne(term: Int, topic: Int, inc: Int) = { gtc(topic) += inc ttc(term)(topic) += inc @@ -390,4 +385,51 @@ private[mllib] object LDAUtils { def uniformSampler(rand: Random, dimension: Int): Int = { rand.nextInt(dimension) } + + def binarySearchInterval( + index: Array[Double], + key: Double, + begin: Int, + end: Int, + greater: Boolean): Int = { + if (begin == end) { + return if (greater) end else begin - 1 + } + var b = begin + var e = end - 1 + + var mid: Int = (e + b) >> 1 + while (b <= e) { + mid = (e + b) >> 1 + val v = index(mid) + if (v < key) { + b = mid + 1 + } + else if (v > key) { + e = mid - 1 + } + else { + return mid + } + } + val v = index(mid) + mid = if ((greater && v >= key) || (!greater && v <= key)) { + mid + } + else if (greater) { + mid + 1 + } + else { + mid - 1 + } + + if (greater) { + if (mid < end) assert(index(mid) >= key) + if (mid > 0) assert(index(mid - 1) <= key) + } else { + if (mid > 0) assert(index(mid) <= key) + if (mid < end - 1) assert(index(mid + 1) >= key) + } + mid + } } From 5225ce5aca5b36f690fbe6e23803201484613a45 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Wed, 11 Feb 2015 16:24:43 +0800 Subject: [PATCH 03/21] =?UTF-8?q?=E6=8C=89=E8=AF=8D=E5=88=87=E5=89=B2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../main/scala/org/apache/spark/mllib/clustering/LDA.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 3a7e65976349b..659a5e57e362d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -390,15 +390,14 @@ object LDA { if (computedModel != null) model = computedModel.value iter.flatMap { case (docId, doc) => - initializeEdges(gen, new BSV[Int](doc.indices, doc.values.map(_.toInt), doc.size), - docId, numTopics, model) + val bsv = new BSV[Int](doc.indices, doc.values.map(_.toInt), doc.size) + initializeEdges(gen, bsv, docId, numTopics, model) } }) var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) - corpus.partitionBy(PartitionStrategy.EdgePartition1D) corpus = updateCounter(corpus, numTopics).cache() corpus.vertices.count() - corpus + corpus.partitionBy(PartitionStrategy.EdgePartition1D) } private def initializeEdges( From de530cdc31ecb4a5ccad985320bcc667c994f0f9 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Wed, 11 Feb 2015 18:27:17 +0800 Subject: [PATCH 04/21] =?UTF-8?q?=E7=A7=BB=E9=99=A4=E8=B0=83=E8=AF=95?= =?UTF-8?q?=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/scala/org/apache/spark/mllib/clustering/LDA.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 659a5e57e362d..50d00982f0206 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -145,7 +145,7 @@ class LDA private[mllib]( def runGibbsSampling(iterations: Int): Unit = { for (iter <- 1 to iterations) { - println(s"perplexity $iter: ${perplexity()}") + // println(s"perplexity $iter: ${perplexity()}") logInfo(s"Start Gibbs sampling (Iteration $iter/$iterations)") gibbsSampling() } From 6a9db52dfe8f03d445a01c44f6df173235871e44 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Wed, 11 Feb 2015 21:53:14 +0800 Subject: [PATCH 05/21] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=85=AC=E5=BC=8F?= =?UTF-8?q?=E6=8B=86=E8=A7=A3=E6=B3=A8=E8=A7=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 108 +++++++++--------- 1 file changed, 54 insertions(+), 54 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 50d00982f0206..6647db8d5b3bf 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -301,9 +301,9 @@ object LDA { (pid, iter) => { val gen = new XORShiftRandom(parts * innerIter + pid) val wordTableCache = new AppendOnlyMap[VertexId, SoftReference[(Double, Table)]]() - var t: Table = null - var tSum: Double = 0.0 - + val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) + val t = generateAlias(dv._2, dv._1) + val tSum = dv._1 iter.map { triplet => val termId = triplet.srcId @@ -313,44 +313,15 @@ object LDA { val topics = triplet.attr for (i <- 0 until topics.length) { val currentTopic = topics(i) - if (t == null || gen.nextDouble() < 1e-6) { - val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) - t = generateAlias(dv._2, dv._1) - tSum = dv._1 - } - - val d = docTopicCounter.synchronized { - termTopicCounter.synchronized { - dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, - currentTopic, numTokens, numTerms, alpha, alphaAS, beta) - } - } - val (wSum, w) = termTopicCounter.synchronized { - wordTable(gen, wordTableCache, totalTopicCounter, - termTopicCounter, termId, numTokens, numTerms, alpha, alphaAS, beta) - } - val newTopic = docTopicCounter.synchronized { - termTopicCounter.synchronized { - tokenSampling(gen, t, tSum, w, wSum, d) - } - } + val d = dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, + currentTopic, numTokens, numTerms, alpha, alphaAS, beta) + val (wSum, w) = wordTable(wordTableCache, totalTopicCounter, + termTopicCounter, termId, numTokens, numTerms, alpha, alphaAS, beta) + val newTopic = tokenSampling(gen, t, tSum, w, wSum, d) if (newTopic != currentTopic) { - docTopicCounter.synchronized { - docTopicCounter(currentTopic) -= 1 - docTopicCounter(newTopic) += 1 - } - termTopicCounter.synchronized { - termTopicCounter(currentTopic) -= 1 - termTopicCounter(newTopic) += 1 - } - - totalTopicCounter(currentTopic) -= 1 - totalTopicCounter(newTopic) += 1 - topics(i) = newTopic } - } topics @@ -397,7 +368,8 @@ object LDA { var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) corpus = updateCounter(corpus, numTopics).cache() corpus.vertices.count() - corpus.partitionBy(PartitionStrategy.EdgePartition1D) + // corpus.partitionBy(PartitionStrategy.EdgePartition1D) + corpus } private def initializeEdges( @@ -440,13 +412,22 @@ object LDA { * 每次采样的复杂度为: O(1) * 使用 Gibbs sampler 采样论文 Rethinking LDA: Why Priors Matter 公式(3) * \frac{{n}_{kw}^{-di}+{\beta }_{w}}{{n}_{k}^{-di}+\bar{\beta}} \frac{{n}_{kd} ^{-di}+ \bar{\alpha} \frac{{n}_{k}^{-di} + \acute{\alpha}}{\sum{n}_{k} +\bar{\acute{\alpha}}}}{\sum{n}_{kd}^{-di} +\bar{\alpha}} + * = t + w + d + * t 全局相关部分 + * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + * w 词相关部分 + * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + * d 文档和词的乘积 + * d = \frac{{n}_{kd}^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + * = \frac{{n}_{kd ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) } * 其中 * \bar{\beta}=\sum_{w}{\beta}_{w} * \bar{\alpha}=\sum_{k}{\alpha}_{k} * \bar{\acute{\alpha}}=\bar{\acute{\alpha}}=\sum_{k}\acute{\alpha} - * {n}_{kd} 是文档d中主题为k的tokens数 + * {n}_{kd} 文档d中主题为k的tokens数 * {n}_{kw} 词中主题为k的tokens数 - * {n}_{k} 是语料库中主题为k的tokens数 + * {n}_{k} 语料库中主题为k的tokens数 + * -di 减去当前token的主题 */ // scalastyle:on def tokenSampling( @@ -461,6 +442,9 @@ object LDA { val used = d.used val dSum = data(d.used - 1) val distSum = tSum + wSum + dSum + if (gen.nextDouble() < 1e-32) { + println(s"dSum: ${dSum / distSum}") + } val genSum = gen.nextDouble() * distSum if (genSum < dSum) { val dGenSum = gen.nextDouble() * dSum @@ -474,6 +458,10 @@ object LDA { } + /** + * 分解后的公式为 + * t = \frac{{\beta }_{w} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} ) } {({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + */ private def tDense( totalTopicCounter: BDV[Count], numTokens: Double, @@ -496,6 +484,10 @@ object LDA { (sum, t) } + /** + * 分解后的公式为 + * w = \frac{ {n}_{kw}^{-di} \bar{\alpha} ( {n}_{k}^{-di} + \acute{\alpha} )}{({n}_{k}^{-di}+\bar{\beta}) ({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + */ private def wSparse( totalTopicCounter: BDV[Count], termTopicCounter: VD, @@ -521,6 +513,11 @@ object LDA { (sum, w) } + /** + * 分解后的公式为 + * d = \frac{{n}_{kd} ^{-di}({\sum{n}_{k}^{-di} + \bar{\acute{\alpha}}})({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta})({\sum{n}_{k}^{-di} +\bar{\acute{\alpha}}})} + * = \frac{{n}_{kd} ^{-di}({n}_{kw}^{-di}+{\beta}_{w})}{({n}_{k}^{-di}+\bar{\beta}) } + */ private def dSparse( totalTopicCounter: BDV[Count], termTopicCounter: VD, @@ -532,25 +529,30 @@ object LDA { alphaAS: Double, beta: Double): BSV[Double] = { val numTopics = totalTopicCounter.length + val index = docTopicCounter.index + val data = docTopicCounter.data + val used = docTopicCounter.used + // val termSum = numTokens - 1D + alphaAS * numTopics val betaSum = numTerms * beta - val d = BSV.zeros[Double](numTopics) + val d = new Array[Double](used) var sum = 0.0 - docTopicCounter.activeIterator.foreach { t => - val topic = t._1 - val count = if (currentTopic == topic && t._2 != 1) t._2 - 1 else t._2 + for (i <- 0 until used) { + val topic = index(i) + var count: Double = data(i) + if (currentTopic == topic) count = count - 1.0 // val last = count * termSum * (termTopicCounter(topic) + beta) / // ((totalTopicCounter(topic) + betaSum) * termSum) val last = count * (termTopicCounter(topic) + beta) / (totalTopicCounter(topic) + betaSum) + sum += last - d(topic) = sum + d(i) = sum } - d + new BSV[Double](index, d, used, numTopics) } private def wordTable( - gen: Random, cacheMap: AppendOnlyMap[VertexId, SoftReference[(Double, Table)]], totalTopicCounter: BDV[Count], termTopicCounter: VD, @@ -561,13 +563,11 @@ object LDA { alphaAS: Double, beta: Double): (Double, Table) = { var w = cacheMap(termId) - if (w == null || w.get() == null || gen.nextDouble() < 1e-5) { - termTopicCounter.synchronized { - val t = wSparse(totalTopicCounter, termTopicCounter, - numTokens, numTerms, alpha, alphaAS, beta) - w = new SoftReference((t._1, generateAlias(t._2, t._1))) - cacheMap.update(termId, w) - } + if (w == null || w.get() == null) { + val t = wSparse(totalTopicCounter, termTopicCounter, + numTokens, numTerms, alpha, alphaAS, beta) + w = new SoftReference((t._1, generateAlias(t._2, t._1))) + cacheMap.update(termId, w) } w.get() } From e93f0edbeac919e9c25ea7c477d899bcc627a73d Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Wed, 11 Feb 2015 22:37:20 +0800 Subject: [PATCH 06/21] w -1 --- .../apache/spark/mllib/clustering/LDA.scala | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 6647db8d5b3bf..545b0a15ab83a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -317,7 +317,7 @@ object LDA { currentTopic, numTokens, numTerms, alpha, alphaAS, beta) val (wSum, w) = wordTable(wordTableCache, totalTopicCounter, termTopicCounter, termId, numTokens, numTerms, alpha, alphaAS, beta) - val newTopic = tokenSampling(gen, t, tSum, w, wSum, d) + val newTopic = tokenSampling(gen, t, tSum, w, termTopicCounter, wSum, d, currentTopic) if (newTopic != currentTopic) { topics(i) = newTopic @@ -435,8 +435,10 @@ object LDA { t: Table, tSum: Double, w: Table, + termTopicCounter: VD, wSum: Double, - d: BSV[Double]): Int = { + d: BSV[Double], + currentTopic: Int): Int = { val index = d.index val data = d.data val used = d.used @@ -451,7 +453,7 @@ object LDA { val pos = binarySearchInterval(data, dGenSum, 0, used, true) index(pos) } else if (genSum < (dSum + wSum)) { - sampleAlias(gen, w) + sampleSV(gen, w, termTopicCounter, currentTopic) } else { sampleAlias(gen, t) } @@ -572,6 +574,22 @@ object LDA { w.get() } + private def sampleSV(gen: Random, table: Table, sv: VD, currentTopic: Int): Int = { + val docTopic = sampleAlias(gen, table) + if (docTopic == currentTopic) { + val svCounter = sv(currentTopic) + // 这里的处理方法不太对. + // 如果采样到当前token的Topic这丢弃掉 + // svCounter == 1 && table.length > 1 采样到token的Topic 但包含其他token + // svCounter > 1 && gen.nextDouble() < 1.0 / svCounter 采样的Topic 有1/svCounter 概率属于当前token + if ((svCounter == 1 && table.length > 1) || + (svCounter > 1 && gen.nextDouble() < 1.0 / svCounter)) { + return sampleSV(gen, table, sv, currentTopic) + } + } + docTopic + } + } class LDAKryoRegistrator extends KryoRegistrator { From a4c0f5cfa956878c1bacdc2254cbb32f22ad9c47 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 12 Feb 2015 14:27:47 +0800 Subject: [PATCH 07/21] =?UTF-8?q?=E4=BB=A5term=20=E5=88=87=E5=88=86partiti?= =?UTF-8?q?on?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/main/scala/org/apache/spark/mllib/clustering/LDA.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 545b0a15ab83a..f457210bd3b3d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -366,9 +366,9 @@ object LDA { } }) var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) + corpus = corpus.partitionBy(PartitionStrategy.EdgePartition1D) corpus = updateCounter(corpus, numTopics).cache() corpus.vertices.count() - // corpus.partitionBy(PartitionStrategy.EdgePartition1D) corpus } From df5002e4582f5d92688945191963839bf099d7fc Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 13 Feb 2015 10:56:35 +0800 Subject: [PATCH 08/21] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E4=B8=80=E4=BA=9Bbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 24 ++++++++++++------- .../spark/mllib/clustering/LDAModel.scala | 1 + 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index f457210bd3b3d..6b507090d5934 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -339,7 +339,9 @@ object LDA { } ctx.sendToDst(vector) ctx.sendToSrc(vector) - }, _ + _, TripletFields.EdgeOnly) + }, _ + _, TripletFields.EdgeOnly).mapValues(t => { + t.compact(); t + }) graph.joinVertices(newCounter)((_, _, nc) => nc) } @@ -365,10 +367,13 @@ object LDA { initializeEdges(gen, bsv, docId, numTopics, model) } }) + edges.persist(storageLevel) var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) corpus = corpus.partitionBy(PartitionStrategy.EdgePartition1D) corpus = updateCounter(corpus, numTopics).cache() corpus.vertices.count() + corpus.edges.count() + edges.unpersist() corpus } @@ -380,12 +385,13 @@ object LDA { computedModel: LDAModel = null): Array[Edge[ED]] = { assert(docId >= 0) val newDocId: DocId = -(docId + 1L) - if (computedModel == null) { - doc.activeIterator.map { case (termId, counter) => - val ev = (0 until counter).map { i => - gen.nextInt(numTopics) - }.toArray - Edge(termId, newDocId, ev) + val edges = if (computedModel == null) { + doc.activeIterator.filter(_._2 > 0).map { case (termId, counter) => + val topics = new Array[Int](counter) + for (i <- 0 until counter) { + topics(i) = gen.nextInt(numTopics) + } + Edge(termId, newDocId, topics) }.toArray } else { @@ -397,13 +403,15 @@ object LDA { docTopicCounter = computedModel.sampleTokens(docTopicCounter, tokens, topics) } - doc.activeIterator.map { case (term, counter) => + doc.activeIterator.filter(_._2 > 0).map { case (term, counter) => val ev = topics.zipWithIndex.filter { case (topic, offset) => term == tokens(offset) }.map(_._1) Edge(term, newDocId, ev) }.toArray } + assert(edges.length > 0) + edges } // scalastyle:off diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 05fd49f5ae162..a9b7d363665ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -243,6 +243,7 @@ class LDAModel private[mllib]( alpha: Double, alphaAS: Double, beta: Double): (Double, Table) = { + if (termTopicCounter.used == 0) return (0.0, null) var w = cacheMap(termId) if (w == null || w.get() == null) { val t = wSparse(totalTopicCounter, termTopicCounter, From 7e3af5eccd0664c92acfbb9820672271704d9205 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 27 Feb 2015 10:01:04 +0800 Subject: [PATCH 09/21] =?UTF-8?q?partition=20strategy=20=E8=AE=BE=E4=B8=BA?= =?UTF-8?q?=20EdgePartition2D(=E8=BF=99=E6=A0=B7=E4=BC=9A=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E5=BE=88=E5=A4=9A=E7=BD=91=E7=BB=9C=E6=B5=81=E9=87=8F?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../main/scala/org/apache/spark/mllib/clustering/LDA.scala | 5 +++-- .../scala/org/apache/spark/mllib/clustering/LDAModel.scala | 5 +---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 6b507090d5934..342bab4218a19 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -340,7 +340,8 @@ object LDA { ctx.sendToDst(vector) ctx.sendToSrc(vector) }, _ + _, TripletFields.EdgeOnly).mapValues(t => { - t.compact(); t + t.compact() + t }) graph.joinVertices(newCounter)((_, _, nc) => nc) } @@ -369,7 +370,7 @@ object LDA { }) edges.persist(storageLevel) var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) - corpus = corpus.partitionBy(PartitionStrategy.EdgePartition1D) + corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D) corpus = updateCounter(corpus, numTopics).cache() corpus.vertices.count() corpus.edges.count() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index a9b7d363665ae..c5ad67939b73a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -21,15 +21,12 @@ import java.lang.ref.SoftReference import java.util.Random import java.util.{PriorityQueue => JPriorityQueue} -import org.apache.spark.util.random.XORShiftRandom - -import scala.reflect.ClassTag - import breeze.linalg.{Vector => BV, DenseVector => BDV, SparseVector => BSV, sum => brzSum, norm => brzNorm} import org.apache.spark.mllib.linalg.{Vectors, DenseVector => SDV, SparseVector => SSV} import org.apache.spark.util.collection.AppendOnlyMap +import org.apache.spark.util.random.XORShiftRandom import LDAUtils._ From 3530bab219ec12324cc9632692fc199e3718731a Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 27 Feb 2015 18:31:57 +0800 Subject: [PATCH 10/21] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20=20Degree-based=20ha?= =?UTF-8?q?shing=20partition=20strategy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 60 ++++++++++++++----- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 342bab4218a19..5a1459c4b98e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -24,7 +24,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, sum => brzSum} import org.apache.spark.broadcast.Broadcast import org.apache.spark.graphx._ -import org.apache.spark.Logging +import org.apache.spark.{HashPartitioner, Logging, Partitioner} import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV} import org.apache.spark.rdd.RDD @@ -302,6 +302,7 @@ object LDA { val gen = new XORShiftRandom(parts * innerIter + pid) val wordTableCache = new AppendOnlyMap[VertexId, SoftReference[(Double, Table)]]() val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) + val dData = new Array[Double](numTopics.toInt) val t = generateAlias(dv._2, dv._1) val tSum = dv._1 iter.map { @@ -313,11 +314,12 @@ object LDA { val topics = triplet.attr for (i <- 0 until topics.length) { val currentTopic = topics(i) - val d = dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, + dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData, currentTopic, numTokens, numTerms, alpha, alphaAS, beta) val (wSum, w) = wordTable(wordTableCache, totalTopicCounter, termTopicCounter, termId, numTokens, numTerms, alpha, alphaAS, beta) - val newTopic = tokenSampling(gen, t, tSum, w, termTopicCounter, wSum, d, currentTopic) + val newTopic = tokenSampling(gen, t, tSum, w, termTopicCounter, wSum, + docTopicCounter, dData, currentTopic) if (newTopic != currentTopic) { topics(i) = newTopic @@ -370,7 +372,14 @@ object LDA { }) edges.persist(storageLevel) var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) - corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D) + val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0)} + val numPartitions = edges.partitions.size + val partitionStrategy = new DBHPartitioner(numPartitions) + val newEdges = degrees.triplets.map { e => + (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr)) + }.partitionBy(new HashPartitioner(numPartitions)).map(_._2) + corpus = Graph.fromEdges(newEdges, null, storageLevel, storageLevel) + // corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D) corpus = updateCounter(corpus, numTopics).cache() corpus.vertices.count() corpus.edges.count() @@ -446,12 +455,12 @@ object LDA { w: Table, termTopicCounter: VD, wSum: Double, - d: BSV[Double], + docTopicCounter: VD, + dData: Array[Double], currentTopic: Int): Int = { - val index = d.index - val data = d.data - val used = d.used - val dSum = data(d.used - 1) + val index = docTopicCounter.index + val used = docTopicCounter.used + val dSum = dData(docTopicCounter.used - 1) val distSum = tSum + wSum + dSum if (gen.nextDouble() < 1e-32) { println(s"dSum: ${dSum / distSum}") @@ -459,7 +468,7 @@ object LDA { val genSum = gen.nextDouble() * distSum if (genSum < dSum) { val dGenSum = gen.nextDouble() * dSum - val pos = binarySearchInterval(data, dGenSum, 0, used, true) + val pos = binarySearchInterval(dData, dGenSum, 0, used, true) index(pos) } else if (genSum < (dSum + wSum)) { sampleSV(gen, w, termTopicCounter, currentTopic) @@ -533,20 +542,19 @@ object LDA { totalTopicCounter: BDV[Count], termTopicCounter: VD, docTopicCounter: VD, + d: Array[Double], currentTopic: Int, numTokens: Double, numTerms: Double, alpha: Double, alphaAS: Double, - beta: Double): BSV[Double] = { - val numTopics = totalTopicCounter.length + beta: Double): Unit = { val index = docTopicCounter.index val data = docTopicCounter.data val used = docTopicCounter.used // val termSum = numTokens - 1D + alphaAS * numTopics val betaSum = numTerms * beta - val d = new Array[Double](used) var sum = 0.0 for (i <- 0 until used) { val topic = index(i) @@ -560,7 +568,6 @@ object LDA { sum += last d(i) = sum } - new BSV[Double](index, d, used, numTopics) } private def wordTable( @@ -601,6 +608,31 @@ object LDA { } +private class DBHPartitioner(partitions: Int) extends Partitioner { + val mixingPrime: Long = 1125899906842597L + + def numPartitions = partitions + + def getPartition(key: Any): Int = { + val edge = key.asInstanceOf[EdgeTriplet[Int, ED]] + val idx = math.min(edge.srcAttr, edge.dstAttr) + getPartition(idx) + } + + def getPartition(src: Int): PartitionID = { + (math.abs(src * mixingPrime) % partitions).toInt + } + + override def equals(other: Any): Boolean = other match { + case h: DBHPartitioner => + h.numPartitions == numPartitions + case _ => + false + } + + override def hashCode: Int = numPartitions +} + class LDAKryoRegistrator extends KryoRegistrator { def registerClasses(kryo: com.esotericsoftware.kryo.Kryo) { val gkr = new GraphKryoRegistrator From 60ae2d50add6069f99056f7d1cff0aaa8ff2d75d Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 28 Feb 2015 18:23:02 +0800 Subject: [PATCH 11/21] =?UTF-8?q?=E6=81=A2=E5=A4=8D=E4=BD=BF=E7=94=A8parti?= =?UTF-8?q?tion=20strategy=20EdgePartition2D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 5a1459c4b98e0..4370935a47227 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -372,14 +372,16 @@ object LDA { }) edges.persist(storageLevel) var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) - val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0)} - val numPartitions = edges.partitions.size - val partitionStrategy = new DBHPartitioner(numPartitions) - val newEdges = degrees.triplets.map { e => - (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr)) - }.partitionBy(new HashPartitioner(numPartitions)).map(_._2) - corpus = Graph.fromEdges(newEdges, null, storageLevel, storageLevel) - // corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D) + // degree-based hashing + // val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0)} + // val numPartitions = edges.partitions.size + // val partitionStrategy = new DBHPartitioner(numPartitions) + // val newEdges = degrees.triplets.map { e => + // (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr)) + // }.partitionBy(new HashPartitioner(numPartitions)).map(_._2) + // corpus = Graph.fromEdges(newEdges, null, storageLevel, storageLevel) + // end degree-based hashing + corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D) corpus = updateCounter(corpus, numTopics).cache() corpus.vertices.count() corpus.edges.count() From ad169a44a5d4e21d262c3cd61473bd5578fac62c Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 5 Mar 2015 15:24:45 +0800 Subject: [PATCH 12/21] =?UTF-8?q?=E4=BC=98=E5=8C=96AliasTable=E5=86=85?= =?UTF-8?q?=E5=AD=98=E5=8D=A0=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 77 +++++++++++++------ .../spark/mllib/clustering/LDAModel.scala | 44 ++++------- 2 files changed, 68 insertions(+), 53 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 4370935a47227..fca3bcab7bd10 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -38,13 +38,13 @@ import LDA._ import LDAUtils._ class LDA private[mllib]( - @transient var corpus: Graph[VD, ED], - val numTopics: Int, - val numTerms: Int, - val alpha: Double, - val beta: Double, - val alphaAS: Double, - @transient val storageLevel: StorageLevel) + @transient private var corpus: Graph[VD, ED], + private val numTopics: Int, + private val numTerms: Int, + private var alpha: Double, + private var beta: Double, + private var alphaAS: Double, + private var storageLevel: StorageLevel) extends Serializable with Logging { def this(docs: RDD[(DocId, SSV)], @@ -68,10 +68,37 @@ class LDA private[mllib]( * 语料库总的词数(包含重复) */ val numTokens = corpus.edges.map(e => e.attr.size.toDouble).sum().toLong + + def setAlpha(alpha: Double): this.type = { + this.alpha = alpha + this + } + + def setBeta(beta: Double): this.type = { + this.beta = beta + this + } + + def setAlphaAS(alphaAS: Double): this.type = { + this.alphaAS = alphaAS + this + } + + def setStorageLevel(newStorageLevel: StorageLevel): this.type = { + this.storageLevel = newStorageLevel + this + } + + def setSeed(newSeed: Int): this.type = { + this.seed = newSeed + this + } + + def getCorpus = corpus + // scalastyle:on - @transient private val sc = corpus.vertices.context - @transient private val seed = new Random().nextInt() + @transient private var seed = new Random().nextInt() @transient private var innerIter = 1 @transient private var totalTopicCounter: BDV[Count] = collectTotalTopicCounter(corpus) @@ -80,7 +107,7 @@ class LDA private[mllib]( private def docVertices = corpus.vertices.filter(t => t._1 < 0) private def checkpoint(): Unit = { - if (innerIter % 10 == 0 && sc.getCheckpointDir.isDefined) { + if (innerIter % 10 == 0 && corpus.edges.sparkContext.getCheckpointDir.isDefined) { val edges = corpus.edges.map(t => t) edges.checkpoint() val newCorpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) @@ -373,15 +400,15 @@ object LDA { edges.persist(storageLevel) var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) // degree-based hashing - // val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0)} - // val numPartitions = edges.partitions.size - // val partitionStrategy = new DBHPartitioner(numPartitions) - // val newEdges = degrees.triplets.map { e => - // (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr)) - // }.partitionBy(new HashPartitioner(numPartitions)).map(_._2) - // corpus = Graph.fromEdges(newEdges, null, storageLevel, storageLevel) + val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0)} + val numPartitions = edges.partitions.size + val partitionStrategy = new DBHPartitioner(numPartitions) + val newEdges = degrees.triplets.map { e => + (partitionStrategy.getPartition(e), Edge(e.srcId, e.dstId, e.attr)) + }.partitionBy(new HashPartitioner(numPartitions)).map(_._2) + corpus = Graph.fromEdges(newEdges, null, storageLevel, storageLevel) // end degree-based hashing - corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D) + // corpus = corpus.partitionBy(PartitionStrategy.EdgePartition2D) corpus = updateCounter(corpus, numTopics).cache() corpus.vertices.count() corpus.edges.count() @@ -464,9 +491,6 @@ object LDA { val used = docTopicCounter.used val dSum = dData(docTopicCounter.used - 1) val distSum = tSum + wSum + dSum - if (gen.nextDouble() < 1e-32) { - println(s"dSum: ${dSum / distSum}") - } val genSum = gen.nextDouble() * distSum if (genSum < dSum) { val dGenSum = gen.nextDouble() * dSum @@ -600,7 +624,7 @@ object LDA { // 如果采样到当前token的Topic这丢弃掉 // svCounter == 1 && table.length > 1 采样到token的Topic 但包含其他token // svCounter > 1 && gen.nextDouble() < 1.0 / svCounter 采样的Topic 有1/svCounter 概率属于当前token - if ((svCounter == 1 && table.length > 1) || + if ((svCounter == 1 && table._1.length > 1) || (svCounter > 1 && gen.nextDouble() < 1.0 / svCounter)) { return sampleSV(gen, table, sv, currentTopic) } @@ -610,6 +634,11 @@ object LDA { } +/** + * Degree-Based Hashing, the paper: + * http://nips.cc/Conferences/2014/Program/event.php?ID=4569 + * @param partitions + */ private class DBHPartitioner(partitions: Int) extends Partitioner { val mixingPrime: Long = 1125899906842597L @@ -621,8 +650,8 @@ private class DBHPartitioner(partitions: Int) extends Partitioner { getPartition(idx) } - def getPartition(src: Int): PartitionID = { - (math.abs(src * mixingPrime) % partitions).toInt + def getPartition(idx: Int): PartitionID = { + (math.abs(idx * mixingPrime) % partitions).toInt } override def equals(other: Any): Boolean = other match { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index c5ad67939b73a..6b477fcc1e43d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -284,7 +284,7 @@ object LDAModel { private[mllib] object LDAUtils { - type Table = Array[(Int, Int, Double)] + type Table = (Array[Int], Array[Int], Array[Double]) @transient private lazy val tableOrdering = new scala.math.Ordering[(Int, Double)] { override def compare(x: (Int, Double), y: (Int, Double)): Int = { @@ -305,7 +305,7 @@ private[mllib] object LDAUtils { used: Int, sum: Double): Table = { val pMean = 1.0 / used - val table = new Table(used) + val table = (new Array[Int](used), new Array[Int](used), new Array[Double](used)) val lq = new JPriorityQueue[(Int, Double)](used, tableOrdering) val hq = new JPriorityQueue[(Int, Double)](used, tableReverseOrdering) @@ -324,7 +324,9 @@ private[mllib] object LDAUtils { while (!lq.isEmpty & !hq.isEmpty) { val (i, pi) = lq.remove() val (h, ph) = hq.remove() - table(offset) = (i, h, pi) + table._1(offset) = i + table._2(offset) = h + table._3(offset) = pi val pd = ph - (pMean - pi) if (pd >= pMean) { hq.add((h, pd)) @@ -336,47 +338,31 @@ private[mllib] object LDAUtils { while (!hq.isEmpty) { val (h, ph) = hq.remove() assert(ph - pMean < 1e-8) - table(offset) = (h, h, ph) + table._1(offset) = h + table._2(offset) = h + table._3(offset) = ph offset += 1 } while (!lq.isEmpty) { val (i, pi) = lq.remove() assert(pMean - pi < 1e-8) - table(offset) = (i, i, pi) + table._1(offset) = i + table._2(offset) = i + table._3(offset) = pi offset += 1 } - - // 测试代码 随即抽样一个样本验证其概率 - // val (di, dp) = probs(Utils.random.nextInt(used)) - // val ds = table.map { t => - // if (t._1 == di) { - // if (t._2 == t._1) { - // pMean - // } else { - // t._3 - // } - // } else if (t._2 == di) { - // pMean - t._3 - // } else { - // 0.0 - // } - // }.sum - // assert((ds - dp).abs < 1e-4) - table } def sampleAlias(gen: Random, table: Table): Int = { - val l = table.length + val l = table._1.length val bin = gen.nextInt(l) - val i = table(bin)._1 - val h = table(bin)._2 - val p = table(bin)._3 + val p = table._3(bin) if (l * p > gen.nextDouble()) { - i + table._1(bin) } else { - h + table._2(bin) } } From b001a333ef724e6b4d457db2ac3c64016743855d Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 12 Mar 2015 23:04:20 +0800 Subject: [PATCH 13/21] =?UTF-8?q?=E7=A7=BB=E9=99=A4burnIn=E5=8F=82?= =?UTF-8?q?=E6=95=B0=20=E4=BC=98=E5=8C=96=E6=94=B6=E6=95=9B=E9=80=9F?= =?UTF-8?q?=E5=BA=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 67 +++++++++++-------- 1 file changed, 39 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index fca3bcab7bd10..2c4ed6b9b4f5b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -142,10 +142,10 @@ class LDA private[mllib]( innerIter += 1 } - def saveModel(burnInIter: Int): LDAModel = { + def saveModel(iter: Int = 1): LDAModel = { var termTopicCounter: RDD[(VertexId, VD)] = null - for (iter <- 1 to burnInIter) { - logInfo(s"Save TopicModel (Iteration $iter/$burnInIter)") + for (iter <- 1 to iter) { + logInfo(s"Save TopicModel (Iteration $iter/$iter)") var previousTermTopicCounter = termTopicCounter gibbsSampling() val newTermTopicCounter = termVertices @@ -154,7 +154,7 @@ class LDA private[mllib]( (term, a :+ b) }).getOrElse(newTermTopicCounter) - termTopicCounter.cache().count() + termTopicCounter.persist(storageLevel).count() Option(previousTermTopicCounter).foreach(_.unpersist()) previousTermTopicCounter = termTopicCounter } @@ -162,9 +162,9 @@ class LDA private[mllib]( termTopicCounter.collect().foreach { case (term, counter) => model.merge(term.toInt, counter) } - model.gtc :/= burnInIter.toDouble + model.gtc :/= iter.toDouble model.ttc.foreach { ttc => - ttc :/= burnInIter.toDouble + ttc :/= iter.toDouble ttc.compact() } model @@ -281,26 +281,20 @@ object LDA { def train(docs: RDD[(DocId, SSV)], numTopics: Int = 2048, totalIter: Int = 150, - burnIn: Int = 5, - alpha: Double = 0.1, + alpha: Double = 0.01, beta: Double = 0.01, alphaAS: Double = 0.1): LDAModel = { - require(totalIter > burnIn, "totalIter is less than burnIn") require(totalIter > 0, "totalIter is less than 0") - require(burnIn > 0, "burnIn is less than 0") val topicModeling = new LDA(docs, numTopics, alpha, beta, alphaAS) - topicModeling.runGibbsSampling(totalIter - burnIn) - topicModeling.saveModel(burnIn) + topicModeling.runGibbsSampling(totalIter - 1) + topicModeling.saveModel(1) } def incrementalTrain(docs: RDD[(DocId, SSV)], computedModel: LDAModel, alphaAS: Double = 1, - totalIter: Int = 150, - burnIn: Int = 5): LDAModel = { - require(totalIter > burnIn, "totalIter is less than burnIn") + totalIter: Int = 150): LDAModel = { require(totalIter > 0, "totalIter is less than 0") - require(burnIn > 0, "burnIn is less than 0") val numTopics = computedModel.ttc.size val alpha = computedModel.alpha val beta = computedModel.beta @@ -309,8 +303,8 @@ object LDA { val topicModeling = new LDA(docs, numTopics, alpha, beta, alphaAS, computedModel = broadcastModel) broadcastModel.unpersist() - topicModeling.runGibbsSampling(totalIter - burnIn) - topicModeling.saveModel(burnIn) + topicModeling.runGibbsSampling(totalIter - 1) + topicModeling.saveModel(1) } private[mllib] def sampleTokens( @@ -341,15 +335,32 @@ object LDA { val topics = triplet.attr for (i <- 0 until topics.length) { val currentTopic = topics(i) - dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData, - currentTopic, numTokens, numTerms, alpha, alphaAS, beta) - val (wSum, w) = wordTable(wordTableCache, totalTopicCounter, - termTopicCounter, termId, numTokens, numTerms, alpha, alphaAS, beta) - val newTopic = tokenSampling(gen, t, tSum, w, termTopicCounter, wSum, - docTopicCounter, dData, currentTopic) - - if (newTopic != currentTopic) { - topics(i) = newTopic + docTopicCounter.synchronized { + termTopicCounter.synchronized { + dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData, + currentTopic, numTokens, numTerms, alpha, alphaAS, beta) + val (wSum, w) = wordTable(wordTableCache, totalTopicCounter, + termTopicCounter, termId, numTokens, numTerms, alpha, alphaAS, beta) + val newTopic = tokenSampling(gen, t, tSum, w, termTopicCounter, wSum, + docTopicCounter, dData, currentTopic) + + if (newTopic != currentTopic) { + topics(i) = newTopic + docTopicCounter(currentTopic) -= 1 + docTopicCounter(newTopic) += 1 + if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact() + + termTopicCounter(currentTopic) -= 1 + termTopicCounter(newTopic) += 1 + if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact() + + totalTopicCounter(currentTopic) -= 1 + totalTopicCounter(newTopic) += 1 + if (gen.nextDouble() < 1e-5) { + wordTableCache.update(termId, null) + } + } + } } } @@ -664,7 +675,7 @@ private class DBHPartitioner(partitions: Int) extends Partitioner { override def hashCode: Int = numPartitions } -class LDAKryoRegistrator extends KryoRegistrator { +private[mllib] class LDAKryoRegistrator extends KryoRegistrator { def registerClasses(kryo: com.esotericsoftware.kryo.Kryo) { val gkr = new GraphKryoRegistrator gkr.registerClasses(kryo) From 17fe0ab6d9a5c296b9bc5a95ed49450294efdd7c Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 13 Mar 2015 16:49:04 +0800 Subject: [PATCH 14/21] =?UTF-8?q?=E7=A7=BB=E9=99=A4counterCorpus.vertices.?= =?UTF-8?q?count()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../main/scala/org/apache/spark/mllib/clustering/LDA.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 2c4ed6b9b4f5b..c7e5a6dd03002 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -128,7 +128,7 @@ class LDA private[mllib]( val counterCorpus = updateCounter(sampledCorpus, numTopics) counterCorpus.persist(storageLevel) - counterCorpus.vertices.count() + // counterCorpus.vertices.count() counterCorpus.edges.count() totalTopicCounter = collectTotalTopicCounter(counterCorpus) @@ -172,7 +172,7 @@ class LDA private[mllib]( def runGibbsSampling(iterations: Int): Unit = { for (iter <- 1 to iterations) { - // println(s"perplexity $iter: ${perplexity()}") + // println(s"perplexity $iter: ${perplexity}") logInfo(s"Start Gibbs sampling (Iteration $iter/$iterations)") gibbsSampling() } From f3f4da6369a4bc9b41ff0d6854e8a7121de5e268 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Fri, 13 Mar 2015 20:40:32 +0800 Subject: [PATCH 15/21] =?UTF-8?q?fix=20wordTable=E6=96=B9=E6=B3=95?= =?UTF-8?q?=E8=BF=94=E5=9B=9Enull?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index c7e5a6dd03002..ee7b56e302aa5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -339,8 +339,9 @@ object LDA { termTopicCounter.synchronized { dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData, currentTopic, numTokens, numTerms, alpha, alphaAS, beta) - val (wSum, w) = wordTable(wordTableCache, totalTopicCounter, - termTopicCounter, termId, numTokens, numTerms, alpha, alphaAS, beta) + val (wSum, w) = wordTable(x => x == null || x.get() == null || gen.nextDouble() < 1e-4, + wordTableCache, totalTopicCounter, termTopicCounter, + termId, numTokens, numTerms, alpha, alphaAS, beta) val newTopic = tokenSampling(gen, t, tSum, w, termTopicCounter, wSum, docTopicCounter, dData, currentTopic) @@ -356,9 +357,6 @@ object LDA { totalTopicCounter(currentTopic) -= 1 totalTopicCounter(newTopic) += 1 - if (gen.nextDouble() < 1e-5) { - wordTableCache.update(termId, null) - } } } } @@ -608,6 +606,7 @@ object LDA { } private def wordTable( + updateFunc: SoftReference[(Double, Table)] => Boolean, cacheMap: AppendOnlyMap[VertexId, SoftReference[(Double, Table)]], totalTopicCounter: BDV[Count], termTopicCounter: VD, @@ -617,14 +616,16 @@ object LDA { alpha: Double, alphaAS: Double, beta: Double): (Double, Table) = { - var w = cacheMap(termId) - if (w == null || w.get() == null) { - val t = wSparse(totalTopicCounter, termTopicCounter, + val cacheW = cacheMap(termId) + if (!updateFunc(cacheW)) { + cacheW.get + } else { + val sv = wSparse(totalTopicCounter, termTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) - w = new SoftReference((t._1, generateAlias(t._2, t._1))) - cacheMap.update(termId, w) + val w = (sv._1, generateAlias(sv._2, sv._1)) + cacheMap.update(termId, new SoftReference(w)) + w } - w.get() } private def sampleSV(gen: Random, table: Table, sv: VD, currentTopic: Int): Int = { From 7664fae98f343163ce959053a7ad8408a8442613 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 14 Mar 2015 00:45:05 +0800 Subject: [PATCH 16/21] =?UTF-8?q?=E6=B7=BB=E5=8A=A0trainAndSaveModel?= =?UTF-8?q?=E6=96=B9=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 31 ++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index ee7b56e302aa5..6c9d0829fedab 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -290,6 +290,35 @@ object LDA { topicModeling.saveModel(1) } + /** + * topicID termID+1:counter termID+1:counter .. + */ + def trainAndSaveModel( + docs: RDD[(DocId, SSV)], + dir: String, + numTopics: Int = 2048, + totalIter: Int = 150, + alpha: Double = 0.01, + beta: Double = 0.01, + alphaAS: Double = 0.1): Unit = { + import org.apache.spark.mllib.regression.LabeledPoint + import org.apache.spark.mllib.util.MLUtils + import org.apache.spark.mllib.linalg.Vectors + val lda = new LDA(docs, numTopics, alpha, beta, alphaAS) + val numTerms = lda.numTerms + lda.runGibbsSampling(totalIter) + val rdd = lda.termVertices.flatMap { case (termId, counter) => + counter.activeIterator.map { case (topic, cn) => + val sv = BSV.zeros[Double](numTerms) + sv(termId.toInt) = cn.toDouble + (topic, sv) + } + }.reduceByKey { (a, b) => a + b}.map { case (topic, sv) => + LabeledPoint(topic.toDouble, Vectors.fromBreeze(sv)) + } + MLUtils.saveAsLibSVMFile(rdd, dir) + } + def incrementalTrain(docs: RDD[(DocId, SSV)], computedModel: LDAModel, alphaAS: Double = 1, @@ -486,7 +515,7 @@ object LDA { * -di 减去当前token的主题 */ // scalastyle:on - def tokenSampling( + private def tokenSampling( gen: Random, t: Table, tSum: Double, From 8f44caf2ff806ffebbb90a0fb5a55668e689264b Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Sat, 14 Mar 2015 17:14:11 +0800 Subject: [PATCH 17/21] =?UTF-8?q?=E4=BC=98=E5=8C=96updateCounter,=20?= =?UTF-8?q?=E4=BC=98=E5=8C=96=20adjustment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 6c9d0829fedab..53434b83591fb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -378,11 +378,11 @@ object LDA { topics(i) = newTopic docTopicCounter(currentTopic) -= 1 docTopicCounter(newTopic) += 1 - if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact() + // if (docTopicCounter(currentTopic) == 0) docTopicCounter.compact() termTopicCounter(currentTopic) -= 1 termTopicCounter(newTopic) += 1 - if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact() + // if (termTopicCounter(currentTopic) == 0) termTopicCounter.compact() totalTopicCounter(currentTopic) -= 1 totalTopicCounter(newTopic) += 1 @@ -406,9 +406,17 @@ object LDA { } ctx.sendToDst(vector) ctx.sendToSrc(vector) - }, _ + _, TripletFields.EdgeOnly).mapValues(t => { - t.compact() - t + }, _ + _, TripletFields.EdgeOnly).mapValues(v => { + val used = v.used + if (v.index.length == used) { + v + } else { + val index = new Array[Int](used) + val data = new Array[Count](used) + Array.copy(v.index, 0, index, 0, used) + Array.copy(v.data, 0, data, 0, used) + new BSV[Count](index, data, numTopics) + } }) graph.joinVertices(newCounter)((_, _, nc) => nc) } @@ -586,7 +594,7 @@ object LDA { val betaSum = numTerms * beta val w = BSV.zeros[Double](numTopics) var sum = 0.0 - termTopicCounter.activeIterator.foreach { t => + termTopicCounter.activeIterator.filter(_._2 > 0).foreach { t => val topic = t._1 val count = t._2 val last = count * alphaSum * (totalTopicCounter(topic) + alphaAS) / @@ -622,12 +630,12 @@ object LDA { var sum = 0.0 for (i <- 0 until used) { val topic = index(i) - var count: Double = data(i) - if (currentTopic == topic) count = count - 1.0 - // val last = count * termSum * (termTopicCounter(topic) + beta) / - // ((totalTopicCounter(topic) + betaSum) * termSum) - val last = count * (termTopicCounter(topic) + beta) / - (totalTopicCounter(topic) + betaSum) + val count = data(i) + val adjustment = if (currentTopic == topic) -1D else 0 + val last = (count + adjustment) * (termTopicCounter(topic) + adjustment + beta) / + (totalTopicCounter(topic) + adjustment + betaSum) + // val lastD = (count + adjustment) * termSum * (termTopicCounter(topic) + adjustment + beta) / + // ((totalTopicCounter(topic) + adjustment + betaSum) * termSum) sum += last d(i) = sum From 8e3561427a2b12abbe1b174e5f0a93afcbec8d1b Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 19 Mar 2015 12:56:10 +0800 Subject: [PATCH 18/21] =?UTF-8?q?=E4=BC=98=E5=8C=96updateCounter,checkpoin?= =?UTF-8?q?t=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../apache/spark/mllib/clustering/LDA.scala | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 53434b83591fb..70564fa6250b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -24,6 +24,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, sum => brzSum} import org.apache.spark.broadcast.Broadcast import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl import org.apache.spark.{HashPartitioner, Logging, Partitioner} import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} import org.apache.spark.mllib.linalg.{DenseVector => SDV, SparseVector => SSV, Vector => SV} @@ -106,12 +107,9 @@ class LDA private[mllib]( private def docVertices = corpus.vertices.filter(t => t._1 < 0) - private def checkpoint(): Unit = { + private def checkpoint(corpus: Graph[VD, ED]): Unit = { if (innerIter % 10 == 0 && corpus.edges.sparkContext.getCheckpointDir.isDefined) { - val edges = corpus.edges.map(t => t) - edges.checkpoint() - val newCorpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) - corpus = updateCounter(newCorpus, numTopics).persist(storageLevel) + corpus.checkpoint() } } @@ -127,6 +125,7 @@ class LDA private[mllib]( sampledCorpus.persist(storageLevel) val counterCorpus = updateCounter(sampledCorpus, numTopics) + checkpoint(counterCorpus) counterCorpus.persist(storageLevel) // counterCorpus.vertices.count() counterCorpus.edges.count() @@ -137,8 +136,6 @@ class LDA private[mllib]( sampledCorpus.edges.unpersist(false) sampledCorpus.vertices.unpersist(false) corpus = counterCorpus - - checkpoint() innerIter += 1 } @@ -172,9 +169,13 @@ class LDA private[mllib]( def runGibbsSampling(iterations: Int): Unit = { for (iter <- 1 to iterations) { - // println(s"perplexity $iter: ${perplexity}") - logInfo(s"Start Gibbs sampling (Iteration $iter/$iterations)") + // logInfo(s"Gibbs samplin perplexity $iter: ${perplexity}") + // logInfo(s"Gibbs sampling (Iteration $iter/$iterations)") + // val startedAt = System.nanoTime() gibbsSampling() + // val endAt = System.nanoTime() + // val useTime = (endAt - startedAt) / 1e9 + // logInfo(s"Gibbs sampling use time $iter: $useTime") } } @@ -313,7 +314,7 @@ object LDA { sv(termId.toInt) = cn.toDouble (topic, sv) } - }.reduceByKey { (a, b) => a + b}.map { case (topic, sv) => + }.reduceByKey { (a, b) => a + b }.map { case (topic, sv) => LabeledPoint(topic.toDouble, Vectors.fromBreeze(sv)) } MLUtils.saveAsLibSVMFile(rdd, dir) @@ -418,7 +419,8 @@ object LDA { new BSV[Count](index, data, numTopics) } }) - graph.joinVertices(newCounter)((_, _, nc) => nc) + // GraphImpl.fromExistingRDDs(newCounter, graph.edges) + GraphImpl(newCounter, graph.edges) } private def collectGlobalCounter(graph: Graph[VD, ED], numTopics: Int): BDV[Count] = { @@ -446,7 +448,7 @@ object LDA { edges.persist(storageLevel) var corpus: Graph[VD, ED] = Graph.fromEdges(edges, null, storageLevel, storageLevel) // degree-based hashing - val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0)} + val degrees = corpus.outerJoinVertices(corpus.degrees) { (vid, data, deg) => deg.getOrElse(0) } val numPartitions = edges.partitions.size val partitionStrategy = new DBHPartitioner(numPartitions) val newEdges = degrees.triplets.map { e => From fabd52047fdd2fcd796ec2fca461e9f40a26e83e Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 2 Apr 2015 13:23:20 +0800 Subject: [PATCH 19/21] Fix DBHPartitioner bug --- .../scala/org/apache/spark/mllib/clustering/LDA.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 70564fa6250b1..0adb5536aed64 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -697,10 +697,13 @@ private class DBHPartitioner(partitions: Int) extends Partitioner { def getPartition(key: Any): Int = { val edge = key.asInstanceOf[EdgeTriplet[Int, ED]] - val idx = math.min(edge.srcAttr, edge.dstAttr) - getPartition(idx) + val srcDeg = edge.srcAttr + val dstDeg = edge.dstAttr + val srcId = edge.srcId + val dstId = edge.dstId + val minId = if (srcDeg < dstDeg) srcId else dstId + getPartition(minId) } - def getPartition(idx: Int): PartitionID = { (math.abs(idx * mixingPrime) % partitions).toInt } From bca468a34b25e2a4e782d02e789f360976d0e177 Mon Sep 17 00:00:00 2001 From: GuoQiang Li Date: Thu, 9 Apr 2015 22:28:22 +0800 Subject: [PATCH 20/21] Fix DBH bug --- .../main/scala/org/apache/spark/mllib/clustering/LDA.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 0adb5536aed64..804cef4c1bf57 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -704,10 +704,15 @@ private class DBHPartitioner(partitions: Int) extends Partitioner { val minId = if (srcDeg < dstDeg) srcId else dstId getPartition(minId) } + def getPartition(idx: Int): PartitionID = { (math.abs(idx * mixingPrime) % partitions).toInt } + def getPartition(idx: Long): PartitionID = { + (math.abs(idx * mixingPrime) % partitions).toInt + } + override def equals(other: Any): Boolean = other match { case h: DBHPartitioner => h.numPartitions == numPartitions From e7bce416f16d878694fd8338a865971dcbab4d72 Mon Sep 17 00:00:00 2001 From: Hao Wang Date: Tue, 28 Apr 2015 12:08:20 +0800 Subject: [PATCH 21/21] remove word table cache --- .../apache/spark/mllib/clustering/LDA.scala | 38 ++++++++++--------- .../spark/mllib/clustering/LDAModel.scala | 30 ++++++++++++--- 2 files changed, 45 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 804cef4c1bf57..b74dc5f67d291 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -17,7 +17,6 @@ package org.apache.spark.mllib.clustering -import java.lang.ref.SoftReference import java.util.Random import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, sum => brzSum} @@ -32,12 +31,13 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.storage.StorageLevel import org.apache.spark.SparkContext._ -import org.apache.spark.util.collection.AppendOnlyMap import org.apache.spark.util.random.XORShiftRandom import LDA._ import LDAUtils._ +import scala.collection.mutable.ArrayBuffer + class LDA private[mllib]( @transient private var corpus: Graph[VD, ED], private val numTopics: Int, @@ -351,7 +351,14 @@ object LDA { val nweGraph = graph.mapTriplets( (pid, iter) => { val gen = new XORShiftRandom(parts * innerIter + pid) - val wordTableCache = new AppendOnlyMap[VertexId, SoftReference[(Double, Table)]]() + // table is a per term data structure + // in GraphX, edges in a partition are clustered by source IDs (term id in this case) + // so, use below simple cache to avoid calculating table each time + val lastTable = (new ArrayBuffer[Int](numTopics.toInt), + new ArrayBuffer[Int](numTopics.toInt), + new ArrayBuffer[Double](numTopics.toInt)) + var lastVid = None.asInstanceOf[Option[VertexId]] + var lastWsum = 0.0 val dv = tDense(totalTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) val dData = new Array[Double](numTopics.toInt) val t = generateAlias(dv._2, dv._1) @@ -369,10 +376,12 @@ object LDA { termTopicCounter.synchronized { dSparse(totalTopicCounter, termTopicCounter, docTopicCounter, dData, currentTopic, numTokens, numTerms, alpha, alphaAS, beta) - val (wSum, w) = wordTable(x => x == null || x.get() == null || gen.nextDouble() < 1e-4, - wordTableCache, totalTopicCounter, termTopicCounter, - termId, numTokens, numTerms, alpha, alphaAS, beta) - val newTopic = tokenSampling(gen, t, tSum, w, termTopicCounter, wSum, + if (lastVid != Some(termId) || gen.nextDouble() < 1e-4) { + lastWsum = wordTable(lastTable, totalTopicCounter, termTopicCounter, + termId, numTokens, numTerms, alpha, alphaAS, beta) + lastVid = Some(termId) + } + val newTopic = tokenSampling(gen, t, tSum, lastTable, termTopicCounter, lastWsum, docTopicCounter, dData, currentTopic) if (newTopic != currentTopic) { @@ -645,8 +654,7 @@ object LDA { } private def wordTable( - updateFunc: SoftReference[(Double, Table)] => Boolean, - cacheMap: AppendOnlyMap[VertexId, SoftReference[(Double, Table)]], + table:Table, totalTopicCounter: BDV[Count], termTopicCounter: VD, termId: VertexId, @@ -654,17 +662,11 @@ object LDA { numTerms: Double, alpha: Double, alphaAS: Double, - beta: Double): (Double, Table) = { - val cacheW = cacheMap(termId) - if (!updateFunc(cacheW)) { - cacheW.get - } else { + beta: Double): Double = { val sv = wSparse(totalTopicCounter, termTopicCounter, numTokens, numTerms, alpha, alphaAS, beta) - val w = (sv._1, generateAlias(sv._2, sv._1)) - cacheMap.update(termId, new SoftReference(w)) - w - } + generateAlias(sv._2, sv._1, Some(table)) + sv._1 } private def sampleSV(gen: Random, table: Table, sv: VD, currentTopic: Int): Int = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 6b477fcc1e43d..8be67cc479f38 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -30,6 +30,8 @@ import org.apache.spark.util.random.XORShiftRandom import LDAUtils._ +import scala.collection.mutable.ArrayBuffer + class LDAModel private[mllib]( private[mllib] val gtc: BDV[Double], private[mllib] val ttc: Array[BSV[Double]], @@ -284,7 +286,7 @@ object LDAModel { private[mllib] object LDAUtils { - type Table = (Array[Int], Array[Int], Array[Double]) + type Table = (ArrayBuffer[Int], ArrayBuffer[Int], ArrayBuffer[Double]) @transient private lazy val tableOrdering = new scala.math.Ordering[(Int, Double)] { override def compare(x: (Int, Double), y: (Int, Double)): Int = { @@ -294,18 +296,36 @@ private[mllib] object LDAUtils { @transient private lazy val tableReverseOrdering = tableOrdering.reverse - def generateAlias(sv: BV[Double], sum: Double): Table = { + def generateAlias(sv: BV[Double], sum: Double, tableCache:Option[Table] = None): Table = { val used = sv.activeSize val probs = sv.activeIterator.slice(0, used) - generateAlias(probs, used, sum) + generateAlias(probs, used, sum, tableCache) } def generateAlias( probs: Iterator[(Int, Double)], used: Int, - sum: Double): Table = { + sum: Double, tableCache:Option[Table]): Table = { val pMean = 1.0 / used - val table = (new Array[Int](used), new Array[Int](used), new Array[Double](used)) + val table = tableCache.getOrElse(new ArrayBuffer[Int](used), + new ArrayBuffer[Int](used), + new ArrayBuffer[Double](used)) + // reset and resize table cache + for (i <- 0 until Math.min(used, table._1.length)) { + table._1(i) = 0 + table._2(i) = 0 + table._3(i) = 0.0 + } + if (used > table._1.length) { + val pad = 0 until (used - table._1.length) + table._1 ++= pad.map(_=>0) + table._2 ++= pad.map(_=>0) + table._3 ++= pad.map(_=>0.0) + } else { + table._1.reduceToSize(used) + table._2.reduceToSize(used) + table._3.reduceToSize(used) + } val lq = new JPriorityQueue[(Int, Double)](used, tableOrdering) val hq = new JPriorityQueue[(Int, Double)](used, tableReverseOrdering)