diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala new file mode 100644 index 000000000000..f4c545ad70e9 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.mllib + +import java.text.BreakIterator + +import scala.collection.mutable + +import scopt.OptionParser + +import org.apache.log4j.{Level, Logger} + +import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.rdd.RDD + + +/** + * An example Latent Dirichlet Allocation (LDA) app. Run with + * {{{ + * ./bin/run-example mllib.LDAExample [options] + * }}} + * If you use it as a template to create your own app, please use `spark-submit` to submit your app. + */ +object LDAExample { + + private case class Params( + input: Seq[String] = Seq.empty, + k: Int = 20, + maxIterations: Int = 10, + docConcentration: Double = -1, + topicConcentration: Double = -1, + vocabSize: Int = 10000, + stopwordFile: String = "", + checkpointDir: Option[String] = None, + checkpointInterval: Int = 10) extends AbstractParams[Params] + + def main(args: Array[String]) { + val defaultParams = Params() + + val parser = new OptionParser[Params]("LDAExample") { + head("LDAExample: an example LDA app for plain text data.") + opt[Int]("k") + .text(s"number of topics. default: ${defaultParams.k}") + .action((x, c) => c.copy(k = x)) + opt[Int]("maxIterations") + .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}") + .action((x, c) => c.copy(maxIterations = x)) + opt[Double]("docConcentration") + .text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." + + s" default: ${defaultParams.docConcentration}") + .action((x, c) => c.copy(docConcentration = x)) + opt[Double]("topicConcentration") + .text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." + + s" default: ${defaultParams.topicConcentration}") + .action((x, c) => c.copy(topicConcentration = x)) + opt[Int]("vocabSize") + .text(s"number of distinct word types to use, chosen by frequency. (-1=all)" + + s" default: ${defaultParams.vocabSize}") + .action((x, c) => c.copy(vocabSize = x)) + opt[String]("stopwordFile") + .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." + + s" default: ${defaultParams.stopwordFile}") + .action((x, c) => c.copy(stopwordFile = x)) + opt[String]("checkpointDir") + .text(s"Directory for checkpointing intermediate results." + + s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." + + s" default: ${defaultParams.checkpointDir}") + .action((x, c) => c.copy(checkpointDir = Some(x))) + opt[Int]("checkpointInterval") + .text(s"Iterations between each checkpoint. Only used if checkpointDir is set." + + s" default: ${defaultParams.checkpointInterval}") + .action((x, c) => c.copy(checkpointInterval = x)) + arg[String]("...") + .text("input paths (directories) to plain text corpora." + + " Each text file line should hold 1 document.") + .unbounded() + .required() + .action((x, c) => c.copy(input = c.input :+ x)) + } + + parser.parse(args, defaultParams).map { params => + run(params) + }.getOrElse { + parser.showUsageAsError + sys.exit(1) + } + } + + private def run(params: Params) { + val conf = new SparkConf().setAppName(s"LDAExample with $params") + val sc = new SparkContext(conf) + + Logger.getRootLogger.setLevel(Level.WARN) + + // Load documents, and prepare them for LDA. + val preprocessStart = System.nanoTime() + val (corpus, vocabArray, actualNumTokens) = + preprocess(sc, params.input, params.vocabSize, params.stopwordFile) + corpus.cache() + val actualCorpusSize = corpus.count() + val actualVocabSize = vocabArray.size + val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9 + + println() + println(s"Corpus summary:") + println(s"\t Training set size: $actualCorpusSize documents") + println(s"\t Vocabulary size: $actualVocabSize terms") + println(s"\t Training set size: $actualNumTokens tokens") + println(s"\t Preprocessing time: $preprocessElapsed sec") + println() + + // Run LDA. + val lda = new LDA() + lda.setK(params.k) + .setMaxIterations(params.maxIterations) + .setDocConcentration(params.docConcentration) + .setTopicConcentration(params.topicConcentration) + .setCheckpointInterval(params.checkpointInterval) + if (params.checkpointDir.nonEmpty) { + lda.setCheckpointDir(params.checkpointDir.get) + } + val startTime = System.nanoTime() + val ldaModel = lda.run(corpus) + val elapsed = (System.nanoTime() - startTime) / 1e9 + + println(s"Finished training LDA model. Summary:") + println(s"\t Training time: $elapsed sec") + val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble + println(s"\t Training data average log likelihood: $avgLogLikelihood") + println() + + // Print the topics, showing the top-weighted terms for each topic. + val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10) + val topics = topicIndices.map { case (terms, termWeights) => + terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) } + } + println(s"${params.k} topics:") + topics.zipWithIndex.foreach { case (topic, i) => + println(s"TOPIC $i") + topic.foreach { case (term, weight) => + println(s"$term\t$weight") + } + println() + } + + } + + /** + * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors. + * @return (corpus, vocabulary as array, total token count in corpus) + */ + private def preprocess( + sc: SparkContext, + paths: Seq[String], + vocabSize: Int, + stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = { + + // Get dataset of document texts + // One document per line in each text file. + val textRDD: RDD[String] = sc.textFile(paths.mkString(",")) + + // Split text into words + val tokenizer = new SimpleTokenizer(sc, stopwordFile) + val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) => + id -> tokenizer.getWords(text) + } + tokenized.cache() + + // Counts words: RDD[(word, wordCount)] + val wordCounts: RDD[(String, Long)] = tokenized + .flatMap { case (_, tokens) => tokens.map(_ -> 1L) } + .reduceByKey(_ + _) + wordCounts.cache() + val fullVocabSize = wordCounts.count() + // Select vocab + // (vocab: Map[word -> id], total tokens after selecting vocab) + val (vocab: Map[String, Int], selectedTokenCount: Long) = { + val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) { + // Use all terms + wordCounts.collect().sortBy(-_._2) + } else { + // Sort terms to select vocab + wordCounts.sortBy(_._2, ascending = false).take(vocabSize) + } + (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum) + } + + val documents = tokenized.map { case (id, tokens) => + // Filter tokens by vocabulary, and create word count vector representation of document. + val wc = new mutable.HashMap[Int, Int]() + tokens.foreach { term => + if (vocab.contains(term)) { + val termIndex = vocab(term) + wc(termIndex) = wc.getOrElse(termIndex, 0) + 1 + } + } + val indices = wc.keys.toArray.sorted + val values = indices.map(i => wc(i).toDouble) + + val sb = Vectors.sparse(vocab.size, indices, values) + (id, sb) + } + + val vocabArray = new Array[String](vocab.size) + vocab.foreach { case (term, i) => vocabArray(i) = term } + + (documents, vocabArray, selectedTokenCount) + } +} + +/** + * Simple Tokenizer. + * + * TODO: Formalize the interface, and make this a public class in mllib.feature + */ +private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable { + + private val stopwords: Set[String] = if (stopwordFile.isEmpty) { + Set.empty[String] + } else { + val stopwordText = sc.textFile(stopwordFile).collect() + stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet + } + + // Matches sequences of Unicode letters + private val allWordRegex = "^(\\p{L}*)$".r + + // Ignore words shorter than this length. + private val minWordLength = 3 + + def getWords(text: String): IndexedSeq[String] = { + + val words = new mutable.ArrayBuffer[String]() + + // Use Java BreakIterator to tokenize text into words. + val wb = BreakIterator.getWordInstance + wb.setText(text) + + // current,end index start,end of each word + var current = wb.first() + var end = wb.next() + while (end != BreakIterator.DONE) { + // Convert to lowercase + val word: String = text.substring(current, end).toLowerCase + // Remove short words and strings that aren't only letters + word match { + case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) => + words += w + case _ => + } + + current = end + try { + end = wb.next() + } catch { + case e: Exception => + // Ignore remaining text in line. + // This is a known bug in BreakIterator (for some Java versions), + // which fails when it sees certain characters. + end = BreakIterator.DONE + } + } + words + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala new file mode 100644 index 000000000000..d8f82867a09d --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala @@ -0,0 +1,519 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import java.util.Random + +import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy} + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaPairRDD +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils + + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA), a topic model designed for text documents. + * + * Terminology: + * - "word" = "term": an element of the vocabulary + * - "token": instance of a term appearing in a document + * - "topic": multinomial distribution over words representing some concept + * + * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented + * according to the Asuncion et al. (2009) paper referenced below. + * + * References: + * - Original LDA paper (journal version): + * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003. + * - This class implements their "smoothed" LDA model. + * - Paper which clearly explains several algorithms, including EM: + * Asuncion, Welling, Smyth, and Teh. + * "On Smoothing and Inference for Topic Models." UAI, 2009. + */ +@Experimental +class LDA private ( + private var k: Int, + private var maxIterations: Int, + private var docConcentration: Double, + private var topicConcentration: Double, + private var seed: Long, + private var checkpointDir: Option[String], + private var checkpointInterval: Int) extends Logging { + + def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1, + seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10) + + /** + * Number of topics to infer. I.e., the number of soft cluster centers. + */ + def getK: Int = k + + /** + * Number of topics to infer. I.e., the number of soft cluster centers. + * (default = 10) + */ + def setK(k: Int): this.type = { + require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k") + this.k = k + this + } + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a symmetric Dirichlet distribution. + */ + def getDocConcentration: Double = { + if (this.docConcentration == -1) { + (50.0 / k) + 1.0 + } else { + this.docConcentration + } + } + + /** + * Concentration parameter (commonly named "alpha") for the prior placed on documents' + * distributions over topics ("theta"). + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * This value should be > 1.0, where larger values mean more smoothing (more regularization). + * If set to -1, then docConcentration is set automatically. + * (default = -1 = automatic) + * + * Automatic setting of parameter: + * - For EM: default = (50 / k) + 1. + * - The 50/k is common in LDA libraries. + * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * + * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + * but values in (0,1) are not yet supported. + */ + def setDocConcentration(docConcentration: Double): this.type = { + require(docConcentration > 1.0 || docConcentration == -1.0, + s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration") + this.docConcentration = docConcentration + this + } + + /** Alias for [[getDocConcentration]] */ + def getAlpha: Double = getDocConcentration + + /** Alias for [[setDocConcentration()]] */ + def setAlpha(alpha: Double): this.type = setDocConcentration(alpha) + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + */ + def getTopicConcentration: Double = { + if (this.topicConcentration == -1) { + 1.1 + } else { + this.topicConcentration + } + } + + /** + * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics' + * distributions over terms. + * + * This is the parameter to a symmetric Dirichlet distribution. + * + * Note: The topics' distributions over terms are called "beta" in the original LDA paper + * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009. + * + * This value should be > 0.0. + * If set to -1, then topicConcentration is set automatically. + * (default = -1 = automatic) + * + * Automatic setting of parameter: + * - For EM: default = 0.1 + 1. + * - The 0.1 gives a small amount of smoothing. + * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM. + * + * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions), + * but values in (0,1) are not yet supported. + */ + def setTopicConcentration(topicConcentration: Double): this.type = { + require(topicConcentration > 1.0 || topicConcentration == -1.0, + s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration") + this.topicConcentration = topicConcentration + this + } + + /** Alias for [[getTopicConcentration]] */ + def getBeta: Double = getTopicConcentration + + /** Alias for [[setTopicConcentration()]] */ + def setBeta(beta: Double): this.type = setBeta(beta) + + /** + * Maximum number of iterations for learning. + */ + def getMaxIterations: Int = maxIterations + + /** + * Maximum number of iterations for learning. + * (default = 20) + */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** Random seed */ + def getSeed: Long = seed + + /** Random seed */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } + + /** + * Directory for storing checkpoint files during learning. + * This is not necessary, but checkpointing helps with recovery (when nodes fail). + * It also helps with eliminating temporary shuffle files on disk, which can be important when + * LDA is run for many iterations. + */ + def getCheckpointDir: Option[String] = checkpointDir + + /** + * Directory for storing checkpoint files during learning. + * This is not necessary, but checkpointing helps with recovery (when nodes fail). + * It also helps with eliminating temporary shuffle files on disk, which can be important when + * LDA is run for many iterations. + * + * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value + * given to LDA is ignored, and the existing directory is kept. + * + * (default = None) + */ + def setCheckpointDir(checkpointDir: String): this.type = { + this.checkpointDir = Some(checkpointDir) + this + } + + /** + * Clear the directory for storing checkpoint files during learning. + * If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still + * occur; otherwise, no checkpointing will be used. + */ + def clearCheckpointDir(): this.type = { + this.checkpointDir = None + this + } + + /** + * Period (in iterations) between checkpoints. + * @see [[getCheckpointDir]] + */ + def getCheckpointInterval: Int = checkpointInterval + + /** + * Period (in iterations) between checkpoints. + * (default = 10) + * @see [[getCheckpointDir]] + */ + def setCheckpointInterval(checkpointInterval: Int): this.type = { + this.checkpointInterval = checkpointInterval + this + } + + /** + * Learn an LDA model using the given dataset. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * Document IDs must be unique and >= 0. + * @return Inferred LDA model + */ + def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = { + val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed, + checkpointDir, checkpointInterval) + var iter = 0 + val iterationTimes = Array.fill[Double](maxIterations)(0) + while (iter < maxIterations) { + val start = System.nanoTime() + state.next() + val elapsedSeconds = (System.nanoTime() - start) / 1e9 + iterationTimes(iter) = elapsedSeconds + iter += 1 + } + state.graphCheckpointer.deleteAllCheckpoints() + new DistributedLDAModel(state, iterationTimes) + } + + /** Java-friendly version of [[run()]] */ + def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = { + run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]]) + } +} + + +private[clustering] object LDA { + + /* + DEVELOPERS NOTE: + + This implementation uses GraphX, where the graph is bipartite with 2 types of vertices: + - Document vertices + - indexed with unique indices >= 0 + - Store vectors of length k (# topics). + - Term vertices + - indexed {-1, -2, ..., -vocabSize} + - Store vectors of length k (# topics). + - Edges correspond to terms appearing in documents. + - Edges are directed Document -> Term. + - Edges are partitioned by documents. + + Info on EM implementation. + - We follow Section 2.2 from Asuncion et al., 2009. We use some of their notation. + - In this implementation, there is one edge for every unique term appearing in a document, + i.e., for every unique (document, term) pair. + - Notation: + - N_{wkj} = count of tokens of term w currently assigned to topic k in document j + - N_{*} where * is missing a subscript w/k/j is the count summed over missing subscript(s) + - gamma_{wjk} = P(z_i = k | x_i = w, d_i = j), + the probability of term x_i in document d_i having topic z_i. + - Data graph + - Document vertices store N_{kj} + - Term vertices store N_{wk} + - Edges store N_{wj}. + - Global data N_k + - Algorithm + - Initial state: + - Document and term vertices store random counts N_{wk}, N_{kj}. + - E-step: For each (document,term) pair i, compute P(z_i | x_i, d_i). + - Aggregate N_k from term vertices. + - Compute gamma_{wjk} for each possible topic k, from each triplet. + using inputs N_{wk}, N_{kj}, N_k. + - M-step: Compute sufficient statistics for hidden parameters phi and theta + (counts N_{wk}, N_{kj}, N_k). + - Document update: + - N_{kj} <- sum_w N_{wj} gamma_{wjk} + - N_j <- sum_k N_{kj} (only needed to output predictions) + - Term update: + - N_{wk} <- sum_j N_{wj} gamma_{wjk} + - N_k <- sum_w N_{wk} + + TODO: Add simplex constraints to allow alpha in (0,1). + See: Vorontsov and Potapenko. "Tutorial on Probabilistic Topic Modeling : Additive + Regularization for Stochastic Matrix Factorization." 2014. + */ + + /** + * Vector over topics (length k) of token counts. + * The meaning of these counts can vary, and it may or may not be normalized to be a distribution. + */ + type TopicCounts = BDV[Double] + + type TokenCount = Double + + /** Term vertex IDs are {-1, -2, ..., -vocabSize} */ + def term2index(term: Int): Long = -(1 + term.toLong) + + def index2term(termIndex: Long): Int = -(1 + termIndex).toInt + + def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0 + + def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0 + + /** + * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters. + * + * @param graph EM graph, storing current parameter estimates in vertex descriptors and + * data (token counts) in edge descriptors. + * @param k Number of topics + * @param vocabSize Number of unique terms + * @param docConcentration "alpha" + * @param topicConcentration "beta" or "eta" + */ + class EMOptimizer( + var graph: Graph[TopicCounts, TokenCount], + val k: Int, + val vocabSize: Int, + val docConcentration: Double, + val topicConcentration: Double, + checkpointDir: Option[String], + checkpointInterval: Int) { + + private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + graph, checkpointDir, checkpointInterval) + + def next(): EMOptimizer = { + val eta = topicConcentration + val W = vocabSize + val alpha = docConcentration + + val N_k = globalTopicTotals + val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit = + (edgeContext) => { + // Compute N_{wj} gamma_{wjk} + val N_wj = edgeContext.attr + // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count + // N_{wj}. + val scaledTopicDistribution: TopicCounts = + computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj + edgeContext.sendToDst((false, scaledTopicDistribution)) + edgeContext.sendToSrc((false, scaledTopicDistribution)) + } + // This is a hack to detect whether we could modify the values in-place. + // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438) + val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) = + (m0, m1) => { + val sum = + if (m0._1) { + m0._2 += m1._2 + } else if (m1._1) { + m1._2 += m0._2 + } else { + m0._2 + m1._2 + } + (true, sum) + } + // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts. + val docTopicDistributions: VertexRDD[TopicCounts] = + graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg) + .mapValues(_._2) + // Update the vertex descriptors with the new counts. + val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) + graph = newGraph + graphCheckpointer.updateGraph(newGraph) + globalTopicTotals = computeGlobalTopicTotals() + this + } + + /** + * Aggregate distributions over topics from all term vertices. + * + * Note: This executes an action on the graph RDDs. + */ + var globalTopicTotals: TopicCounts = computeGlobalTopicTotals() + + private def computeGlobalTopicTotals(): TopicCounts = { + val numTopics = k + graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _) + } + + } + + /** + * Compute gamma_{wjk}, a distribution over topics k. + */ + private def computePTopic( + docTopicCounts: TopicCounts, + termTopicCounts: TopicCounts, + totalTopicCounts: TopicCounts, + vocabSize: Int, + eta: Double, + alpha: Double): TopicCounts = { + val K = docTopicCounts.length + val N_j = docTopicCounts.data + val N_w = termTopicCounts.data + val N = totalTopicCounts.data + val eta1 = eta - 1.0 + val alpha1 = alpha - 1.0 + val Weta1 = vocabSize * eta1 + var sum = 0.0 + val gamma_wj = new Array[Double](K) + var k = 0 + while (k < K) { + val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1) + gamma_wj(k) = gamma_wjk + sum += gamma_wjk + k += 1 + } + // normalize + BDV(gamma_wj) /= sum + } + + /** + * Compute bipartite term/doc graph. + */ + private def initialState( + docs: RDD[(Long, Vector)], + k: Int, + docConcentration: Double, + topicConcentration: Double, + randomSeed: Long, + checkpointDir: Option[String], + checkpointInterval: Int): EMOptimizer = { + // For each document, create an edge (Document -> Term) for each unique term in the document. + val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) => + // Add edges for terms with non-zero counts. + termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) => + Edge(docID, term2index(term), cnt) + } + } + + val vocabSize = docs.take(1).head._2.size + + // Create vertices. + // Initially, we use random soft assignments of tokens to topics (random gamma). + val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] = + edges.mapPartitionsWithIndex { case (partIndex, partEdges) => + val random = new Random(partIndex + randomSeed) + partEdges.map { edge => + // Create a random gamma_{wjk} + (edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)) + } + } + def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = { + val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] = + edgesWithGamma.map { case (edge, gamma: TopicCounts) => + (sendToWhere(edge), (edge.attr, gamma)) + } + verticesTMP.aggregateByKey(BDV.zeros[Double](k))( + (sum, t) => { + brzAxpy(t._1, t._2, sum) + sum + }, + (sum0, sum1) => { + sum0 += sum1 + } + ) + } + val docVertices = createVertices(_.srcId) + val termVertices = createVertices(_.dstId) + + // Partition such that edges are grouped by document + val graph = Graph(docVertices ++ termVertices, edges) + .partitionBy(PartitionStrategy.EdgePartition1D) + + new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir, + checkpointInterval) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala new file mode 100644 index 000000000000..19e8aab6eabd --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -0,0 +1,351 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.BoundedPriorityQueue + +/** + * :: Experimental :: + * + * Latent Dirichlet Allocation (LDA) model. + * + * This abstraction permits for different underlying representations, + * including local and distributed data structures. + */ +@Experimental +abstract class LDAModel private[clustering] { + + /** Number of topics */ + def k: Int + + /** Vocabulary size (number of terms or terms in the vocabulary) */ + def vocabSize: Int + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + */ + def topicsMatrix: Matrix + + /** + * Return the topics described by weighted terms. + * + * This limits the number of terms per topic. + * This is approximate; it may not return exactly the top-weighted terms for each topic. + * To get a more precise set of top terms, increase maxTermsPerTopic. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (term indices, term weights in topic). + * Each topic's terms are sorted in order of decreasing weight. + */ + def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] + + /** + * Return the topics described by weighted terms. + * + * WARNING: If vocabSize and k are large, this can return a large object! + * + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (term indices, term weights in topic). + * Each topic's terms are sorted in order of decreasing weight. + */ + def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize) + + /* TODO (once LDA can be trained with Strings or given a dictionary) + * Return the topics described by weighted terms. + * + * This is similar to [[describeTopics()]] but returns String values for terms. + * If this model was trained using Strings or was given a dictionary, then this method returns + * terms as text. Otherwise, this method returns terms as term indices. + * + * This limits the number of terms per topic. + * This is approximate; it may not return exactly the top-weighted terms for each topic. + * To get a more precise set of top terms, increase maxTermsPerTopic. + * + * @param maxTermsPerTopic Maximum number of terms to collect for each topic. + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (terms, term weights in topic) where terms are either the actual term text + * (if available) or the term indices. + * Each topic's terms are sorted in order of decreasing weight. + */ + // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])] + + /* TODO (once LDA can be trained with Strings or given a dictionary) + * Return the topics described by weighted terms. + * + * This is similar to [[describeTopics()]] but returns String values for terms. + * If this model was trained using Strings or was given a dictionary, then this method returns + * terms as text. Otherwise, this method returns terms as term indices. + * + * WARNING: If vocabSize and k are large, this can return a large object! + * + * @return Array over topics. Each topic is represented as a pair of matching arrays: + * (terms, term weights in topic) where terms are either the actual term text + * (if available) or the term indices. + * Each topic's terms are sorted in order of decreasing weight. + */ + // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] = + // describeTopicsAsStrings(vocabSize) + + /* TODO + * Compute the log likelihood of the observed tokens, given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, alpha, eta) + * + * Note: + * - This excludes the prior. + * - Even with the prior, this is NOT the same as the data log likelihood given the + * hyperparameters. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * This must use the same vocabulary (ordering of term counts) as in training. + * Document IDs must be unique and >= 0. + * @return Estimated log likelihood of the data under this model + */ + // def logLikelihood(documents: RDD[(Long, Vector)]): Double + + /* TODO + * Compute the estimated topic distribution for each document. + * This is often called “theta” in the literature. + * + * @param documents RDD of documents, which are term (word) count vectors paired with IDs. + * The term count vectors are "bags of words" with a fixed-size vocabulary + * (where the vocabulary size is the length of the vector). + * This must use the same vocabulary (ordering of term counts) as in training. + * Document IDs must be unique and >= 0. + * @return Estimated topic distribution for each document. + * The returned RDD may be zipped with the given RDD, where each returned vector + * is a multinomial distribution over topics. + */ + // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] + +} + +/** + * :: Experimental :: + * + * Local LDA model. + * This model stores only the inferred topics. + * It may be used for computing topics for new documents, but it may give less accurate answers + * than the [[DistributedLDAModel]]. + * + * @param topics Inferred topics (vocabSize x k matrix). + */ +@Experimental +class LocalLDAModel private[clustering] ( + private val topics: Matrix) extends LDAModel with Serializable { + + override def k: Int = topics.numCols + + override def vocabSize: Int = topics.numRows + + override def topicsMatrix: Matrix = topics + + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val brzTopics = topics.toBreeze.toDenseMatrix + Range(0, k).map { topicIndex => + val topic = normalize(brzTopics(::, topicIndex), 1.0) + val (termWeights, terms) = + topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip + (terms.toArray, termWeights.toArray) + }.toArray + } + + // TODO + // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + // TODO: + // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + +} + +/** + * :: Experimental :: + * + * Distributed LDA model. + * This model stores the inferred topics, the full training dataset, and the topic distributions. + * When computing topics for new documents, it may give more accurate answers + * than the [[LocalLDAModel]]. + */ +@Experimental +class DistributedLDAModel private ( + private val graph: Graph[LDA.TopicCounts, LDA.TokenCount], + private val globalTopicTotals: LDA.TopicCounts, + val k: Int, + val vocabSize: Int, + private val docConcentration: Double, + private val topicConcentration: Double, + private[spark] val iterationTimes: Array[Double]) extends LDAModel { + + import LDA._ + + private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = { + this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration, + state.topicConcentration, iterationTimes) + } + + /** + * Convert model to a local model. + * The local model stores the inferred topics but not the topic distributions for training + * documents. + */ + def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix) + + /** + * Inferred topics, where each topic is represented by a distribution over terms. + * This is a matrix of size vocabSize x k, where each column is a topic. + * No guarantees are given about the ordering of the topics. + * + * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large. + */ + override lazy val topicsMatrix: Matrix = { + // Collect row-major topics + val termTopicCounts: Array[(Int, TopicCounts)] = + graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) => + (index2term(termIndex), cnts) + }.collect() + // Convert to Matrix + val brzTopics = BDM.zeros[Double](vocabSize, k) + termTopicCounts.foreach { case (term, cnts) => + var j = 0 + while (j < k) { + brzTopics(term, j) = cnts(j) + j += 1 + } + } + Matrices.fromBreeze(brzTopics) + } + + override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = { + val numTopics = k + // Note: N_k is not needed to find the top terms, but it is needed to normalize weights + // to a distribution over terms. + val N_k: TopicCounts = globalTopicTotals + val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] = + graph.vertices.filter(isTermVertex) + .mapPartitions { termVertices => + // For this partition, collect the most common terms for each topic in queues: + // queues(topic) = queue of (term weight, term index). + // Term weights are N_{wk} / N_k. + val queues = + Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic)) + for ((termId, n_wk) <- termVertices) { + var topic = 0 + while (topic < numTopics) { + queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt)) + topic += 1 + } + } + Iterator(queues) + }.reduce { (q1, q2) => + q1.zip(q2).foreach { case (a, b) => a ++= b} + q1 + } + topicsInQueues.map { q => + val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + } + + // TODO + // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? + + /** + * Log likelihood of the observed tokens in the training set, + * given the current parameter estimates: + * log P(docs | topics, topic distributions for docs, alpha, eta) + * + * Note: + * - This excludes the prior; for that, use [[logPrior]]. + * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the + * hyperparameters. + */ + lazy val logLikelihood: Double = { + val eta = topicConcentration + val alpha = docConcentration + assert(eta > 1.0) + assert(alpha > 1.0) + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + // Edges: Compute token log probability from phi_{wk}, theta_{kj}. + val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => { + val N_wj = edgeContext.attr + val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) + val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) + edgeContext.sendToDst(tokenLogLikelihood) + } + graph.aggregateMessages[Double](sendMsg, _ + _) + .map(_._2).fold(0.0)(_ + _) + } + + /** + * Log probability of the current parameter estimate: + * log P(topics, topic distributions for docs | alpha, eta) + */ + lazy val logPrior: Double = { + val eta = topicConcentration + val alpha = docConcentration + // Term vertices: Compute phi_{wk}. Use to compute prior log probability. + // Doc vertex: Compute theta_{kj}. Use to compute prior log probability. + val N_k = globalTopicTotals + val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0)) + val seqOp: (Double, (VertexId, TopicCounts)) => Double = { + case (sumPrior: Double, vertex: (VertexId, TopicCounts)) => + if (isTermVertex(vertex)) { + val N_wk = vertex._2 + val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) + val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + (eta - 1.0) * brzSum(phi_wk.map(math.log)) + } else { + val N_kj = vertex._2 + val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0) + val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) + (alpha - 1.0) * brzSum(theta_kj.map(math.log)) + } + } + graph.vertices.aggregate(0.0)(seqOp, _ + _) + } + + /** + * For each document in the training set, return the distribution over topics for that document + * (i.e., "theta_doc"). + * + * @return RDD of (document ID, topic distribution) pairs + */ + def topicDistributions: RDD[(Long, Vector)] = { + graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => + (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0))) + } + } + + // TODO: + // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala new file mode 100644 index 000000000000..76672fe51e83 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.Logging +import org.apache.spark.graphx.Graph +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing Graphs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are + * responsible for materializing the graph to ensure that persisting and checkpointing actually + * occur. + * + * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. + * - Unpersist graphs from queue until there are at most 3 persisted graphs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new graph, and put in a queue of checkpointed graphs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Graphs should be + * checkpointed). + * - This class removes checkpoint files once later graphs have been checkpointed. + * However, references to the older graphs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (graph1, graph2, graph3, ...) = ... + * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * graph1.vertices.count(); graph1.edges.count() + * // persisted: graph1 + * cp.updateGraph(graph2) + * graph2.vertices.count(); graph2.edges.count() + * // persisted: graph1, graph2 + * // checkpointed: graph2 + * cp.updateGraph(graph3) + * graph3.vertices.count(); graph3.edges.count() + * // persisted: graph1, graph2, graph3 + * // checkpointed: graph2 + * cp.updateGraph(graph4) + * graph4.vertices.count(); graph4.edges.count() + * // persisted: graph2, graph3, graph4 + * // checkpointed: graph4 + * cp.updateGraph(graph5) + * graph5.vertices.count(); graph5.edges.count() + * // persisted: graph3, graph4, graph5 + * // checkpointed: graph4 + * }}} + * + * @param currentGraph Initial graph + * @param checkpointDir The directory for storing checkpoint files + * @param checkpointInterval Graphs will be checkpointed at this interval + * @tparam VD Vertex descriptor type + * @tparam ED Edge descriptor type + * + * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + */ +private[mllib] class PeriodicGraphCheckpointer[VD, ED]( + var currentGraph: Graph[VD, ED], + val checkpointDir: Option[String], + val checkpointInterval: Int) extends Logging { + + /** FIFO queue of past checkpointed RDDs */ + private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() + + /** FIFO queue of past persisted RDDs */ + private val persistedQueue = mutable.Queue[Graph[VD, ED]]() + + /** Number of times [[updateGraph()]] has been called */ + private var updateCount = 0 + + /** + * Spark Context for the Graphs given to this checkpointer. + * NOTE: This code assumes that only one SparkContext is used for the given graphs. + */ + private val sc = currentGraph.vertices.sparkContext + + // If a checkpoint directory is given, and there's no prior checkpoint directory, + // then set the checkpoint directory with the given one. + if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) { + sc.setCheckpointDir(checkpointDir.get) + } + + updateGraph(currentGraph) + + /** + * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the graph + * has been materialized. + * + * @param newGraph New graph created from previous graphs in the lineage. + */ + def updateGraph(newGraph: Graph[VD, ED]): Unit = { + if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { + newGraph.persist() + } + persistedQueue.enqueue(newGraph) + // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: + // Users should call [[updateGraph()]] when a new graph has been created, + // before the graph has been materialized. + while (persistedQueue.size > 3) { + val graphToUnpersist = persistedQueue.dequeue() + graphToUnpersist.unpersist(blocking = false) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + newGraph.checkpoint() + checkpointQueue.enqueue(newGraph) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (checkpointQueue.get(1).get.isCheckpointed) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + } + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.size > 0) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + old.getCheckpointFiles.foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java new file mode 100644 index 000000000000..dc10aa67c7c1 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering; + +import java.io.Serializable; +import java.util.ArrayList; + +import org.apache.spark.api.java.JavaRDD; +import scala.Tuple2; + +import org.junit.After; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertArrayEquals; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Matrix; +import org.apache.spark.mllib.linalg.Vector; + + +public class JavaLDASuite implements Serializable { + private transient JavaSparkContext sc; + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLDA"); + ArrayList> tinyCorpus = new ArrayList>(); + for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) { + tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(), + LDASuite$.MODULE$.tinyCorpus()[i]._2())); + } + JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2); + corpus = JavaPairRDD.fromJavaRDD(tmpCorpus); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void localLDAModel() { + LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics()); + + // Check: basic parameters + assertEquals(model.k(), tinyK); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), tinyTopics); + + // Check: describeTopics() with all terms + Tuple2[] fullTopicSummary = model.describeTopics(); + assertEquals(fullTopicSummary.length, tinyK); + for (int i = 0; i < fullTopicSummary.length; i++) { + assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1()); + assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5); + } + } + + @Test + public void distributedLDAModel() { + int k = 3; + double topicSmoothing = 1.2; + double termSmoothing = 1.2; + + // Train a model + LDA lda = new LDA(); + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345); + + DistributedLDAModel model = lda.run(corpus); + + // Check: basic parameters + LocalLDAModel localModel = model.toLocal(); + assertEquals(model.k(), k); + assertEquals(localModel.k(), k); + assertEquals(model.vocabSize(), tinyVocabSize); + assertEquals(localModel.vocabSize(), tinyVocabSize); + assertEquals(model.topicsMatrix(), localModel.topicsMatrix()); + + // Check: topic summaries + Tuple2[] roundedTopicSummary = model.describeTopics(); + assertEquals(roundedTopicSummary.length, k); + Tuple2[] roundedLocalTopicSummary = localModel.describeTopics(); + assertEquals(roundedLocalTopicSummary.length, k); + + // Check: log probabilities + assert(model.logLikelihood() < 0.0); + assert(model.logPrior() < 0.0); + } + + private static int tinyK = LDASuite$.MODULE$.tinyK(); + private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize(); + private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics(); + private static Tuple2[] tinyTopicDescription = + LDASuite$.MODULE$.tinyTopicDescription(); + JavaPairRDD corpus; + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala new file mode 100644 index 000000000000..302d751eb8a9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class LDASuite extends FunSuite with MLlibTestSparkContext { + + import LDASuite._ + + test("LocalLDAModel") { + val model = new LocalLDAModel(tinyTopics) + + // Check: basic parameters + assert(model.k === tinyK) + assert(model.vocabSize === tinyVocabSize) + assert(model.topicsMatrix === tinyTopics) + + // Check: describeTopics() with all terms + val fullTopicSummary = model.describeTopics() + assert(fullTopicSummary.size === tinyK) + fullTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms) + assert(algTermWeights === termWeights) + } + + // Check: describeTopics() with some terms + val smallNumTerms = 3 + val smallTopicSummary = model.describeTopics(maxTermsPerTopic = smallNumTerms) + smallTopicSummary.zip(tinyTopicDescription).foreach { + case ((algTerms, algTermWeights), (terms, termWeights)) => + assert(algTerms === terms.slice(0, smallNumTerms)) + assert(algTermWeights === termWeights.slice(0, smallNumTerms)) + } + } + + test("running and DistributedLDAModel") { + val k = 3 + val topicSmoothing = 1.2 + val termSmoothing = 1.2 + + // Train a model + val lda = new LDA() + lda.setK(k) + .setDocConcentration(topicSmoothing) + .setTopicConcentration(termSmoothing) + .setMaxIterations(5) + .setSeed(12345) + val corpus = sc.parallelize(tinyCorpus, 2) + + val model: DistributedLDAModel = lda.run(corpus) + + // Check: basic parameters + val localModel = model.toLocal + assert(model.k === k) + assert(localModel.k === k) + assert(model.vocabSize === tinyVocabSize) + assert(localModel.vocabSize === tinyVocabSize) + assert(model.topicsMatrix === localModel.topicsMatrix) + + // Check: topic summaries + // The odd decimal formatting and sorting is a hack to do a robust comparison. + val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => + // cut values to 3 digits after the decimal place + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } + }.sortBy(_.mkString("")) + val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + // cut values to 3 digits after the decimal place + terms.zip(termWeights).map { case (term, weight) => + ("%.3f".format(weight).toDouble, term.toInt) + } + }.sortBy(_.mkString("")) + roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => + assert(t1 === t2) + } + + // Check: per-doc topic distributions + val topicDistributions = model.topicDistributions.collect() + // Ensure all documents are covered. + assert(topicDistributions.size === tinyCorpus.size) + assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet) + // Ensure we have proper distributions + topicDistributions.foreach { case (docId, topicDistribution) => + assert(topicDistribution.size === tinyK) + assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5) + } + + // Check: log probabilities + assert(model.logLikelihood < 0.0) + assert(model.logPrior < 0.0) + } + + test("vertex indexing") { + // Check vertex ID indexing and conversions. + val docIds = Array(0, 1, 2) + val docVertexIds = docIds + val termIds = Array(0, 1, 2) + val termVertexIds = Array(-1, -2, -3) + assert(docVertexIds.forall(i => !LDA.isTermVertex((i.toLong, 0)))) + assert(termIds.map(LDA.term2index) === termVertexIds) + assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds) + assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0)))) + } +} + +private[clustering] object LDASuite { + + def tinyK: Int = 3 + def tinyVocabSize: Int = 5 + def tinyTopicsAsArray: Array[Array[Double]] = Array( + Array[Double](0.1, 0.2, 0.3, 0.4, 0.0), // topic 0 + Array[Double](0.5, 0.05, 0.05, 0.1, 0.3), // topic 1 + Array[Double](0.2, 0.2, 0.05, 0.05, 0.5) // topic 2 + ) + def tinyTopics: Matrix = new DenseMatrix(numRows = tinyVocabSize, numCols = tinyK, + values = tinyTopicsAsArray.fold(Array.empty[Double])(_ ++ _)) + def tinyTopicDescription: Array[(Array[Int], Array[Double])] = tinyTopicsAsArray.map { topic => + val (termWeights, terms) = topic.zipWithIndex.sortBy(-_._1).unzip + (terms.toArray, termWeights.toArray) + } + + def tinyCorpus = Array( + Vectors.dense(1, 3, 0, 2, 8), + Vectors.dense(0, 2, 1, 0, 4), + Vectors.dense(2, 3, 12, 3, 1), + Vectors.dense(0, 3, 1, 9, 8), + Vectors.dense(1, 1, 4, 2, 6) + ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) } + assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data + +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala new file mode 100644 index 000000000000..dac28a369b5b --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.scalatest.FunSuite + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext { + + import PeriodicGraphCheckpointerSuite._ + + // TODO: Do I need to call count() on the graphs' RDDs? + + test("Persisting") { + var graphsToCheck = Seq.empty[GraphToCheck] + + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.updateGraph(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.updateGraph(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicGraphCheckpointerSuite { + + case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) + + val edges = Seq( + Edge[Double](0, 1, 0), + Edge[Double](1, 2, 0), + Edge[Double](2, 3, 0), + Edge[Double](3, 4, 0)) + + def createGraph(sc: SparkContext): Graph[Double, Double] = { + Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + } + + def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { + graphs.foreach { g => + checkPersistence(g.graph, g.gIndex, iteration) + } + } + + /** + * Check storage level of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(graph.vertices.getStorageLevel == StorageLevel.NONE) + assert(graph.edges.getStorageLevel == StorageLevel.NONE) + } else { + assert(graph.vertices.getStorageLevel != StorageLevel.NONE) + assert(graph.edges.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" + + s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n") + } + } + + def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = { + graphs.reverse.foreach { g => + checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = { + // Note: We cannot check graph.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this graph.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration) + graph.getCheckpointFiles.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), + "Graph checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of graph. + * @param gIndex Index of graph in order inserted into checkpointer (from 1). + * @param iteration Total number of graphs inserted into checkpointer. + */ + def checkCheckpoint( + graph: Graph[_, _], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph) + // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(graph.isCheckpointed, "Graph should be checkpointed") + assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(graph) + } + } else { + // Graph should never be checkpointed + assert(!graph.isCheckpointed, "Graph should never have been checkpointed") + assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" + + s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +}