From 568918c99fa9d42f94daecc8d9759794240db148 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 28 Jul 2015 09:48:14 -0700 Subject: [PATCH 1/4] Generalized PeriodicGraphCheckpointer to PeriodicCheckpointer, with subclasses for RDDs and Graphs. --- .../spark/mllib/clustering/LDAOptimizer.scala | 2 +- .../mllib/impl/PeriodicCheckpointer.scala | 184 ++++++++++++++++++ .../impl/PeriodicGraphCheckpointer.scala | 104 ++-------- .../mllib/impl/PeriodicRDDCheckpointer.scala | 97 +++++++++ .../impl/PeriodicGraphCheckpointerSuite.scala | 8 +- .../impl/PeriodicRDDCheckpointerSuite.scala | 171 ++++++++++++++++ 6 files changed, 471 insertions(+), 95 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index f4170a3d98dd..a4c46c55f891 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -188,7 +188,7 @@ final class EMLDAOptimizer extends LDAOptimizer { // Update the vertex descriptors with the new counts. val newGraph = GraphImpl.fromExistingRDDs(docTopicDistributions, graph.edges) graph = newGraph - graphCheckpointer.updateGraph(newGraph) + graphCheckpointer.update(newGraph) globalTopicTotals = computeGlobalTopicTotals() this } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala new file mode 100644 index 000000000000..4fe99139d6c7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.storage.StorageLevel + + +/** + * This abstraction helps with persisting and checkpointing RDDs and types derived from RDDs + * (such as Graphs and DataFrames). In documentation, we use the phrase "Dataset" to refer to + * the distributed data type (RDD, Graph, etc.). + * + * Specifically, this abstraction automatically handles persisting and (optionally) checkpointing, + * as well as unpersisting and removing checkpoint files. + * + * Users should call update() when a new Dataset has been created, + * before the Dataset has been materialized. After updating [[PeriodicCheckpointer]], users are + * responsible for materializing the Dataset to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new Dataset (if not yet persisted), and put in queue of persisted Datasets. + * - Unpersist Datasets from queue until there are at most 3 persisted Datasets. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new Dataset, and put in a queue of checkpointed Datasets. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which Datasets should be + * checkpointed). + * - This class removes checkpoint files once later Datasets have been checkpointed. + * However, references to the older Datasets will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (data1, data2, data3, ...) = ... + * val cp = new PeriodicCheckpointer(data1, dir, 2) + * data1.count(); + * // persisted: data1 + * cp.update(data2) + * data2.count(); + * // persisted: data1, data2 + * // checkpointed: data2 + * cp.update(data3) + * data3.count(); + * // persisted: data1, data2, data3 + * // checkpointed: data2 + * cp.update(data4) + * data4.count(); + * // persisted: data2, data3, data4 + * // checkpointed: data4 + * cp.update(data5) + * data5.count(); + * // persisted: data3, data4, data5 + * // checkpointed: data4 + * }}} + * + * @param currentData Initial Dataset + * @param checkpointInterval Datasets will be checkpointed at this interval + * @param sc SparkContext for the Datasets given to this checkpointer + * @tparam T Dataset type, such as RDD[Double] + */ +private[mllib] abstract class PeriodicCheckpointer[T]( + var currentData: T, + val checkpointInterval: Int, + val sc: SparkContext) extends Logging { + + /** FIFO queue of past checkpointed Datasets */ + private val checkpointQueue = mutable.Queue[T]() + + /** FIFO queue of past persisted Datasets */ + private val persistedQueue = mutable.Queue[T]() + + /** Number of times [[update()]] has been called */ + private var updateCount = 0 + + update(currentData) + + /** + * Update [[currentData]] with a new Dataset. Handle persistence and checkpointing as needed. + * Since this handles persistence and checkpointing, this should be called before the Dataset + * has been materialized. + * + * @param newData New Dataset created from previous Datasets in the lineage. + */ + def update(newData: T): Unit = { + persist(newData) + persistedQueue.enqueue(newData) + // We try to maintain 2 Datasets in persistedQueue to support the semantics of this class: + // Users should call [[update()]] when a new Dataset has been created, + // before the Dataset has been materialized. + while (persistedQueue.size > 3) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + updateCount += 1 + + // Handle checkpointing (after persisting) + if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { + // Add new checkpoint before removing old checkpoints. + checkpoint(newData) + checkpointQueue.enqueue(newData) + // Remove checkpoints before the latest one. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // Delete the oldest checkpoint only if the next checkpoint exists. + if (isCheckpointed(checkpointQueue.get(1).get)) { + removeCheckpointFile() + } else { + canDelete = false + } + } + } + + currentData = newData + } + + /** Checkpoint the Dataset */ + def checkpoint(data: T): Unit + + /** Return true iff the Dataset is checkpointed */ + def isCheckpointed(data: T): Boolean + + /** + * Persist the Dataset. + * Note: This should handle checking the current [[StorageLevel]] of the Dataset. + */ + def persist(data: T): Unit + + /** Unpersist the Dataset */ + def unpersist(data: T): Unit + + /** Get list of checkpoint files for this given Dataset */ + def getCheckpointFiles(data: T): Iterable[String] + + /** + * Call this at the end to delete any remaining checkpoint files. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.nonEmpty) { + removeCheckpointFile() + } + } + + /** + * Dequeue the oldest checkpointed Dataset, and remove its checkpoint files. + * This prints a warning but does not fail if the files cannot be removed. + */ + private def removeCheckpointFile(): Unit = { + val old = checkpointQueue.dequeue() + // Since the old checkpoint is not deleted by Spark, we manually delete it. + val fs = FileSystem.get(sc.hadoopConfiguration) + getCheckpointFiles(old).foreach { checkpointFile => + try { + fs.delete(new Path(checkpointFile), true) + } catch { + case e: Exception => + logWarning("PeriodicCheckpointer could not remove old checkpoint file: " + + checkpointFile) + } + } + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index 6e5dd119dd65..a4317b182b36 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -17,11 +17,6 @@ package org.apache.spark.mllib.impl -import scala.collection.mutable - -import org.apache.hadoop.fs.{Path, FileSystem} - -import org.apache.spark.Logging import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel @@ -31,12 +26,12 @@ import org.apache.spark.storage.StorageLevel * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as * unpersisting and removing checkpoint files. * - * Users should call [[PeriodicGraphCheckpointer.updateGraph()]] when a new graph has been created, + * Users should call update() when a new graph has been created, * before the graph has been materialized. After updating [[PeriodicGraphCheckpointer]], users are * responsible for materializing the graph to ensure that persisting and checkpointing actually * occur. * - * When [[PeriodicGraphCheckpointer.updateGraph()]] is called, this does the following: + * When update() is called, this does the following: * - Persist new graph (if not yet persisted), and put in queue of persisted graphs. * - Unpersist graphs from queue until there are at most 3 persisted graphs. * - If using checkpointing and the checkpoint interval has been reached, @@ -73,99 +68,30 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param currentGraph Initial graph + * @param initGraph Initial graph * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Generalize this for Graphs and RDDs, and move it out of MLlib. + * TODO: Move this out of MLlib? */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( - var currentGraph: Graph[VD, ED], - val checkpointInterval: Int) extends Logging { - - /** FIFO queue of past checkpointed RDDs */ - private val checkpointQueue = mutable.Queue[Graph[VD, ED]]() - - /** FIFO queue of past persisted RDDs */ - private val persistedQueue = mutable.Queue[Graph[VD, ED]]() + initGraph: Graph[VD, ED], + checkpointInterval: Int) + extends PeriodicCheckpointer[Graph[VD, ED]](initGraph, checkpointInterval, + initGraph.vertices.sparkContext) { - /** Number of times [[updateGraph()]] has been called */ - private var updateCount = 0 + override def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - /** - * Spark Context for the Graphs given to this checkpointer. - * NOTE: This code assumes that only one SparkContext is used for the given graphs. - */ - private val sc = currentGraph.vertices.sparkContext + override def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - updateGraph(currentGraph) - - /** - * Update [[currentGraph]] with a new graph. Handle persistence and checkpointing as needed. - * Since this handles persistence and checkpointing, this should be called before the graph - * has been materialized. - * - * @param newGraph New graph created from previous graphs in the lineage. - */ - def updateGraph(newGraph: Graph[VD, ED]): Unit = { - if (newGraph.vertices.getStorageLevel == StorageLevel.NONE) { - newGraph.persist() - } - persistedQueue.enqueue(newGraph) - // We try to maintain 2 Graphs in persistedQueue to support the semantics of this class: - // Users should call [[updateGraph()]] when a new graph has been created, - // before the graph has been materialized. - while (persistedQueue.size > 3) { - val graphToUnpersist = persistedQueue.dequeue() - graphToUnpersist.unpersist(blocking = false) - } - updateCount += 1 - - // Handle checkpointing (after persisting) - if ((updateCount % checkpointInterval) == 0 && sc.getCheckpointDir.nonEmpty) { - // Add new checkpoint before removing old checkpoints. - newGraph.checkpoint() - checkpointQueue.enqueue(newGraph) - // Remove checkpoints before the latest one. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // Delete the oldest checkpoint only if the next checkpoint exists. - if (checkpointQueue.get(1).get.isCheckpointed) { - removeCheckpointFile() - } else { - canDelete = false - } - } - } - } - - /** - * Call this at the end to delete any remaining checkpoint files. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.size > 0) { - removeCheckpointFile() + override def persist(data: Graph[VD, ED]): Unit = { + if (data.vertices.getStorageLevel == StorageLevel.NONE) { + data.persist() } } - /** - * Dequeue the oldest checkpointed Graph, and remove its checkpoint files. - * This prints a warning but does not fail if the files cannot be removed. - */ - private def removeCheckpointFile(): Unit = { - val old = checkpointQueue.dequeue() - // Since the old checkpoint is not deleted by Spark, we manually delete it. - val fs = FileSystem.get(sc.hadoopConfiguration) - old.getCheckpointFiles.foreach { checkpointFile => - try { - fs.delete(new Path(checkpointFile), true) - } catch { - case e: Exception => - logWarning("PeriodicGraphCheckpointer could not remove old checkpoint file: " + - checkpointFile) - } - } - } + override def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) + override def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = data.getCheckpointFiles } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala new file mode 100644 index 000000000000..84df95b46097 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + + +/** + * This class helps with persisting and checkpointing RDDs. + * Specifically, it automatically handles persisting and (optionally) checkpointing, as well as + * unpersisting and removing checkpoint files. + * + * Users should call update() when a new RDD has been created, + * before the RDD has been materialized. After updating [[PeriodicRDDCheckpointer]], users are + * responsible for materializing the RDD to ensure that persisting and checkpointing actually + * occur. + * + * When update() is called, this does the following: + * - Persist new RDD (if not yet persisted), and put in queue of persisted RDDs. + * - Unpersist RDDs from queue until there are at most 3 persisted RDDs. + * - If using checkpointing and the checkpoint interval has been reached, + * - Checkpoint the new RDD, and put in a queue of checkpointed RDDs. + * - Remove older checkpoints. + * + * WARNINGS: + * - This class should NOT be copied (since copies may conflict on which RDDs should be + * checkpointed). + * - This class removes checkpoint files once later RDDs have been checkpointed. + * However, references to the older RDDs will still return isCheckpointed = true. + * + * Example usage: + * {{{ + * val (rdd1, rdd2, rdd3, ...) = ... + * val cp = new PeriodicRDDCheckpointer(rdd1, dir, 2) + * rdd1.count(); + * // persisted: rdd1 + * cp.update(rdd2) + * rdd2.count(); + * // persisted: rdd1, rdd2 + * // checkpointed: rdd2 + * cp.update(rdd3) + * rdd3.count(); + * // persisted: rdd1, rdd2, rdd3 + * // checkpointed: rdd2 + * cp.update(rdd4) + * rdd4.count(); + * // persisted: rdd2, rdd3, rdd4 + * // checkpointed: rdd4 + * cp.update(rdd5) + * rdd5.count(); + * // persisted: rdd3, rdd4, rdd5 + * // checkpointed: rdd4 + * }}} + * + * @param initRDD Initial RDD + * @param checkpointInterval RDDs will be checkpointed at this interval + * @tparam T RDD element type + * + * TODO: Move this out of MLlib? + */ +private[mllib] class PeriodicRDDCheckpointer[T]( + initRDD: RDD[T], + checkpointInterval: Int) + extends PeriodicCheckpointer[RDD[T]](initRDD, checkpointInterval, initRDD.sparkContext) { + + override def checkpoint(data: RDD[T]): Unit = data.checkpoint() + + override def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + + override def persist(data: RDD[T]): Unit = { + if (data.getStorageLevel == StorageLevel.NONE) { + data.persist() + } + } + + override def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + + override def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + data.getCheckpointFile.map(x => x) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index d34888af2d73..d3da05ddcc0c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -30,8 +30,6 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo import PeriodicGraphCheckpointerSuite._ - // TODO: Do I need to call count() on the graphs' RDDs? - test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] @@ -43,7 +41,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) checkPersistence(graphsToCheck, iteration) iteration += 1 @@ -66,7 +64,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var iteration = 2 while (iteration < 9) { val graph = createGraph(sc) - checkpointer.updateGraph(graph) + checkpointer.update(graph) graph.vertices.count() graph.edges.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) @@ -168,7 +166,7 @@ private object PeriodicGraphCheckpointerSuite { } else { // Graph should never be checkpointed assert(!graph.isCheckpointed, "Graph should never have been checkpointed") - assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files") + assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files") } } catch { case e: AssertionError => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala new file mode 100644 index 000000000000..f85d6691d451 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.impl + +import org.apache.hadoop.fs.{FileSystem, Path} + +import org.apache.spark.{SparkContext, SparkFunSuite} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils + + +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { + + import PeriodicRDDCheckpointerSuite._ + + test("Persisting") { + var rddsToCheck = Seq.empty[RDDToCheck] + + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer(rdd1, 10) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkPersistence(rddsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkPersistence(rddsToCheck, iteration) + iteration += 1 + } + } + + test("Checkpointing") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var rddsToCheck = Seq.empty[RDDToCheck] + sc.setCheckpointDir(path) + val rdd1 = createRDD(sc) + val checkpointer = new PeriodicRDDCheckpointer(rdd1, checkpointInterval) + rdd1.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) + checkCheckpoint(rddsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val rdd = createRDD(sc) + checkpointer.update(rdd) + rdd.count() + rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration) + checkCheckpoint(rddsToCheck, iteration, checkpointInterval) + iteration += 1 + } + + checkpointer.deleteAllCheckpoints() + rddsToCheck.foreach { rdd => + confirmCheckpointRemoved(rdd.rdd) + } + + Utils.deleteRecursively(tempDir) + } +} + +private object PeriodicRDDCheckpointerSuite { + + case class RDDToCheck(rdd: RDD[Double], gIndex: Int) + + def createRDD(sc: SparkContext): RDD[Double] = { + sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0)) + } + + def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = { + rdds.foreach { g => + checkPersistence(g.rdd, g.gIndex, iteration) + } + } + + /** + * Check storage level of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = { + try { + if (gIndex + 2 < iteration) { + assert(rdd.getStorageLevel == StorageLevel.NONE) + } else { + assert(rdd.getStorageLevel != StorageLevel.NONE) + } + } catch { + case _: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n") + } + } + + def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = { + rdds.reverse.foreach { g => + checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval) + } + } + + def confirmCheckpointRemoved(rdd: RDD[_]): Unit = { + // Note: We cannot check rdd.isCheckpointed since that value is never updated. + // Instead, we check for the presence of the checkpoint files. + // This test should continue to work even after this rdd.isCheckpointed issue + // is fixed (though it can then be simplified and not look for the files). + val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration) + rdd.getCheckpointFile.foreach { checkpointFile => + assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed") + } + } + + /** + * Check checkpointed status of rdd. + * @param gIndex Index of rdd in order inserted into checkpointer (from 1). + * @param iteration Total number of rdds inserted into checkpointer. + */ + def checkCheckpoint( + rdd: RDD[_], + gIndex: Int, + iteration: Int, + checkpointInterval: Int): Unit = { + try { + if (gIndex % checkpointInterval == 0) { + // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd) + // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint. + if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) { + assert(rdd.isCheckpointed, "RDD should be checkpointed") + assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files") + } else { + confirmCheckpointRemoved(rdd) + } + } else { + // RDD should never be checkpointed + assert(!rdd.isCheckpointed, "RDD should never have been checkpointed") + assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files") + } + } catch { + case e: AssertionError => + throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" + + s"\t gIndex = $gIndex\n" + + s"\t iteration = $iteration\n" + + s"\t checkpointInterval = $checkpointInterval\n" + + s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" + + s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" + + s" AssertionError message: ${e.getMessage}") + } + } + +} From 0b3dbc0c026902649e3948f6b8476d98bedef792 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 29 Jul 2015 18:59:05 -0700 Subject: [PATCH 2/4] Changed checkpointer constructor not to take initial data. --- .../mllib/impl/PeriodicCheckpointer.scala | 34 ++----------------- .../impl/PeriodicGraphCheckpointer.scala | 11 +++--- .../mllib/impl/PeriodicRDDCheckpointer.scala | 10 +++--- .../impl/PeriodicGraphCheckpointerSuite.scala | 6 ++-- .../impl/PeriodicRDDCheckpointerSuite.scala | 4 +-- 5 files changed, 18 insertions(+), 47 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 4fe99139d6c7..a29bafed8d03 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -51,37 +51,11 @@ import org.apache.spark.storage.StorageLevel * - This class removes checkpoint files once later Datasets have been checkpointed. * However, references to the older Datasets will still return isCheckpointed = true. * - * Example usage: - * {{{ - * val (data1, data2, data3, ...) = ... - * val cp = new PeriodicCheckpointer(data1, dir, 2) - * data1.count(); - * // persisted: data1 - * cp.update(data2) - * data2.count(); - * // persisted: data1, data2 - * // checkpointed: data2 - * cp.update(data3) - * data3.count(); - * // persisted: data1, data2, data3 - * // checkpointed: data2 - * cp.update(data4) - * data4.count(); - * // persisted: data2, data3, data4 - * // checkpointed: data4 - * cp.update(data5) - * data5.count(); - * // persisted: data3, data4, data5 - * // checkpointed: data4 - * }}} - * - * @param currentData Initial Dataset * @param checkpointInterval Datasets will be checkpointed at this interval * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ private[mllib] abstract class PeriodicCheckpointer[T]( - var currentData: T, val checkpointInterval: Int, val sc: SparkContext) extends Logging { @@ -94,10 +68,8 @@ private[mllib] abstract class PeriodicCheckpointer[T]( /** Number of times [[update()]] has been called */ private var updateCount = 0 - update(currentData) - /** - * Update [[currentData]] with a new Dataset. Handle persistence and checkpointing as needed. + * Update with a new Dataset. Handle persistence and checkpointing as needed. * Since this handles persistence and checkpointing, this should be called before the Dataset * has been materialized. * @@ -124,15 +96,13 @@ private[mllib] abstract class PeriodicCheckpointer[T]( var canDelete = true while (checkpointQueue.size > 1 && canDelete) { // Delete the oldest checkpoint only if the next checkpoint exists. - if (isCheckpointed(checkpointQueue.get(1).get)) { + if (isCheckpointed(checkpointQueue.head)) { removeCheckpointFile() } else { canDelete = false } } } - - currentData = newData } /** Checkpoint the Dataset */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index a4317b182b36..bebd495d7608 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.impl +import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel @@ -47,7 +48,7 @@ import org.apache.spark.storage.StorageLevel * Example usage: * {{{ * val (graph1, graph2, graph3, ...) = ... - * val cp = new PeriodicGraphCheckpointer(graph1, dir, 2) + * val cp = new PeriodicGraphCheckpointer(2, sc) * graph1.vertices.count(); graph1.edges.count() * // persisted: graph1 * cp.updateGraph(graph2) @@ -68,7 +69,6 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: graph4 * }}} * - * @param initGraph Initial graph * @param checkpointInterval Graphs will be checkpointed at this interval * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type @@ -76,10 +76,9 @@ import org.apache.spark.storage.StorageLevel * TODO: Move this out of MLlib? */ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( - initGraph: Graph[VD, ED], - checkpointInterval: Int) - extends PeriodicCheckpointer[Graph[VD, ED]](initGraph, checkpointInterval, - initGraph.vertices.sparkContext) { + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { override def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala index 84df95b46097..42191fae74f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.impl +import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -47,7 +48,7 @@ import org.apache.spark.storage.StorageLevel * Example usage: * {{{ * val (rdd1, rdd2, rdd3, ...) = ... - * val cp = new PeriodicRDDCheckpointer(rdd1, dir, 2) + * val cp = new PeriodicRDDCheckpointer(2, sc) * rdd1.count(); * // persisted: rdd1 * cp.update(rdd2) @@ -68,16 +69,15 @@ import org.apache.spark.storage.StorageLevel * // checkpointed: rdd4 * }}} * - * @param initRDD Initial RDD * @param checkpointInterval RDDs will be checkpointed at this interval * @tparam T RDD element type * * TODO: Move this out of MLlib? */ private[mllib] class PeriodicRDDCheckpointer[T]( - initRDD: RDD[T], - checkpointInterval: Int) - extends PeriodicCheckpointer[RDD[T]](initRDD, checkpointInterval, initRDD.sparkContext) { + checkpointInterval: Int, + sc: SparkContext) + extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { override def checkpoint(data: RDD[T]): Unit = data.checkpoint() diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index d3da05ddcc0c..993cc99435fe 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -34,7 +34,8 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var graphsToCheck = Seq.empty[GraphToCheck] val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, 10) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) @@ -55,7 +56,8 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo var graphsToCheck = Seq.empty[GraphToCheck] sc.setCheckpointDir(path) val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala index f85d6691d451..c1c8ff5f5e0e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -34,7 +34,7 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCont var rddsToCheck = Seq.empty[RDDToCheck] val rdd1 = createRDD(sc) - val checkpointer = new PeriodicRDDCheckpointer(rdd1, 10) + val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) checkPersistence(rddsToCheck, 1) @@ -55,7 +55,7 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCont var rddsToCheck = Seq.empty[RDDToCheck] sc.setCheckpointDir(path) val rdd1 = createRDD(sc) - val checkpointer = new PeriodicRDDCheckpointer(rdd1, checkpointInterval) + val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) rdd1.count() rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) checkCheckpoint(rddsToCheck, 1, checkpointInterval) From 32b23b870614efa67ed8f4a6dccf69d92727a40a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 29 Jul 2015 19:33:34 -0700 Subject: [PATCH 3/4] fixed usage of checkpointer in lda --- .../org/apache/spark/mllib/clustering/LDAOptimizer.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index a4c46c55f891..29ed990ebdea 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -142,8 +142,8 @@ final class EMLDAOptimizer extends LDAOptimizer { this.k = k this.vocabSize = docs.take(1).head._2.size this.checkpointInterval = lda.getCheckpointInterval - this.graphCheckpointer = new - PeriodicGraphCheckpointer[TopicCounts, TokenCount](graph, checkpointInterval) + this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( + checkpointInterval, graph.vertices.sparkContext) this.globalTopicTotals = computeGlobalTopicTotals() this } From d41902c085504ca30714d7665ee924c2c2a7fd91 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 29 Jul 2015 23:03:44 -0700 Subject: [PATCH 4/4] Oops, forgot to update an extra time in the checkpointer tests, after the last commit. I'll fix that. I'll also make some of the checkpointer methods protected, which I should have done before. --- .../spark/mllib/impl/PeriodicCheckpointer.scala | 10 +++++----- .../spark/mllib/impl/PeriodicGraphCheckpointer.scala | 12 +++++++----- .../spark/mllib/impl/PeriodicRDDCheckpointer.scala | 10 +++++----- .../mllib/impl/PeriodicGraphCheckpointerSuite.scala | 2 ++ .../mllib/impl/PeriodicRDDCheckpointerSuite.scala | 2 ++ 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index a29bafed8d03..72d3aabc9b1f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -106,22 +106,22 @@ private[mllib] abstract class PeriodicCheckpointer[T]( } /** Checkpoint the Dataset */ - def checkpoint(data: T): Unit + protected def checkpoint(data: T): Unit /** Return true iff the Dataset is checkpointed */ - def isCheckpointed(data: T): Boolean + protected def isCheckpointed(data: T): Boolean /** * Persist the Dataset. * Note: This should handle checking the current [[StorageLevel]] of the Dataset. */ - def persist(data: T): Unit + protected def persist(data: T): Unit /** Unpersist the Dataset */ - def unpersist(data: T): Unit + protected def unpersist(data: T): Unit /** Get list of checkpoint files for this given Dataset */ - def getCheckpointFiles(data: T): Iterable[String] + protected def getCheckpointFiles(data: T): Iterable[String] /** * Call this at the end to delete any remaining checkpoint files. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala index bebd495d7608..11a059536c50 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala @@ -80,17 +80,19 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( sc: SparkContext) extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { - override def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() + override protected def checkpoint(data: Graph[VD, ED]): Unit = data.checkpoint() - override def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed + override protected def isCheckpointed(data: Graph[VD, ED]): Boolean = data.isCheckpointed - override def persist(data: Graph[VD, ED]): Unit = { + override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { data.persist() } } - override def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) + override protected def unpersist(data: Graph[VD, ED]): Unit = data.unpersist(blocking = false) - override def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = data.getCheckpointFiles + override protected def getCheckpointFiles(data: Graph[VD, ED]): Iterable[String] = { + data.getCheckpointFiles + } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala index 42191fae74f7..f31ed2aa90a6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala @@ -79,19 +79,19 @@ private[mllib] class PeriodicRDDCheckpointer[T]( sc: SparkContext) extends PeriodicCheckpointer[RDD[T]](checkpointInterval, sc) { - override def checkpoint(data: RDD[T]): Unit = data.checkpoint() + override protected def checkpoint(data: RDD[T]): Unit = data.checkpoint() - override def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed + override protected def isCheckpointed(data: RDD[T]): Boolean = data.isCheckpointed - override def persist(data: RDD[T]): Unit = { + override protected def persist(data: RDD[T]): Unit = { if (data.getStorageLevel == StorageLevel.NONE) { data.persist() } } - override def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) + override protected def unpersist(data: RDD[T]): Unit = data.unpersist(blocking = false) - override def getCheckpointFiles(data: RDD[T]): Iterable[String] = { + override protected def getCheckpointFiles(data: RDD[T]): Iterable[String] = { data.getCheckpointFile.map(x => x) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala index 993cc99435fe..e331c7598918 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala @@ -36,6 +36,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo val graph1 = createGraph(sc) val checkpointer = new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) checkPersistence(graphsToCheck, 1) @@ -58,6 +59,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo val graph1 = createGraph(sc) val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) graph1.edges.count() graph1.vertices.count() graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala index c1c8ff5f5e0e..b2a459a68b5f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala @@ -35,6 +35,7 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCont val rdd1 = createRDD(sc) val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext) + checkpointer.update(rdd1) rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) checkPersistence(rddsToCheck, 1) @@ -56,6 +57,7 @@ class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCont sc.setCheckpointDir(path) val rdd1 = createRDD(sc) val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext) + checkpointer.update(rdd1) rdd1.count() rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1) checkCheckpoint(rddsToCheck, 1, checkpointInterval)