From a5c03435a206ea8239b176fc72bc9bbb6ff8e318 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 31 Mar 2016 22:55:09 +0800 Subject: [PATCH 1/2] LDA should support disable checkpoint --- .../spark/mllib/clustering/LDAOptimizer.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 7491ab0d51ca..b397fe391633 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -142,9 +142,11 @@ final class EMLDAOptimizer extends LDAOptimizer { this.k = k this.vocabSize = docs.take(1).head._2.size this.checkpointInterval = lda.getCheckpointInterval - this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - checkpointInterval, graph.vertices.sparkContext) - this.graphCheckpointer.update(this.graph) + if (this.checkpointInterval != -1) { + this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + checkpointInterval, graph.vertices.sparkContext) + this.graphCheckpointer.update(this.graph) + } this.globalTopicTotals = computeGlobalTopicTotals() this } @@ -189,7 +191,9 @@ final class EMLDAOptimizer extends LDAOptimizer { // Update the vertex descriptors with the new counts. val newGraph = Graph(docTopicDistributions, graph.edges) graph = newGraph - graphCheckpointer.update(newGraph) + if (this.checkpointInterval != -1) { + graphCheckpointer.update(newGraph) + } globalTopicTotals = computeGlobalTopicTotals() this } @@ -208,7 +212,9 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") - this.graphCheckpointer.deleteAllCheckpoints() + if (this.checkpointInterval != -1) { + this.graphCheckpointer.deleteAllCheckpoints() + } // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in // LDAModel.toLocal conversion new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, From 4af96d8fadbf1a9290756cd60e422d26a7fdc915 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 7 Apr 2016 17:54:04 +0800 Subject: [PATCH 2/2] Disable checkpoint in PeriodicCheckpointer --- .../spark/mllib/clustering/LDAOptimizer.scala | 16 +++++----------- .../spark/mllib/impl/PeriodicCheckpointer.scala | 6 ++++-- .../mllib/impl/PeriodicGraphCheckpointer.scala | 3 ++- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index b397fe391633..7491ab0d51ca 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -142,11 +142,9 @@ final class EMLDAOptimizer extends LDAOptimizer { this.k = k this.vocabSize = docs.take(1).head._2.size this.checkpointInterval = lda.getCheckpointInterval - if (this.checkpointInterval != -1) { - this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( - checkpointInterval, graph.vertices.sparkContext) - this.graphCheckpointer.update(this.graph) - } + this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + checkpointInterval, graph.vertices.sparkContext) + this.graphCheckpointer.update(this.graph) this.globalTopicTotals = computeGlobalTopicTotals() this } @@ -191,9 +189,7 @@ final class EMLDAOptimizer extends LDAOptimizer { // Update the vertex descriptors with the new counts. val newGraph = Graph(docTopicDistributions, graph.edges) graph = newGraph - if (this.checkpointInterval != -1) { - graphCheckpointer.update(newGraph) - } + graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() this } @@ -212,9 +208,7 @@ final class EMLDAOptimizer extends LDAOptimizer { override private[clustering] def getLDAModel(iterationTimes: Array[Double]): LDAModel = { require(graph != null, "graph is null, EMLDAOptimizer not initialized.") - if (this.checkpointInterval != -1) { - this.graphCheckpointer.deleteAllCheckpoints() - } + this.graphCheckpointer.deleteAllCheckpoints() // The constructor's default arguments assume gammaShape = 100 to ensure equivalence in // LDAModel.toLocal conversion new DistributedLDAModel(this.graph, this.globalTopicTotals, this.k, this.vocabSize, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 391f89aa1489..f6858c14c58c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -52,7 +52,8 @@ import org.apache.spark.storage.StorageLevel * - This class removes checkpoint files once later Datasets have been checkpointed. * However, references to the older Datasets will still return isCheckpointed = true. * - * @param checkpointInterval Datasets will be checkpointed at this interval + * @param checkpointInterval Datasets will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ @@ -89,7 +90,8 @@ private[mllib] abstract class PeriodicCheckpointer[T]( updateCount += 1 // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + if (checkpointInterval != -1 && (updateCount % checkpointInterval) == 0 + && sc.getCheckpointDir.nonEmpty) { // Add new checkpoint before removing old checkpoints. checkpoint(newData) checkpointQueue.enqueue(newData) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 11a059536c50..20db6084d0e0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -69,7 +69,8 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param checkpointInterval Graphs will be checkpointed at this interval + * @param checkpointInterval Graphs will be checkpointed at this interval. + * If this interval was set as -1, then checkpointing will be disabled. * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type *