diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
new file mode 100644
index 000000000000..f4c545ad70e9
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.text.BreakIterator
+
+import scala.collection.mutable
+
+import scopt.OptionParser
+
+import org.apache.log4j.{Level, Logger}
+
+import org.apache.spark.{SparkContext, SparkConf}
+import org.apache.spark.mllib.clustering.LDA
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * An example Latent Dirichlet Allocation (LDA) app. Run with
+ * {{{
+ * ./bin/run-example mllib.LDAExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object LDAExample {
+
+ private case class Params(
+ input: Seq[String] = Seq.empty,
+ k: Int = 20,
+ maxIterations: Int = 10,
+ docConcentration: Double = -1,
+ topicConcentration: Double = -1,
+ vocabSize: Int = 10000,
+ stopwordFile: String = "",
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("LDAExample") {
+ head("LDAExample: an example LDA app for plain text data.")
+ opt[Int]("k")
+ .text(s"number of topics. default: ${defaultParams.k}")
+ .action((x, c) => c.copy(k = x))
+ opt[Int]("maxIterations")
+ .text(s"number of iterations of learning. default: ${defaultParams.maxIterations}")
+ .action((x, c) => c.copy(maxIterations = x))
+ opt[Double]("docConcentration")
+ .text(s"amount of topic smoothing to use (> 1.0) (-1=auto)." +
+ s" default: ${defaultParams.docConcentration}")
+ .action((x, c) => c.copy(docConcentration = x))
+ opt[Double]("topicConcentration")
+ .text(s"amount of term (word) smoothing to use (> 1.0) (-1=auto)." +
+ s" default: ${defaultParams.topicConcentration}")
+ .action((x, c) => c.copy(topicConcentration = x))
+ opt[Int]("vocabSize")
+ .text(s"number of distinct word types to use, chosen by frequency. (-1=all)" +
+ s" default: ${defaultParams.vocabSize}")
+ .action((x, c) => c.copy(vocabSize = x))
+ opt[String]("stopwordFile")
+ .text(s"filepath for a list of stopwords. Note: This must fit on a single machine." +
+ s" default: ${defaultParams.stopwordFile}")
+ .action((x, c) => c.copy(stopwordFile = x))
+ opt[String]("checkpointDir")
+ .text(s"Directory for checkpointing intermediate results." +
+ s" Checkpointing helps with recovery and eliminates temporary shuffle files on disk." +
+ s" default: ${defaultParams.checkpointDir}")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"Iterations between each checkpoint. Only used if checkpointDir is set." +
+ s" default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
+ arg[String]("...")
+ .text("input paths (directories) to plain text corpora." +
+ " Each text file line should hold 1 document.")
+ .unbounded()
+ .required()
+ .action((x, c) => c.copy(input = c.input :+ x))
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ parser.showUsageAsError
+ sys.exit(1)
+ }
+ }
+
+ private def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"LDAExample with $params")
+ val sc = new SparkContext(conf)
+
+ Logger.getRootLogger.setLevel(Level.WARN)
+
+ // Load documents, and prepare them for LDA.
+ val preprocessStart = System.nanoTime()
+ val (corpus, vocabArray, actualNumTokens) =
+ preprocess(sc, params.input, params.vocabSize, params.stopwordFile)
+ corpus.cache()
+ val actualCorpusSize = corpus.count()
+ val actualVocabSize = vocabArray.size
+ val preprocessElapsed = (System.nanoTime() - preprocessStart) / 1e9
+
+ println()
+ println(s"Corpus summary:")
+ println(s"\t Training set size: $actualCorpusSize documents")
+ println(s"\t Vocabulary size: $actualVocabSize terms")
+ println(s"\t Training set size: $actualNumTokens tokens")
+ println(s"\t Preprocessing time: $preprocessElapsed sec")
+ println()
+
+ // Run LDA.
+ val lda = new LDA()
+ lda.setK(params.k)
+ .setMaxIterations(params.maxIterations)
+ .setDocConcentration(params.docConcentration)
+ .setTopicConcentration(params.topicConcentration)
+ .setCheckpointInterval(params.checkpointInterval)
+ if (params.checkpointDir.nonEmpty) {
+ lda.setCheckpointDir(params.checkpointDir.get)
+ }
+ val startTime = System.nanoTime()
+ val ldaModel = lda.run(corpus)
+ val elapsed = (System.nanoTime() - startTime) / 1e9
+
+ println(s"Finished training LDA model. Summary:")
+ println(s"\t Training time: $elapsed sec")
+ val avgLogLikelihood = ldaModel.logLikelihood / actualCorpusSize.toDouble
+ println(s"\t Training data average log likelihood: $avgLogLikelihood")
+ println()
+
+ // Print the topics, showing the top-weighted terms for each topic.
+ val topicIndices = ldaModel.describeTopics(maxTermsPerTopic = 10)
+ val topics = topicIndices.map { case (terms, termWeights) =>
+ terms.zip(termWeights).map { case (term, weight) => (vocabArray(term.toInt), weight) }
+ }
+ println(s"${params.k} topics:")
+ topics.zipWithIndex.foreach { case (topic, i) =>
+ println(s"TOPIC $i")
+ topic.foreach { case (term, weight) =>
+ println(s"$term\t$weight")
+ }
+ println()
+ }
+
+ }
+
+ /**
+ * Load documents, tokenize them, create vocabulary, and prepare documents as term count vectors.
+ * @return (corpus, vocabulary as array, total token count in corpus)
+ */
+ private def preprocess(
+ sc: SparkContext,
+ paths: Seq[String],
+ vocabSize: Int,
+ stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
+
+ // Get dataset of document texts
+ // One document per line in each text file.
+ val textRDD: RDD[String] = sc.textFile(paths.mkString(","))
+
+ // Split text into words
+ val tokenizer = new SimpleTokenizer(sc, stopwordFile)
+ val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
+ id -> tokenizer.getWords(text)
+ }
+ tokenized.cache()
+
+ // Counts words: RDD[(word, wordCount)]
+ val wordCounts: RDD[(String, Long)] = tokenized
+ .flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
+ .reduceByKey(_ + _)
+ wordCounts.cache()
+ val fullVocabSize = wordCounts.count()
+ // Select vocab
+ // (vocab: Map[word -> id], total tokens after selecting vocab)
+ val (vocab: Map[String, Int], selectedTokenCount: Long) = {
+ val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
+ // Use all terms
+ wordCounts.collect().sortBy(-_._2)
+ } else {
+ // Sort terms to select vocab
+ wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
+ }
+ (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
+ }
+
+ val documents = tokenized.map { case (id, tokens) =>
+ // Filter tokens by vocabulary, and create word count vector representation of document.
+ val wc = new mutable.HashMap[Int, Int]()
+ tokens.foreach { term =>
+ if (vocab.contains(term)) {
+ val termIndex = vocab(term)
+ wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
+ }
+ }
+ val indices = wc.keys.toArray.sorted
+ val values = indices.map(i => wc(i).toDouble)
+
+ val sb = Vectors.sparse(vocab.size, indices, values)
+ (id, sb)
+ }
+
+ val vocabArray = new Array[String](vocab.size)
+ vocab.foreach { case (term, i) => vocabArray(i) = term }
+
+ (documents, vocabArray, selectedTokenCount)
+ }
+}
+
+/**
+ * Simple Tokenizer.
+ *
+ * TODO: Formalize the interface, and make this a public class in mllib.feature
+ */
+private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
+
+ private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
+ Set.empty[String]
+ } else {
+ val stopwordText = sc.textFile(stopwordFile).collect()
+ stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
+ }
+
+ // Matches sequences of Unicode letters
+ private val allWordRegex = "^(\\p{L}*)$".r
+
+ // Ignore words shorter than this length.
+ private val minWordLength = 3
+
+ def getWords(text: String): IndexedSeq[String] = {
+
+ val words = new mutable.ArrayBuffer[String]()
+
+ // Use Java BreakIterator to tokenize text into words.
+ val wb = BreakIterator.getWordInstance
+ wb.setText(text)
+
+ // current,end index start,end of each word
+ var current = wb.first()
+ var end = wb.next()
+ while (end != BreakIterator.DONE) {
+ // Convert to lowercase
+ val word: String = text.substring(current, end).toLowerCase
+ // Remove short words and strings that aren't only letters
+ word match {
+ case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
+ words += w
+ case _ =>
+ }
+
+ current = end
+ try {
+ end = wb.next()
+ } catch {
+ case e: Exception =>
+ // Ignore remaining text in line.
+ // This is a known bug in BreakIterator (for some Java versions),
+ // which fails when it sees certain characters.
+ end = BreakIterator.DONE
+ }
+ }
+ words
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
new file mode 100644
index 000000000000..d8f82867a09d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -0,0 +1,519 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.util.Random
+
+import breeze.linalg.{DenseVector => BDV, normalize, axpy => brzAxpy}
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaPairRDD
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.GraphImpl
+import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
+
+
+/**
+ * :: Experimental ::
+ *
+ * Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
+ *
+ * Terminology:
+ * - "word" = "term": an element of the vocabulary
+ * - "token": instance of a term appearing in a document
+ * - "topic": multinomial distribution over words representing some concept
+ *
+ * Currently, the underlying implementation uses Expectation-Maximization (EM), implemented
+ * according to the Asuncion et al. (2009) paper referenced below.
+ *
+ * References:
+ * - Original LDA paper (journal version):
+ * Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
+ * - This class implements their "smoothed" LDA model.
+ * - Paper which clearly explains several algorithms, including EM:
+ * Asuncion, Welling, Smyth, and Teh.
+ * "On Smoothing and Inference for Topic Models." UAI, 2009.
+ */
+@Experimental
+class LDA private (
+ private var k: Int,
+ private var maxIterations: Int,
+ private var docConcentration: Double,
+ private var topicConcentration: Double,
+ private var seed: Long,
+ private var checkpointDir: Option[String],
+ private var checkpointInterval: Int) extends Logging {
+
+ def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
+ seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10)
+
+ /**
+ * Number of topics to infer. I.e., the number of soft cluster centers.
+ */
+ def getK: Int = k
+
+ /**
+ * Number of topics to infer. I.e., the number of soft cluster centers.
+ * (default = 10)
+ */
+ def setK(k: Int): this.type = {
+ require(k > 0, s"LDA k (number of clusters) must be > 0, but was set to $k")
+ this.k = k
+ this
+ }
+
+ /**
+ * Concentration parameter (commonly named "alpha") for the prior placed on documents'
+ * distributions over topics ("theta").
+ *
+ * This is the parameter to a symmetric Dirichlet distribution.
+ */
+ def getDocConcentration: Double = {
+ if (this.docConcentration == -1) {
+ (50.0 / k) + 1.0
+ } else {
+ this.docConcentration
+ }
+ }
+
+ /**
+ * Concentration parameter (commonly named "alpha") for the prior placed on documents'
+ * distributions over topics ("theta").
+ *
+ * This is the parameter to a symmetric Dirichlet distribution.
+ *
+ * This value should be > 1.0, where larger values mean more smoothing (more regularization).
+ * If set to -1, then docConcentration is set automatically.
+ * (default = -1 = automatic)
+ *
+ * Automatic setting of parameter:
+ * - For EM: default = (50 / k) + 1.
+ * - The 50/k is common in LDA libraries.
+ * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+ *
+ * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
+ * but values in (0,1) are not yet supported.
+ */
+ def setDocConcentration(docConcentration: Double): this.type = {
+ require(docConcentration > 1.0 || docConcentration == -1.0,
+ s"LDA docConcentration must be > 1.0 (or -1 for auto), but was set to $docConcentration")
+ this.docConcentration = docConcentration
+ this
+ }
+
+ /** Alias for [[getDocConcentration]] */
+ def getAlpha: Double = getDocConcentration
+
+ /** Alias for [[setDocConcentration()]] */
+ def setAlpha(alpha: Double): this.type = setDocConcentration(alpha)
+
+ /**
+ * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
+ * distributions over terms.
+ *
+ * This is the parameter to a symmetric Dirichlet distribution.
+ *
+ * Note: The topics' distributions over terms are called "beta" in the original LDA paper
+ * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
+ */
+ def getTopicConcentration: Double = {
+ if (this.topicConcentration == -1) {
+ 1.1
+ } else {
+ this.topicConcentration
+ }
+ }
+
+ /**
+ * Concentration parameter (commonly named "beta" or "eta") for the prior placed on topics'
+ * distributions over terms.
+ *
+ * This is the parameter to a symmetric Dirichlet distribution.
+ *
+ * Note: The topics' distributions over terms are called "beta" in the original LDA paper
+ * by Blei et al., but are called "phi" in many later papers such as Asuncion et al., 2009.
+ *
+ * This value should be > 0.0.
+ * If set to -1, then topicConcentration is set automatically.
+ * (default = -1 = automatic)
+ *
+ * Automatic setting of parameter:
+ * - For EM: default = 0.1 + 1.
+ * - The 0.1 gives a small amount of smoothing.
+ * - The +1 follows Asuncion et al. (2009), who recommend a +1 adjustment for EM.
+ *
+ * Note: The restriction > 1.0 may be relaxed in the future (allowing sparse solutions),
+ * but values in (0,1) are not yet supported.
+ */
+ def setTopicConcentration(topicConcentration: Double): this.type = {
+ require(topicConcentration > 1.0 || topicConcentration == -1.0,
+ s"LDA topicConcentration must be > 1.0 (or -1 for auto), but was set to $topicConcentration")
+ this.topicConcentration = topicConcentration
+ this
+ }
+
+ /** Alias for [[getTopicConcentration]] */
+ def getBeta: Double = getTopicConcentration
+
+ /** Alias for [[setTopicConcentration()]] */
+ def setBeta(beta: Double): this.type = setBeta(beta)
+
+ /**
+ * Maximum number of iterations for learning.
+ */
+ def getMaxIterations: Int = maxIterations
+
+ /**
+ * Maximum number of iterations for learning.
+ * (default = 20)
+ */
+ def setMaxIterations(maxIterations: Int): this.type = {
+ this.maxIterations = maxIterations
+ this
+ }
+
+ /** Random seed */
+ def getSeed: Long = seed
+
+ /** Random seed */
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
+ /**
+ * Directory for storing checkpoint files during learning.
+ * This is not necessary, but checkpointing helps with recovery (when nodes fail).
+ * It also helps with eliminating temporary shuffle files on disk, which can be important when
+ * LDA is run for many iterations.
+ */
+ def getCheckpointDir: Option[String] = checkpointDir
+
+ /**
+ * Directory for storing checkpoint files during learning.
+ * This is not necessary, but checkpointing helps with recovery (when nodes fail).
+ * It also helps with eliminating temporary shuffle files on disk, which can be important when
+ * LDA is run for many iterations.
+ *
+ * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value
+ * given to LDA is ignored, and the existing directory is kept.
+ *
+ * (default = None)
+ */
+ def setCheckpointDir(checkpointDir: String): this.type = {
+ this.checkpointDir = Some(checkpointDir)
+ this
+ }
+
+ /**
+ * Clear the directory for storing checkpoint files during learning.
+ * If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still
+ * occur; otherwise, no checkpointing will be used.
+ */
+ def clearCheckpointDir(): this.type = {
+ this.checkpointDir = None
+ this
+ }
+
+ /**
+ * Period (in iterations) between checkpoints.
+ * @see [[getCheckpointDir]]
+ */
+ def getCheckpointInterval: Int = checkpointInterval
+
+ /**
+ * Period (in iterations) between checkpoints.
+ * (default = 10)
+ * @see [[getCheckpointDir]]
+ */
+ def setCheckpointInterval(checkpointInterval: Int): this.type = {
+ this.checkpointInterval = checkpointInterval
+ this
+ }
+
+ /**
+ * Learn an LDA model using the given dataset.
+ *
+ * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
+ * The term count vectors are "bags of words" with a fixed-size vocabulary
+ * (where the vocabulary size is the length of the vector).
+ * Document IDs must be unique and >= 0.
+ * @return Inferred LDA model
+ */
+ def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
+ val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
+ checkpointDir, checkpointInterval)
+ var iter = 0
+ val iterationTimes = Array.fill[Double](maxIterations)(0)
+ while (iter < maxIterations) {
+ val start = System.nanoTime()
+ state.next()
+ val elapsedSeconds = (System.nanoTime() - start) / 1e9
+ iterationTimes(iter) = elapsedSeconds
+ iter += 1
+ }
+ state.graphCheckpointer.deleteAllCheckpoints()
+ new DistributedLDAModel(state, iterationTimes)
+ }
+
+ /** Java-friendly version of [[run()]] */
+ def run(documents: JavaPairRDD[java.lang.Long, Vector]): DistributedLDAModel = {
+ run(documents.rdd.asInstanceOf[RDD[(Long, Vector)]])
+ }
+}
+
+
+private[clustering] object LDA {
+
+ /*
+ DEVELOPERS NOTE:
+
+ This implementation uses GraphX, where the graph is bipartite with 2 types of vertices:
+ - Document vertices
+ - indexed with unique indices >= 0
+ - Store vectors of length k (# topics).
+ - Term vertices
+ - indexed {-1, -2, ..., -vocabSize}
+ - Store vectors of length k (# topics).
+ - Edges correspond to terms appearing in documents.
+ - Edges are directed Document -> Term.
+ - Edges are partitioned by documents.
+
+ Info on EM implementation.
+ - We follow Section 2.2 from Asuncion et al., 2009. We use some of their notation.
+ - In this implementation, there is one edge for every unique term appearing in a document,
+ i.e., for every unique (document, term) pair.
+ - Notation:
+ - N_{wkj} = count of tokens of term w currently assigned to topic k in document j
+ - N_{*} where * is missing a subscript w/k/j is the count summed over missing subscript(s)
+ - gamma_{wjk} = P(z_i = k | x_i = w, d_i = j),
+ the probability of term x_i in document d_i having topic z_i.
+ - Data graph
+ - Document vertices store N_{kj}
+ - Term vertices store N_{wk}
+ - Edges store N_{wj}.
+ - Global data N_k
+ - Algorithm
+ - Initial state:
+ - Document and term vertices store random counts N_{wk}, N_{kj}.
+ - E-step: For each (document,term) pair i, compute P(z_i | x_i, d_i).
+ - Aggregate N_k from term vertices.
+ - Compute gamma_{wjk} for each possible topic k, from each triplet.
+ using inputs N_{wk}, N_{kj}, N_k.
+ - M-step: Compute sufficient statistics for hidden parameters phi and theta
+ (counts N_{wk}, N_{kj}, N_k).
+ - Document update:
+ - N_{kj} <- sum_w N_{wj} gamma_{wjk}
+ - N_j <- sum_k N_{kj} (only needed to output predictions)
+ - Term update:
+ - N_{wk} <- sum_j N_{wj} gamma_{wjk}
+ - N_k <- sum_w N_{wk}
+
+ TODO: Add simplex constraints to allow alpha in (0,1).
+ See: Vorontsov and Potapenko. "Tutorial on Probabilistic Topic Modeling : Additive
+ Regularization for Stochastic Matrix Factorization." 2014.
+ */
+
+ /**
+ * Vector over topics (length k) of token counts.
+ * The meaning of these counts can vary, and it may or may not be normalized to be a distribution.
+ */
+ type TopicCounts = BDV[Double]
+
+ type TokenCount = Double
+
+ /** Term vertex IDs are {-1, -2, ..., -vocabSize} */
+ def term2index(term: Int): Long = -(1 + term.toLong)
+
+ def index2term(termIndex: Long): Int = -(1 + termIndex).toInt
+
+ def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0
+
+ def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
+
+ /**
+ * Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
+ *
+ * @param graph EM graph, storing current parameter estimates in vertex descriptors and
+ * data (token counts) in edge descriptors.
+ * @param k Number of topics
+ * @param vocabSize Number of unique terms
+ * @param docConcentration "alpha"
+ * @param topicConcentration "beta" or "eta"
+ */
+ class EMOptimizer(
+ var graph: Graph[TopicCounts, TokenCount],
+ val k: Int,
+ val vocabSize: Int,
+ val docConcentration: Double,
+ val topicConcentration: Double,
+ checkpointDir: Option[String],
+ checkpointInterval: Int) {
+
+ private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
+ graph, checkpointDir, checkpointInterval)
+
+ def next(): EMOptimizer = {
+ val eta = topicConcentration
+ val W = vocabSize
+ val alpha = docConcentration
+
+ val N_k = globalTopicTotals
+ val sendMsg: EdgeContext[TopicCounts, TokenCount, (Boolean, TopicCounts)] => Unit =
+ (edgeContext) => {
+ // Compute N_{wj} gamma_{wjk}
+ val N_wj = edgeContext.attr
+ // E-STEP: Compute gamma_{wjk} (smoothed topic distributions), scaled by token count
+ // N_{wj}.
+ val scaledTopicDistribution: TopicCounts =
+ computePTopic(edgeContext.srcAttr, edgeContext.dstAttr, N_k, W, eta, alpha) *= N_wj
+ edgeContext.sendToDst((false, scaledTopicDistribution))
+ edgeContext.sendToSrc((false, scaledTopicDistribution))
+ }
+ // This is a hack to detect whether we could modify the values in-place.
+ // TODO: Add zero/seqOp/combOp option to aggregateMessages. (SPARK-5438)
+ val mergeMsg: ((Boolean, TopicCounts), (Boolean, TopicCounts)) => (Boolean, TopicCounts) =
+ (m0, m1) => {
+ val sum =
+ if (m0._1) {
+ m0._2 += m1._2
+ } else if (m1._1) {
+ m1._2 += m0._2
+ } else {
+ m0._2 + m1._2
+ }
+ (true, sum)
+ }
+ // M-STEP: Aggregation computes new N_{kj}, N_{wk} counts.
+ val docTopicDistributions: VertexRDD[TopicCounts] =
+ graph.aggregateMessages[(Boolean, TopicCounts)](sendMsg, mergeMsg)
+ .mapValues(_._2)
+ // Update the vertex descriptors with the new counts.
+ val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges)
+ graph = newGraph
+ graphCheckpointer.updateGraph(newGraph)
+ globalTopicTotals = computeGlobalTopicTotals()
+ this
+ }
+
+ /**
+ * Aggregate distributions over topics from all term vertices.
+ *
+ * Note: This executes an action on the graph RDDs.
+ */
+ var globalTopicTotals: TopicCounts = computeGlobalTopicTotals()
+
+ private def computeGlobalTopicTotals(): TopicCounts = {
+ val numTopics = k
+ graph.vertices.filter(isTermVertex).values.fold(BDV.zeros[Double](numTopics))(_ += _)
+ }
+
+ }
+
+ /**
+ * Compute gamma_{wjk}, a distribution over topics k.
+ */
+ private def computePTopic(
+ docTopicCounts: TopicCounts,
+ termTopicCounts: TopicCounts,
+ totalTopicCounts: TopicCounts,
+ vocabSize: Int,
+ eta: Double,
+ alpha: Double): TopicCounts = {
+ val K = docTopicCounts.length
+ val N_j = docTopicCounts.data
+ val N_w = termTopicCounts.data
+ val N = totalTopicCounts.data
+ val eta1 = eta - 1.0
+ val alpha1 = alpha - 1.0
+ val Weta1 = vocabSize * eta1
+ var sum = 0.0
+ val gamma_wj = new Array[Double](K)
+ var k = 0
+ while (k < K) {
+ val gamma_wjk = (N_w(k) + eta1) * (N_j(k) + alpha1) / (N(k) + Weta1)
+ gamma_wj(k) = gamma_wjk
+ sum += gamma_wjk
+ k += 1
+ }
+ // normalize
+ BDV(gamma_wj) /= sum
+ }
+
+ /**
+ * Compute bipartite term/doc graph.
+ */
+ private def initialState(
+ docs: RDD[(Long, Vector)],
+ k: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ randomSeed: Long,
+ checkpointDir: Option[String],
+ checkpointInterval: Int): EMOptimizer = {
+ // For each document, create an edge (Document -> Term) for each unique term in the document.
+ val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
+ // Add edges for terms with non-zero counts.
+ termCounts.toBreeze.activeIterator.filter(_._2 != 0.0).map { case (term, cnt) =>
+ Edge(docID, term2index(term), cnt)
+ }
+ }
+
+ val vocabSize = docs.take(1).head._2.size
+
+ // Create vertices.
+ // Initially, we use random soft assignments of tokens to topics (random gamma).
+ val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] =
+ edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
+ val random = new Random(partIndex + randomSeed)
+ partEdges.map { edge =>
+ // Create a random gamma_{wjk}
+ (edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0))
+ }
+ }
+ def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = {
+ val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] =
+ edgesWithGamma.map { case (edge, gamma: TopicCounts) =>
+ (sendToWhere(edge), (edge.attr, gamma))
+ }
+ verticesTMP.aggregateByKey(BDV.zeros[Double](k))(
+ (sum, t) => {
+ brzAxpy(t._1, t._2, sum)
+ sum
+ },
+ (sum0, sum1) => {
+ sum0 += sum1
+ }
+ )
+ }
+ val docVertices = createVertices(_.srcId)
+ val termVertices = createVertices(_.dstId)
+
+ // Partition such that edges are grouped by document
+ val graph = Graph(docVertices ++ termVertices, edges)
+ .partitionBy(PartitionStrategy.EdgePartition1D)
+
+ new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir,
+ checkpointInterval)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
new file mode 100644
index 000000000000..19e8aab6eabd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -0,0 +1,351 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
+import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.BoundedPriorityQueue
+
+/**
+ * :: Experimental ::
+ *
+ * Latent Dirichlet Allocation (LDA) model.
+ *
+ * This abstraction permits for different underlying representations,
+ * including local and distributed data structures.
+ */
+@Experimental
+abstract class LDAModel private[clustering] {
+
+ /** Number of topics */
+ def k: Int
+
+ /** Vocabulary size (number of terms or terms in the vocabulary) */
+ def vocabSize: Int
+
+ /**
+ * Inferred topics, where each topic is represented by a distribution over terms.
+ * This is a matrix of size vocabSize x k, where each column is a topic.
+ * No guarantees are given about the ordering of the topics.
+ */
+ def topicsMatrix: Matrix
+
+ /**
+ * Return the topics described by weighted terms.
+ *
+ * This limits the number of terms per topic.
+ * This is approximate; it may not return exactly the top-weighted terms for each topic.
+ * To get a more precise set of top terms, increase maxTermsPerTopic.
+ *
+ * @param maxTermsPerTopic Maximum number of terms to collect for each topic.
+ * @return Array over topics. Each topic is represented as a pair of matching arrays:
+ * (term indices, term weights in topic).
+ * Each topic's terms are sorted in order of decreasing weight.
+ */
+ def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])]
+
+ /**
+ * Return the topics described by weighted terms.
+ *
+ * WARNING: If vocabSize and k are large, this can return a large object!
+ *
+ * @return Array over topics. Each topic is represented as a pair of matching arrays:
+ * (term indices, term weights in topic).
+ * Each topic's terms are sorted in order of decreasing weight.
+ */
+ def describeTopics(): Array[(Array[Int], Array[Double])] = describeTopics(vocabSize)
+
+ /* TODO (once LDA can be trained with Strings or given a dictionary)
+ * Return the topics described by weighted terms.
+ *
+ * This is similar to [[describeTopics()]] but returns String values for terms.
+ * If this model was trained using Strings or was given a dictionary, then this method returns
+ * terms as text. Otherwise, this method returns terms as term indices.
+ *
+ * This limits the number of terms per topic.
+ * This is approximate; it may not return exactly the top-weighted terms for each topic.
+ * To get a more precise set of top terms, increase maxTermsPerTopic.
+ *
+ * @param maxTermsPerTopic Maximum number of terms to collect for each topic.
+ * @return Array over topics. Each topic is represented as a pair of matching arrays:
+ * (terms, term weights in topic) where terms are either the actual term text
+ * (if available) or the term indices.
+ * Each topic's terms are sorted in order of decreasing weight.
+ */
+ // def describeTopicsAsStrings(maxTermsPerTopic: Int): Array[(Array[Double], Array[String])]
+
+ /* TODO (once LDA can be trained with Strings or given a dictionary)
+ * Return the topics described by weighted terms.
+ *
+ * This is similar to [[describeTopics()]] but returns String values for terms.
+ * If this model was trained using Strings or was given a dictionary, then this method returns
+ * terms as text. Otherwise, this method returns terms as term indices.
+ *
+ * WARNING: If vocabSize and k are large, this can return a large object!
+ *
+ * @return Array over topics. Each topic is represented as a pair of matching arrays:
+ * (terms, term weights in topic) where terms are either the actual term text
+ * (if available) or the term indices.
+ * Each topic's terms are sorted in order of decreasing weight.
+ */
+ // def describeTopicsAsStrings(): Array[(Array[Double], Array[String])] =
+ // describeTopicsAsStrings(vocabSize)
+
+ /* TODO
+ * Compute the log likelihood of the observed tokens, given the current parameter estimates:
+ * log P(docs | topics, topic distributions for docs, alpha, eta)
+ *
+ * Note:
+ * - This excludes the prior.
+ * - Even with the prior, this is NOT the same as the data log likelihood given the
+ * hyperparameters.
+ *
+ * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
+ * The term count vectors are "bags of words" with a fixed-size vocabulary
+ * (where the vocabulary size is the length of the vector).
+ * This must use the same vocabulary (ordering of term counts) as in training.
+ * Document IDs must be unique and >= 0.
+ * @return Estimated log likelihood of the data under this model
+ */
+ // def logLikelihood(documents: RDD[(Long, Vector)]): Double
+
+ /* TODO
+ * Compute the estimated topic distribution for each document.
+ * This is often called “theta” in the literature.
+ *
+ * @param documents RDD of documents, which are term (word) count vectors paired with IDs.
+ * The term count vectors are "bags of words" with a fixed-size vocabulary
+ * (where the vocabulary size is the length of the vector).
+ * This must use the same vocabulary (ordering of term counts) as in training.
+ * Document IDs must be unique and >= 0.
+ * @return Estimated topic distribution for each document.
+ * The returned RDD may be zipped with the given RDD, where each returned vector
+ * is a multinomial distribution over topics.
+ */
+ // def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)]
+
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Local LDA model.
+ * This model stores only the inferred topics.
+ * It may be used for computing topics for new documents, but it may give less accurate answers
+ * than the [[DistributedLDAModel]].
+ *
+ * @param topics Inferred topics (vocabSize x k matrix).
+ */
+@Experimental
+class LocalLDAModel private[clustering] (
+ private val topics: Matrix) extends LDAModel with Serializable {
+
+ override def k: Int = topics.numCols
+
+ override def vocabSize: Int = topics.numRows
+
+ override def topicsMatrix: Matrix = topics
+
+ override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
+ val brzTopics = topics.toBreeze.toDenseMatrix
+ Range(0, k).map { topicIndex =>
+ val topic = normalize(brzTopics(::, topicIndex), 1.0)
+ val (termWeights, terms) =
+ topic.toArray.zipWithIndex.sortBy(-_._1).take(maxTermsPerTopic).unzip
+ (terms.toArray, termWeights.toArray)
+ }.toArray
+ }
+
+ // TODO
+ // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
+
+ // TODO:
+ // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
+
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Distributed LDA model.
+ * This model stores the inferred topics, the full training dataset, and the topic distributions.
+ * When computing topics for new documents, it may give more accurate answers
+ * than the [[LocalLDAModel]].
+ */
+@Experimental
+class DistributedLDAModel private (
+ private val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
+ private val globalTopicTotals: LDA.TopicCounts,
+ val k: Int,
+ val vocabSize: Int,
+ private val docConcentration: Double,
+ private val topicConcentration: Double,
+ private[spark] val iterationTimes: Array[Double]) extends LDAModel {
+
+ import LDA._
+
+ private[clustering] def this(state: LDA.EMOptimizer, iterationTimes: Array[Double]) = {
+ this(state.graph, state.globalTopicTotals, state.k, state.vocabSize, state.docConcentration,
+ state.topicConcentration, iterationTimes)
+ }
+
+ /**
+ * Convert model to a local model.
+ * The local model stores the inferred topics but not the topic distributions for training
+ * documents.
+ */
+ def toLocal: LocalLDAModel = new LocalLDAModel(topicsMatrix)
+
+ /**
+ * Inferred topics, where each topic is represented by a distribution over terms.
+ * This is a matrix of size vocabSize x k, where each column is a topic.
+ * No guarantees are given about the ordering of the topics.
+ *
+ * WARNING: This matrix is collected from an RDD. Beware memory usage when vocabSize, k are large.
+ */
+ override lazy val topicsMatrix: Matrix = {
+ // Collect row-major topics
+ val termTopicCounts: Array[(Int, TopicCounts)] =
+ graph.vertices.filter(_._1 < 0).map { case (termIndex, cnts) =>
+ (index2term(termIndex), cnts)
+ }.collect()
+ // Convert to Matrix
+ val brzTopics = BDM.zeros[Double](vocabSize, k)
+ termTopicCounts.foreach { case (term, cnts) =>
+ var j = 0
+ while (j < k) {
+ brzTopics(term, j) = cnts(j)
+ j += 1
+ }
+ }
+ Matrices.fromBreeze(brzTopics)
+ }
+
+ override def describeTopics(maxTermsPerTopic: Int): Array[(Array[Int], Array[Double])] = {
+ val numTopics = k
+ // Note: N_k is not needed to find the top terms, but it is needed to normalize weights
+ // to a distribution over terms.
+ val N_k: TopicCounts = globalTopicTotals
+ val topicsInQueues: Array[BoundedPriorityQueue[(Double, Int)]] =
+ graph.vertices.filter(isTermVertex)
+ .mapPartitions { termVertices =>
+ // For this partition, collect the most common terms for each topic in queues:
+ // queues(topic) = queue of (term weight, term index).
+ // Term weights are N_{wk} / N_k.
+ val queues =
+ Array.fill(numTopics)(new BoundedPriorityQueue[(Double, Int)](maxTermsPerTopic))
+ for ((termId, n_wk) <- termVertices) {
+ var topic = 0
+ while (topic < numTopics) {
+ queues(topic) += (n_wk(topic) / N_k(topic) -> index2term(termId.toInt))
+ topic += 1
+ }
+ }
+ Iterator(queues)
+ }.reduce { (q1, q2) =>
+ q1.zip(q2).foreach { case (a, b) => a ++= b}
+ q1
+ }
+ topicsInQueues.map { q =>
+ val (termWeights, terms) = q.toArray.sortBy(-_._1).unzip
+ (terms.toArray, termWeights.toArray)
+ }
+ }
+
+ // TODO
+ // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
+
+ /**
+ * Log likelihood of the observed tokens in the training set,
+ * given the current parameter estimates:
+ * log P(docs | topics, topic distributions for docs, alpha, eta)
+ *
+ * Note:
+ * - This excludes the prior; for that, use [[logPrior]].
+ * - Even with [[logPrior]], this is NOT the same as the data log likelihood given the
+ * hyperparameters.
+ */
+ lazy val logLikelihood: Double = {
+ val eta = topicConcentration
+ val alpha = docConcentration
+ assert(eta > 1.0)
+ assert(alpha > 1.0)
+ val N_k = globalTopicTotals
+ val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0))
+ // Edges: Compute token log probability from phi_{wk}, theta_{kj}.
+ val sendMsg: EdgeContext[TopicCounts, TokenCount, Double] => Unit = (edgeContext) => {
+ val N_wj = edgeContext.attr
+ val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0)
+ val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0)
+ val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
+ val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
+ val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj))
+ edgeContext.sendToDst(tokenLogLikelihood)
+ }
+ graph.aggregateMessages[Double](sendMsg, _ + _)
+ .map(_._2).fold(0.0)(_ + _)
+ }
+
+ /**
+ * Log probability of the current parameter estimate:
+ * log P(topics, topic distributions for docs | alpha, eta)
+ */
+ lazy val logPrior: Double = {
+ val eta = topicConcentration
+ val alpha = docConcentration
+ // Term vertices: Compute phi_{wk}. Use to compute prior log probability.
+ // Doc vertex: Compute theta_{kj}. Use to compute prior log probability.
+ val N_k = globalTopicTotals
+ val smoothed_N_k: TopicCounts = N_k + (vocabSize * (eta - 1.0))
+ val seqOp: (Double, (VertexId, TopicCounts)) => Double = {
+ case (sumPrior: Double, vertex: (VertexId, TopicCounts)) =>
+ if (isTermVertex(vertex)) {
+ val N_wk = vertex._2
+ val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0)
+ val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k
+ (eta - 1.0) * brzSum(phi_wk.map(math.log))
+ } else {
+ val N_kj = vertex._2
+ val smoothed_N_kj: TopicCounts = N_kj + (alpha - 1.0)
+ val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0)
+ (alpha - 1.0) * brzSum(theta_kj.map(math.log))
+ }
+ }
+ graph.vertices.aggregate(0.0)(seqOp, _ + _)
+ }
+
+ /**
+ * For each document in the training set, return the distribution over topics for that document
+ * (i.e., "theta_doc").
+ *
+ * @return RDD of (document ID, topic distribution) pairs
+ */
+ def topicDistributions: RDD[(Long, Vector)] = {
+ graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) =>
+ (docID.toLong, Vectors.fromBreeze(normalize(topicCounts, 1.0)))
+ }
+ }
+
+ // TODO:
+ // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
new file mode 100644
index 000000000000..76672fe51e83
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.Logging
+import org.apache.spark.graphx.Graph
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * This class helps with persisting and checkpointing Graphs.
+ * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as
+ * unpersisting and removing checkpoint files.
+ *
+ * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created,
+ * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are
+ * responsible for materializing the graph to ensure that persisting and checkpointing actually
+ * occur.
+ *
+ * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following:
+ * - Persist new graph (if not yet persisted), and put in queue of persisted graphs.
+ * - Unpersist graphs from queue until there are at most 3 persisted graphs.
+ * - If using checkpointing and the checkpoint interval has been reached,
+ * - Checkpoint the new graph, and put in a queue of checkpointed graphs.
+ * - Remove older checkpoints.
+ *
+ * WARNINGS:
+ * - This class should NOT be copied (since copies may conflict on which Graphs should be
+ * checkpointed).
+ * - This class removes checkpoint files once later graphs have been checkpointed.
+ * However, references to the older graphs will still return isCheckpointed = true.
+ *
+ * Example usage:
+ * {{{
+ * val (graph1, graph2, graph3, ...) = ...
+ * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2)
+ * graph1.vertices.count(); graph1.edges.count()
+ * // persisted: graph1
+ * cp.updateGraph(graph2)
+ * graph2.vertices.count(); graph2.edges.count()
+ * // persisted: graph1, graph2
+ * // checkpointed: graph2
+ * cp.updateGraph(graph3)
+ * graph3.vertices.count(); graph3.edges.count()
+ * // persisted: graph1, graph2, graph3
+ * // checkpointed: graph2
+ * cp.updateGraph(graph4)
+ * graph4.vertices.count(); graph4.edges.count()
+ * // persisted: graph2, graph3, graph4
+ * // checkpointed: graph4
+ * cp.updateGraph(graph5)
+ * graph5.vertices.count(); graph5.edges.count()
+ * // persisted: graph3, graph4, graph5
+ * // checkpointed: graph4
+ * }}}
+ *
+ * @param currentGraph Initial graph
+ * @param checkpointDir The directory for storing checkpoint files
+ * @param checkpointInterval Graphs will be checkpointed at this interval
+ * @tparam VD Vertex descriptor type
+ * @tparam ED Edge descriptor type
+ *
+ * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib.
+ */
+private[mllib] class PeriodicGraphCheckpointer[VD, ED](
+ var currentGraph: Graph[VD, ED],
+ val checkpointDir: Option[String],
+ val checkpointInterval: Int) extends Logging {
+
+ /** FIFO queue of past checkpointed RDDs */
+ private val checkpointQueue = mutable.Queue[Graph[VD, ED]]()
+
+ /** FIFO queue of past persisted RDDs */
+ private val persistedQueue = mutable.Queue[Graph[VD, ED]]()
+
+ /** Number of times [[updateGraph()]] has been called */
+ private var updateCount = 0
+
+ /**
+ * Spark Context for the Graphs given to this checkpointer.
+ * NOTE: This code assumes that only one SparkContext is used for the given graphs.
+ */
+ private val sc = currentGraph.vertices.sparkContext
+
+ // If a checkpoint directory is given, and there's no prior checkpoint directory,
+ // then set the checkpoint directory with the given one.
+ if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) {
+ sc.setCheckpointDir(checkpointDir.get)
+ }
+
+ updateGraph(currentGraph)
+
+ /**
+ * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed.
+ * Since this handles persistence and checkpointing, this should be called before the graph
+ * has been materialized.
+ *
+ * @param newGraph New graph created from previous graphs in the lineage.
+ */
+ def updateGraph(newGraph: Graph[VD, ED]): Unit = {
+ if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) {
+ newGraph.persist()
+ }
+ persistedQueue.enqueue(newGraph)
+ // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class:
+ // Users should call [[updateGraph()]] when a new graph has been created,
+ // before the graph has been materialized.
+ while (persistedQueue.size > 3) {
+ val graphToUnpersist = persistedQueue.dequeue()
+ graphToUnpersist.unpersist(blocking = false)
+ }
+ updateCount += 1
+
+ // Handle checkpointing (after persisting)
+ if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) {
+ // Add new checkpoint before removing old checkpoints.
+ newGraph.checkpoint()
+ checkpointQueue.enqueue(newGraph)
+ // Remove checkpoints before the latest one.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // Delete the oldest checkpoint only if the next checkpoint exists.
+ if (checkpointQueue.get(1).get.isCheckpointed) {
+ removeCheckpointFile()
+ } else {
+ canDelete = false
+ }
+ }
+ }
+ }
+
+ /**
+ * Call this at the end to delete any remaining checkpoint files.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.size > 0) {
+ removeCheckpointFile()
+ }
+ }
+
+ /**
+ * Dequeue the oldest checkpointed Graph, and remove its checkpoint files.
+ * This prints a warning but does not fail if the files cannot be removed.
+ */
+ private def removeCheckpointFile(): Unit = {
+ val old = checkpointQueue.dequeue()
+ // Since the old checkpoint is not deleted by Spark, we manually delete it.
+ val fs = FileSystem.get(sc.hadoopConfiguration)
+ old.getCheckpointFiles.foreach { checkpointFile =>
+ try {
+ fs.delete(new Path(checkpointFile), true)
+ } catch {
+ case e: Exception =>
+ logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " +
+ checkpointFile)
+ }
+ }
+ }
+
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
new file mode 100644
index 000000000000..dc10aa67c7c1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+
+import org.apache.spark.api.java.JavaRDD;
+import scala.Tuple2;
+
+import org.junit.After;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.linalg.Vector;
+
+
+public class JavaLDASuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaLDA");
+ ArrayList> tinyCorpus = new ArrayList>();
+ for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) {
+ tinyCorpus.add(new Tuple2((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(),
+ LDASuite$.MODULE$.tinyCorpus()[i]._2()));
+ }
+ JavaRDD> tmpCorpus = sc.parallelize(tinyCorpus, 2);
+ corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void localLDAModel() {
+ LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics());
+
+ // Check: basic parameters
+ assertEquals(model.k(), tinyK);
+ assertEquals(model.vocabSize(), tinyVocabSize);
+ assertEquals(model.topicsMatrix(), tinyTopics);
+
+ // Check: describeTopics() with all terms
+ Tuple2[] fullTopicSummary = model.describeTopics();
+ assertEquals(fullTopicSummary.length, tinyK);
+ for (int i = 0; i < fullTopicSummary.length; i++) {
+ assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1());
+ assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5);
+ }
+ }
+
+ @Test
+ public void distributedLDAModel() {
+ int k = 3;
+ double topicSmoothing = 1.2;
+ double termSmoothing = 1.2;
+
+ // Train a model
+ LDA lda = new LDA();
+ lda.setK(k)
+ .setDocConcentration(topicSmoothing)
+ .setTopicConcentration(termSmoothing)
+ .setMaxIterations(5)
+ .setSeed(12345);
+
+ DistributedLDAModel model = lda.run(corpus);
+
+ // Check: basic parameters
+ LocalLDAModel localModel = model.toLocal();
+ assertEquals(model.k(), k);
+ assertEquals(localModel.k(), k);
+ assertEquals(model.vocabSize(), tinyVocabSize);
+ assertEquals(localModel.vocabSize(), tinyVocabSize);
+ assertEquals(model.topicsMatrix(), localModel.topicsMatrix());
+
+ // Check: topic summaries
+ Tuple2[] roundedTopicSummary = model.describeTopics();
+ assertEquals(roundedTopicSummary.length, k);
+ Tuple2[] roundedLocalTopicSummary = localModel.describeTopics();
+ assertEquals(roundedLocalTopicSummary.length, k);
+
+ // Check: log probabilities
+ assert(model.logLikelihood() < 0.0);
+ assert(model.logPrior() < 0.0);
+ }
+
+ private static int tinyK = LDASuite$.MODULE$.tinyK();
+ private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
+ private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
+ private static Tuple2[] tinyTopicDescription =
+ LDASuite$.MODULE$.tinyTopicDescription();
+ JavaPairRDD corpus;
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
new file mode 100644
index 000000000000..302d751eb8a9
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class LDASuite extends FunSuite with MLlibTestSparkContext {
+
+ import LDASuite._
+
+ test("LocalLDAModel") {
+ val model = new LocalLDAModel(tinyTopics)
+
+ // Check: basic parameters
+ assert(model.k === tinyK)
+ assert(model.vocabSize === tinyVocabSize)
+ assert(model.topicsMatrix === tinyTopics)
+
+ // Check: describeTopics() with all terms
+ val fullTopicSummary = model.describeTopics()
+ assert(fullTopicSummary.size === tinyK)
+ fullTopicSummary.zip(tinyTopicDescription).foreach {
+ case ((algTerms, algTermWeights), (terms, termWeights)) =>
+ assert(algTerms === terms)
+ assert(algTermWeights === termWeights)
+ }
+
+ // Check: describeTopics() with some terms
+ val smallNumTerms = 3
+ val smallTopicSummary = model.describeTopics(maxTermsPerTopic = smallNumTerms)
+ smallTopicSummary.zip(tinyTopicDescription).foreach {
+ case ((algTerms, algTermWeights), (terms, termWeights)) =>
+ assert(algTerms === terms.slice(0, smallNumTerms))
+ assert(algTermWeights === termWeights.slice(0, smallNumTerms))
+ }
+ }
+
+ test("running and DistributedLDAModel") {
+ val k = 3
+ val topicSmoothing = 1.2
+ val termSmoothing = 1.2
+
+ // Train a model
+ val lda = new LDA()
+ lda.setK(k)
+ .setDocConcentration(topicSmoothing)
+ .setTopicConcentration(termSmoothing)
+ .setMaxIterations(5)
+ .setSeed(12345)
+ val corpus = sc.parallelize(tinyCorpus, 2)
+
+ val model: DistributedLDAModel = lda.run(corpus)
+
+ // Check: basic parameters
+ val localModel = model.toLocal
+ assert(model.k === k)
+ assert(localModel.k === k)
+ assert(model.vocabSize === tinyVocabSize)
+ assert(localModel.vocabSize === tinyVocabSize)
+ assert(model.topicsMatrix === localModel.topicsMatrix)
+
+ // Check: topic summaries
+ // The odd decimal formatting and sorting is a hack to do a robust comparison.
+ val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) =>
+ // cut values to 3 digits after the decimal place
+ terms.zip(termWeights).map { case (term, weight) =>
+ ("%.3f".format(weight).toDouble, term.toInt)
+ }
+ }.sortBy(_.mkString(""))
+ val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
+ // cut values to 3 digits after the decimal place
+ terms.zip(termWeights).map { case (term, weight) =>
+ ("%.3f".format(weight).toDouble, term.toInt)
+ }
+ }.sortBy(_.mkString(""))
+ roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) =>
+ assert(t1 === t2)
+ }
+
+ // Check: per-doc topic distributions
+ val topicDistributions = model.topicDistributions.collect()
+ // Ensure all documents are covered.
+ assert(topicDistributions.size === tinyCorpus.size)
+ assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
+ // Ensure we have proper distributions
+ topicDistributions.foreach { case (docId, topicDistribution) =>
+ assert(topicDistribution.size === tinyK)
+ assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5)
+ }
+
+ // Check: log probabilities
+ assert(model.logLikelihood < 0.0)
+ assert(model.logPrior < 0.0)
+ }
+
+ test("vertex indexing") {
+ // Check vertex ID indexing and conversions.
+ val docIds = Array(0, 1, 2)
+ val docVertexIds = docIds
+ val termIds = Array(0, 1, 2)
+ val termVertexIds = Array(-1, -2, -3)
+ assert(docVertexIds.forall(i => !LDA.isTermVertex((i.toLong, 0))))
+ assert(termIds.map(LDA.term2index) === termVertexIds)
+ assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds)
+ assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0))))
+ }
+}
+
+private[clustering] object LDASuite {
+
+ def tinyK: Int = 3
+ def tinyVocabSize: Int = 5
+ def tinyTopicsAsArray: Array[Array[Double]] = Array(
+ Array[Double](0.1, 0.2, 0.3, 0.4, 0.0), // topic 0
+ Array[Double](0.5, 0.05, 0.05, 0.1, 0.3), // topic 1
+ Array[Double](0.2, 0.2, 0.05, 0.05, 0.5) // topic 2
+ )
+ def tinyTopics: Matrix = new DenseMatrix(numRows = tinyVocabSize, numCols = tinyK,
+ values = tinyTopicsAsArray.fold(Array.empty[Double])(_ ++ _))
+ def tinyTopicDescription: Array[(Array[Int], Array[Double])] = tinyTopicsAsArray.map { topic =>
+ val (termWeights, terms) = topic.zipWithIndex.sortBy(-_._1).unzip
+ (terms.toArray, termWeights.toArray)
+ }
+
+ def tinyCorpus = Array(
+ Vectors.dense(1, 3, 0, 2, 8),
+ Vectors.dense(0, 2, 1, 0, 4),
+ Vectors.dense(2, 3, 12, 3, 1),
+ Vectors.dense(0, 3, 1, 9, 8),
+ Vectors.dense(1, 1, 4, 2, 6)
+ ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+ assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
new file mode 100644
index 000000000000..dac28a369b5b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.{Edge, Graph}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+
+class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext {
+
+ import PeriodicGraphCheckpointerSuite._
+
+ // TODO: Do I need to call count() on the graphs' RDDs?
+
+ test("Persisting") {
+ var graphsToCheck = Seq.empty[GraphToCheck]
+
+ val graph1 = createGraph(sc)
+ val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkPersistence(graphsToCheck, 1)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.updateGraph(graph)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkPersistence(graphsToCheck, iteration)
+ iteration += 1
+ }
+ }
+
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val checkpointInterval = 2
+ var graphsToCheck = Seq.empty[GraphToCheck]
+
+ val graph1 = createGraph(sc)
+ val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval)
+ graph1.edges.count()
+ graph1.vertices.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkCheckpoint(graphsToCheck, 1, checkpointInterval)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.updateGraph(graph)
+ graph.vertices.count()
+ graph.edges.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
+ iteration += 1
+ }
+
+ checkpointer.deleteAllCheckpoints()
+ graphsToCheck.foreach { graph =>
+ confirmCheckpointRemoved(graph.graph)
+ }
+
+ Utils.deleteRecursively(tempDir)
+ }
+}
+
+private object PeriodicGraphCheckpointerSuite {
+
+ case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
+
+ val edges = Seq(
+ Edge[Double](0, 1, 0),
+ Edge[Double](1, 2, 0),
+ Edge[Double](2, 3, 0),
+ Edge[Double](3, 4, 0))
+
+ def createGraph(sc: SparkContext): Graph[Double, Double] = {
+ Graph.fromEdges[Double, Double](sc.parallelize(edges), 0)
+ }
+
+ def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = {
+ graphs.foreach { g =>
+ checkPersistence(g.graph, g.gIndex, iteration)
+ }
+ }
+
+ /**
+ * Check storage level of graph.
+ * @param gIndex Index of graph in order inserted into checkpointer (from 1).
+ * @param iteration Total number of graphs inserted into checkpointer.
+ */
+ def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = {
+ try {
+ if (gIndex + 2 < iteration) {
+ assert(graph.vertices.getStorageLevel == StorageLevel.NONE)
+ assert(graph.edges.getStorageLevel == StorageLevel.NONE)
+ } else {
+ assert(graph.vertices.getStorageLevel != StorageLevel.NONE)
+ assert(graph.edges.getStorageLevel != StorageLevel.NONE)
+ }
+ } catch {
+ case _: AssertionError =>
+ throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" +
+ s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n")
+ }
+ }
+
+ def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = {
+ graphs.reverse.foreach { g =>
+ checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval)
+ }
+ }
+
+ def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = {
+ // Note: We cannot check graph.isCheckpointed since that value is never updated.
+ // Instead, we check for the presence of the checkpoint files.
+ // This test should continue to work even after this graph.isCheckpointed issue
+ // is fixed (though it can then be simplified and not look for the files).
+ val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration)
+ graph.getCheckpointFiles.foreach { checkpointFile =>
+ assert(!fs.exists(new Path(checkpointFile)),
+ "Graph checkpoint file should have been removed")
+ }
+ }
+
+ /**
+ * Check checkpointed status of graph.
+ * @param gIndex Index of graph in order inserted into checkpointer (from 1).
+ * @param iteration Total number of graphs inserted into checkpointer.
+ */
+ def checkCheckpoint(
+ graph: Graph[_, _],
+ gIndex: Int,
+ iteration: Int,
+ checkpointInterval: Int): Unit = {
+ try {
+ if (gIndex % checkpointInterval == 0) {
+ // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph)
+ // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint.
+ if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
+ assert(graph.isCheckpointed, "Graph should be checkpointed")
+ assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files")
+ } else {
+ confirmCheckpointRemoved(graph)
+ }
+ } else {
+ // Graph should never be checkpointed
+ assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
+ assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
+ }
+ } catch {
+ case e: AssertionError =>
+ throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t checkpointInterval = $checkpointInterval\n" +
+ s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" +
+ s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" +
+ s" AssertionError message: ${e.getMessage}")
+ }
+ }
+
+}