diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala index 5e17c8da61134..328acea56962a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -19,17 +19,18 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} +import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, normalize, sum => brzSum} import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.clustering.LDA.LearningAlgorithms import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{Vectors, Matrices, Matrix, Vector} import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils +import org.apache.spark.util.{BoundedPriorityQueue, Utils} /** @@ -63,10 +64,11 @@ class LDA private ( private var docConcentration: Double, private var topicConcentration: Double, private var seed: Long, - private var checkpointInterval: Int) extends Logging { + private var checkpointInterval: Int, + private var algorithm: LDA.LearningAlgorithms.Algorithm) extends Logging { def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, - seed = Utils.random.nextLong(), checkpointInterval = 10) + seed = Utils.random.nextLong(), checkpointInterval = 10, algorithm = LDA.LearningAlgorithms.EM) /** * Number of topics to infer. I.e., the number of soft cluster centers. @@ -230,8 +232,14 @@ class LDA private ( * @return Inferred LDA model */ def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { - val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, - checkpointInterval) + val state = algorithm match { + case LearningAlgorithms.EM => LDA.EMLearningStateInitializer.initialState( + documents, k, getDocConcentration, getTopicConcentration, seed, checkpointInterval + ) + case LearningAlgorithms.Gibbs => LDA.EMLearningStateInitializer.initialState( + documents, k, getDocConcentration, getTopicConcentration, seed, checkpointInterval + ) + } var iter = 0 val iterationTimes = Array.fill[Double](maxIterations)(0) while (iter < maxIterations) { @@ -241,7 +249,7 @@ class LDA private ( iterationTimes(iter) = elapsedSeconds iter += 1 } - state.graphCheckpointer.deleteAllCheckpoints() + state.deleteAllCheckpoints() new DistributedLDAModel(state, iterationTimes) } @@ -311,165 +319,319 @@ private[clustering] object LDA { private[clustering] type TokenCount = Double - /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ - private[clustering] def term2index(term: Int): Long = -(1 + term.toLong) - private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt - private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + object LearningAlgorithms extends Enumeration { + type Algorithm = Value + val Gibbs, EM = Value + } - private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + private[clustering] trait LearningState { + def next(): LearningState + def topicsMatrix: Matrix + def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] + def logLikelihood: Double + def logPrior: Double + def topicDistributions: RDD[(Long, Vector)] + def globalTopicTotals: LDA.TopicCounts + def k: Int + def vocabSize: Int + def docConcentration: Double + def topicConcentration: Double + def deleteAllCheckpoints(): Unit + } - /** - * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. - * - * @param graph EM graph, storing current parameter estimates in vertex descriptors and - * data (token counts) in edge descriptors. - * @param k Number of topics - * @param vocabSize Number of unique terms - * @param docConcentration "alpha" - * @param topicConcentration "beta" or "eta" - */ - private[clustering] class EMOptimizer( - var graph: Graph[TopicCounts, TokenCount], - val k: Int, - val vocabSize: Int, - val docConcentration: Double, - val topicConcentration: Double, - checkpointInterval: Int) { - - private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - graph, checkpointInterval) - - def next(): EMOptimizer = { - val eta = topicConcentration - val W = vocabSize - val alpha = docConcentration - - val N_k = globalTopicTotals - val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = - (edgeContext) => { - // Compute N_{wj} gamma_{wjk} + private[clustering] trait LearningStateInitializer { + def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): LearningState + } + + private[clustering] object EMLearningStateInitializer extends LearningStateInitializer { + + /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ + private[clustering] def term2index(term: Int): Long = -(1 + term.toLong) + + private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + + private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + + private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + + private[clustering] class EMLearningState(optimizer: EMOptimizer) extends LearningState { + val k: Int = optimizer.k + val vocabSize = optimizer.vocabSize + val docConcentration = optimizer.docConcentration + val topicConcentration = optimizer.topicConcentration + + def next(): LearningState = { + optimizer.next() + this + } + + def deleteAllCheckpoints() = { + optimizer.graphCheckpointer.deleteAllCheckpoints() + } + + lazy val topicsMatrix: Matrix = { + // Collect row-major topics + val termTopicCounts: Array[(Int, TopicCounts)] = + optimizer.graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) => + (index2term(termIndex), cnts) + }.collect() + // Convert to Matrix + val brzTopics = BDM.zeros[Double](vocabSize, k) + termTopicCounts.foreach { case (term, cnts) => + var j = 0 + while (j < k) { + brzTopics(term, j) = cnts(j) + j += 1 + } + } + Matrices.fromBreeze(brzTopics) + } + + def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val numTopics = k + // Note: N_k is not needed to find the top terms, but it is needed to normalize weights + // to a distribution over terms. + val N_k: TopicCounts = globalTopicTotals + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] = + optimizer.graph.vertices.filter(isTermVertex) + .mapPartitions { termVertices => + // For this partition, collect the most common terms for each topic in queues: + // queues(topic) = queue of (term weight, term index). + // Term weights are N_{wk} / N_k. + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic)) + for ((termId, n_wk) <- termVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt)) + topic += 1 + } + } + Iterator(queues) + }.reduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b} + q1 + } + topicsInQueues.map { q => + val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + } + + lazy val logLikelihood: Double = { + val eta = topicConcentration + val alpha = docConcentration + assert(eta > 1.0) + assert(alpha > 1.0) + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + // Edges: Compute token log probability from phi_{wk}, theta_{kj}. + val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => { val N_wj = edgeContext.attr - // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count - // N_{wj}. - val scaledTopicDistribution: TopicCounts = - computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj - edgeContext.sendToDst((false, scaledTopicDistribution)) - edgeContext.sendToSrc((false, scaledTopicDistribution)) + val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) + val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) + edgeContext.sendToDst(tokenLogLikelihood) } - // This is a hack to detect whether we could modify the values in-place. - // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) - val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = - (m0, m1) => { - val sum = - if (m0._1) { - m0._2 += m1._2 - } else if (m1._1) { - m1._2 += m0._2 + optimizer.graph.aggregateMessages[Double](sendMsg, _ + _) + .map(_._2).fold(0.0)(_ + _) + } + + lazy val logPrior: Double = { + val eta = topicConcentration + val alpha = docConcentration + // Term vertices: Compute phi_{wk}. Use to compute prior log probability. + // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + val seqOp: (Double, (VertexId, TopicCounts)) => Double = { + case (sumPrior: Double, vertex: (VertexId, TopicCounts)) => + if (isTermVertex(vertex)) { + val N_wk = vertex._2 + val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + (eta - 1.0) * brzSum(phi_wk.map(math.log)) } else { - m0._2 + m1._2 + val N_kj = vertex._2 + val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + (alpha - 1.0) * brzSum(theta_kj.map(math.log)) } - (true, sum) } - // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. - val docTopicDistributions: VertexRDD[TopicCounts] = - graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) - .mapValues(_._2) - // Update the vertex descriptors with the new counts. - val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) - graph = newGraph - graphCheckpointer.updateGraph(newGraph) - globalTopicTotals = computeGlobalTopicTotals() - this + optimizer.graph.vertices.aggregate(0.0)(seqOp, _ + _) + } + + def topicDistributions: RDD[(Long, Vector)] = { + optimizer.graph.vertices.filter(isDocumentVertex).map { case (docID, topicCounts) => + (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) + } + } + + def globalTopicTotals: LDA.TopicCounts = { + optimizer.globalTopicTotals + } } /** - * Aggregate distributions over topics from all term vertices. - * - * Note: This executes an action on the graph RDDs. + * Compute bipartite term/doc graph. */ - var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() + def initialState(docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointInterval: Int): EMLearningState = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } - private def computeGlobalTopicTotals(): TopicCounts = { - val numTopics = k - graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) - } + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + def createVertices(): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.flatMap { edge => + val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) + val sum = gamma * edge.attr + Seq((edge.srcId, sum), (edge.dstId, sum)) + } + } + verticesTMP.reduceByKey(_ + _) + } - } + val docTermVertices = createVertices() - /** - * Compute gamma_{wjk}, a distribution over topics k. - */ - private def computePTopic( - docTopicCounts: TopicCounts, - termTopicCounts: TopicCounts, - totalTopicCounts: TopicCounts, - vocabSize: Int, - eta: Double, - alpha: Double): TopicCounts = { - val K = docTopicCounts.length - val N_j = docTopicCounts.data - val N_w = termTopicCounts.data - val N = totalTopicCounts.data - val eta1 = eta - 1.0 - val alpha1 = alpha - 1.0 - val Weta1 = vocabSize * eta1 - var sum = 0.0 - val gamma_wj = new Array[Double](K) - var k = 0 - while (k < K) { - val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1) - gamma_wj(k) = gamma_wjk - sum += gamma_wjk - k += 1 - } - // normalize - BDV(gamma_wj) /= sum - } + // Partition such that edges are grouped by document + val graph = Graph(docTermVertices, edges) + .partitionBy(PartitionStrategy.EdgePartition1D) - /** - * Compute bipartite term/doc graph. - */ - private def initialState( - docs: RDD[(Long, Vector)], - k: Int, - docConcentration: Double, - topicConcentration: Double, - randomSeed: Long, - checkpointInterval: Int): EMOptimizer = { - // For each document, create an edge (Document -> Term) for each unique term in the document. - val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => - // Add edges for terms with non-zero counts. - termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => - Edge(docID, term2index(term), cnt) - } + val optimizer = new EMOptimizer( + graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval + ) + new EMLearningState(optimizer) } - - val vocabSize = docs.take(1).head._2.size - - // Create vertices. - // Initially, we use random soft assignments of tokens to topics (random gamma). - def createVertices(): RDD[(VertexId, TopicCounts)] = { - val verticesTMP: RDD[(VertexId, TopicCounts)] = - edges.mapPartitionsWithIndex { case (partIndex, partEdges) => - val random = new Random(partIndex + randomSeed) - partEdges.flatMap { edge => - val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0) - val sum = gamma * edge.attr - Seq((edge.srcId, sum), (edge.dstId, sum)) + /** + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * @param graph EM graph, storing current parameter estimates in vertex descriptors and + * data (token counts) in edge descriptors. + * @param k Number of topics + * @param vocabSize Number of unique terms + * @param docConcentration "alpha" + * @param topicConcentration "beta" or "eta" + */ + private[clustering] class EMOptimizer( + var graph: Graph[TopicCounts, TokenCount], + val k: Int, + val vocabSize: Int, + val docConcentration: Double, + val topicConcentration: Double, + checkpointInterval: Int) { + + private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + graph, checkpointInterval) + + def next(): EMOptimizer = { + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) } - } - verticesTMP.reduceByKey(_ + _) - } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } - val docTermVertices = createVertices() + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() - // Partition such that edges are grouped by document - val graph = Graph(docTermVertices, edges) - .partitionBy(PartitionStrategy.EdgePartition1D) + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } - new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval) - } + } + /** + * Compute gamma_{wjk}, a distribution over topics k. + */ + private def computePTopic(docTopicCounts: TopicCounts, + termTopicCounts: TopicCounts, + totalTopicCounts: TopicCounts, + vocabSize: Int, + eta: Double, + alpha: Double): TopicCounts = { + val K = docTopicCounts.length + val N_j = docTopicCounts.data + val N_w = termTopicCounts.data + val N = totalTopicCounts.data + val eta1 = eta - 1.0 + val alpha1 = alpha - 1.0 + val Weta1 = vocabSize * eta1 + var sum = 0.0 + val gamma_wj = new Array[Double](K) + var k = 0 + while (k < K) { + val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1) + gamma_wj(k) = gamma_wjk + sum += gamma_wjk + k += 1 + } + // normalize + BDV(gamma_wj) /= sum + } + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index b0e991d2f2344..05d468809c1b1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,13 +17,12 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} +import breeze.linalg.normalize import org.apache.spark.annotation.Experimental -import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.mllib.clustering.LDA.LearningState +import org.apache.spark.mllib.linalg.{Vector, Matrix} import org.apache.spark.rdd.RDD -import org.apache.spark.util.BoundedPriorityQueue /** * :: Experimental :: @@ -193,7 +192,7 @@ class LocalLDAModel private[clustering] ( */ @Experimental class DistributedLDAModel private ( - private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private val state: LearningState, private val globalTopicTotals: LDA.TopicCounts, val k: Int, val vocabSize: Int, @@ -201,10 +200,8 @@ class DistributedLDAModel private ( private val topicConcentration: Double, private[spark] val iterationTimes: Array[Double]) extends LDAModel { - import LDA._ - - private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { - this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, + private[clustering] def this(state: LDA.LearningState, iterationTimes: Array[Double]) = { + this(state, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, state.topicConcentration, iterationTimes) } @@ -223,52 +220,11 @@ class DistributedLDAModel private ( * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. */ override lazy val topicsMatrix: Matrix = { - // Collect row-major topics - val termTopicCounts: Array[(Int, TopicCounts)] = - graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) => - (index2term(termIndex), cnts) - }.collect() - // Convert to Matrix - val brzTopics = BDM.zeros[Double](vocabSize, k) - termTopicCounts.foreach { case (term, cnts) => - var j = 0 - while (j < k) { - brzTopics(term, j) = cnts(j) - j += 1 - } - } - Matrices.fromBreeze(brzTopics) + state.topicsMatrix } override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { - val numTopics = k - // Note: N_k is not needed to find the top terms, but it is needed to normalize weights - // to a distribution over terms. - val N_k: TopicCounts = globalTopicTotals - val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] = - graph.vertices.filter(isTermVertex) - .mapPartitions { termVertices => - // For this partition, collect the most common terms for each topic in queues: - // queues(topic) = queue of (term weight, term index). - // Term weights are N_{wk} / N_k. - val queues = - Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic)) - for ((termId, n_wk) <- termVertices) { - var topic = 0 - while (topic < numTopics) { - queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt)) - topic += 1 - } - } - Iterator(queues) - }.reduce { (q1, q2) => - q1.zip(q2).foreach { case (a, b) => a ++= b} - q1 - } - topicsInQueues.map { q => - val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip - (terms.toArray, termWeights.toArray) - } + state.describeTopics(maxTermsPerTopic) } // TODO @@ -285,24 +241,7 @@ class DistributedLDAModel private ( * hyperparameters. */ lazy val logLikelihood: Double = { - val eta = topicConcentration - val alpha = docConcentration - assert(eta > 1.0) - assert(alpha > 1.0) - val N_k = globalTopicTotals - val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) - // Edges: Compute token log probability from phi_{wk}, theta_{kj}. - val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => { - val N_wj = edgeContext.attr - val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) - val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k - val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) - val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) - edgeContext.sendToDst(tokenLogLikelihood) - } - graph.aggregateMessages[Double](sendMsg, _ + _) - .map(_._2).fold(0.0)(_ + _) + state.logLikelihood } /** @@ -310,27 +249,7 @@ class DistributedLDAModel private ( * log P(topics, topic distributions for docs | alpha, eta) */ lazy val logPrior: Double = { - val eta = topicConcentration - val alpha = docConcentration - // Term vertices: Compute phi_{wk}. Use to compute prior log probability. - // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. - val N_k = globalTopicTotals - val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) - val seqOp: (Double, (VertexId, TopicCounts)) => Double = { - case (sumPrior: Double, vertex: (VertexId, TopicCounts)) => - if (isTermVertex(vertex)) { - val N_wk = vertex._2 - val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k - (eta - 1.0) * brzSum(phi_wk.map(math.log)) - } else { - val N_kj = vertex._2 - val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) - val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) - (alpha - 1.0) * brzSum(theta_kj.map(math.log)) - } - } - graph.vertices.aggregate(0.0)(seqOp, _ + _) + state.logPrior } /** @@ -340,9 +259,7 @@ class DistributedLDAModel private ( * @return RDD of (document ID, topic distribution) pairs */ def topicDistributions: RDD[(Long, Vector)] = { - graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => - (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) - } + state.topicDistributions } // TODO: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 302d751eb8a94..fae3e8da03c80 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -118,10 +118,10 @@ class LDASuite extends FunSuite with MLlibTestSparkContext { val docVertexIds = docIds val termIds = Array(0, 1, 2) val termVertexIds = Array(-1, -2, -3) - assert(docVertexIds.forall(i => !LDA.isTermVertex((i.toLong, 0)))) - assert(termIds.map(LDA.term2index) === termVertexIds) - assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds) - assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0)))) + assert(docVertexIds.forall(i => !LDA.EMLearningStateInitializer.isTermVertex((i.toLong, 0)))) + assert(termIds.map(LDA.EMLearningStateInitializer.term2index) === termVertexIds) + assert(termVertexIds.map(i => LDA.EMLearningStateInitializer.index2term(i.toLong)) === termIds) + assert(termVertexIds.forall(i => LDA.EMLearningStateInitializer.isTermVertex((i.toLong, 0)))) } }