From d8635168a9f01e3be2b53a27cc5918a1a0ed1612 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 Jun 2015 00:50:45 +0800 Subject: [PATCH 1/4] Add CheckpointingIterator to optimize checkpointing. --- .../spark/rdd/PartitionerAwareUnionRDD.scala | 3 + .../main/scala/org/apache/spark/rdd/RDD.scala | 9 +- .../apache/spark/rdd/RDDCheckpointData.scala | 55 +++++-- .../spark/util/CheckpointingIterator.scala | 141 ++++++++++++++++++ .../org/apache/spark/CheckpointSuite.scala | 2 +- 5 files changed, 193 insertions(+), 17 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 9e3880714a79..a6ab989bbdd8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -63,6 +63,9 @@ class PartitionerAwareUnionRDD[T: ClassTag]( require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) + require(rdds.map(_.partitioner.get.numPartitions).toSet.size == 1, + "Parent RDDs have different number of partitions: " + + rdds.map(_.partitioner.get.numPartitions)) override val partitioner = rdds.head.partitioner diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 10610f4b6f1f..caae592284ee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -38,7 +38,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, CheckpointingIterator, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, SamplingUtils} @@ -238,11 +238,16 @@ abstract class RDD[T: ClassTag]( * subclasses of RDD. */ final def iterator(split: Partition, context: TaskContext): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { + val iter = if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { computeOrReadCheckpoint(split, context) } + if (checkpointData.isDefined) { + checkpointData.get.getCheckpointIterator(iter, context, split.index) + } else { + iter + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index acbd31aacdf5..e9662913b2a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{CheckpointingIterator, SerializableConfiguration} /** * Enumeration to manage state transitions of an RDD through checkpointing @@ -45,6 +45,22 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) import CheckpointState._ + // Because SparkContext is transient in RDD, so we can't get the id and checkpointDir later. + // So keep a copy of the id and checkpointDir. + // The id of RDD + val rddId: Int = rdd.id + + // The path the checkpoint data will write to. + val checkpointDir = rdd.context.checkpointDir.get + @transient val checkpointPath = new Path(checkpointDir, "rdd-" + rddId) + @transient val fs = checkpointPath.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(checkpointPath)) { + throw new SparkException("Failed to create checkpoint path " + checkpointPath) + } + + val broadcastedConf = rdd.context.broadcast( + new SerializableConfiguration(rdd.context.hadoopConfiguration)) + // The checkpoint state of the associated RDD. var cpState = Initialized @@ -71,6 +87,27 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) RDDCheckpointData.synchronized { cpFile } } + // Get the iterator used to write checkpoint data to HDFS + def getCheckpointIterator( + rddIterator: Iterator[T], + context: TaskContext, + partitionId: Int): Iterator[T] = { + RDDCheckpointData.synchronized { + if (cpState == MarkedForCheckpoint) { + // Create the output path for the checkpoint + val path = new Path(checkpointDir, "rdd-" + rddId) + CheckpointingIterator[T, Iterator[T]]( + rddIterator, + path.toString, + broadcastedConf, + partitionId, + context) + } else { + rddIterator + } + } + } + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. def doCheckpoint() { // If it is marked for checkpointing AND checkpointing is not already in progress, @@ -83,23 +120,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } - // Create the output path for the checkpoint - val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) + val path = checkpointPath val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { rdd.context.cleaner.foreach { cleaner => - cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rddId) } } - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + @@ -113,7 +140,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) + logInfo("Done checkpointing RDD " + rddId + " to " + path + ", new parent is RDD " + newRDD.id) } // Get preferred location of a split after checkpointing diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala new file mode 100644 index 000000000000..e7599ee2d261 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.IOException + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.CheckpointRDD +import org.apache.spark.serializer.SerializationStream + +/** + * Wrapper around an iterator which writes checkpoint data to HDFS while running action on + * a RDD to support checkpointing RDD. + */ +private[spark] class CheckpointingIterator[A: ClassTag, +I <: Iterator[A]]( + sub: I, + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + partitionId: Int, + context: TaskContext, + blockSize: Int = -1) extends Iterator[A] with Logging { + + val env = SparkEnv.get + var fs: FileSystem = null + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + var serializeStream: SerializationStream = null + + var finalOutputPath: Path = null + var tempOutputPath: Path = null + + def init(): this.type = { + val outputDir = new Path(path) + fs = outputDir.getFileSystem(broadcastedConf.value.value) + + val finalOutputName = CheckpointRDD.splitIdToFile(partitionId) + finalOutputPath = new Path(outputDir, finalOutputName) + tempOutputPath = + new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptNumber) + + if (fs.exists(tempOutputPath)) { + // There are more than one iterator of the RDD is consumed. + // Don't checkpoint data in this iterator. + doCheckpoint = false + return this + } + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = env.serializer.newInstance() + serializeStream = serializer.serializeStream(fileOutputStream) + this + } + + def completion(): Unit = { + if (!doCheckpoint) { + return + } + + serializeStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + logInfo("Deleting tempOutputPath " + tempOutputPath) + fs.delete(tempOutputPath, false) + throw new IOException("Checkpoint failed: failed to save output of task: " + + context.attemptNumber + " and final output path does not exist") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") + fs.delete(tempOutputPath, false) + } + } + } + + def checkpointing(item: A): Unit = { + serializeStream.writeObject(item) + } + + override def next(): A = { + val item = sub.next() + if (doCheckpoint) { + checkpointing(item) + } + item + } + + private[this] var doCheckpoint = true + private[this] var completed = false + + override def hasNext: Boolean = { + val r = sub.hasNext + if (!r && !completed) { + completed = true + completion() + } + r + } +} + +private[spark] object CheckpointingIterator { + def apply[A: ClassTag, I <: Iterator[A]]( + sub: I, + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + partitionId: Int, + context: TaskContext, + blockSize: Int = -1) : CheckpointingIterator[A, I] = { + new CheckpointingIterator[A, I]( + sub, + path, + broadcastedConf, + partitionId, + context, + blockSize).init() + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d1761a48babb..e27162c228e9 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -359,7 +359,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * have large size. */ def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + new FatPairRDD(sc.makeRDD(1 to 100, 2), partitioner).mapValues(x => x) } /** From 1a3055ea6ca67fcb23c1188b7c3344c726b054b3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 Jun 2015 01:10:52 +0800 Subject: [PATCH 2/4] Fix scala style. --- .../org/apache/spark/util/CheckpointingIterator.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala index e7599ee2d261..b80de6708942 100644 --- a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -40,16 +40,16 @@ private[spark] class CheckpointingIterator[A: ClassTag, +I <: Iterator[A]]( partitionId: Int, context: TaskContext, blockSize: Int = -1) extends Iterator[A] with Logging { - + val env = SparkEnv.get - var fs: FileSystem = null + var fs: FileSystem = null val bufferSize = env.conf.getInt("spark.buffer.size", 65536) var serializeStream: SerializationStream = null - + var finalOutputPath: Path = null var tempOutputPath: Path = null - def init(): this.type = { + def init(): this.type = { val outputDir = new Path(path) fs = outputDir.getFileSystem(broadcastedConf.value.value) @@ -100,7 +100,7 @@ private[spark] class CheckpointingIterator[A: ClassTag, +I <: Iterator[A]]( def checkpointing(item: A): Unit = { serializeStream.writeObject(item) } - + override def next(): A = { val item = sub.next() if (doCheckpoint) { From 3c5b203fd2b85f4110795a1fc6ca3e289ca0d837 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 Jun 2015 17:49:01 +0800 Subject: [PATCH 3/4] Write checkpoint data to disk if it is at the end of iterator. --- .../apache/spark/rdd/RDDCheckpointData.scala | 17 +++++++++++------ .../spark/util/CheckpointingIterator.scala | 2 ++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index e9662913b2a6..f75f3736e1f4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag +import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path import org.apache.spark._ @@ -51,11 +52,15 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) val rddId: Int = rdd.id // The path the checkpoint data will write to. - val checkpointDir = rdd.context.checkpointDir.get - @transient val checkpointPath = new Path(checkpointDir, "rdd-" + rddId) - @transient val fs = checkpointPath.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(checkpointPath)) { - throw new SparkException("Failed to create checkpoint path " + checkpointPath) + val checkpointDir = rdd.context.checkpointDir + @transient var checkpointPath: Path = null + @transient var fs: FileSystem = null + if (checkpointDir.isDefined) { + checkpointPath = new Path(checkpointDir.get, "rdd-" + rddId) + fs = checkpointPath.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(checkpointPath)) { + throw new SparkException("Failed to create checkpoint path " + checkpointPath) + } } val broadcastedConf = rdd.context.broadcast( @@ -95,7 +100,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) RDDCheckpointData.synchronized { if (cpState == MarkedForCheckpoint) { // Create the output path for the checkpoint - val path = new Path(checkpointDir, "rdd-" + rddId) + val path = new Path(checkpointDir.get, "rdd-" + rddId) CheckpointingIterator[T, Iterator[T]]( rddIterator, path.toString, diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala index b80de6708942..61f699b3c127 100644 --- a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -106,6 +106,8 @@ private[spark] class CheckpointingIterator[A: ClassTag, +I <: Iterator[A]]( if (doCheckpoint) { checkpointing(item) } + // If this the latest item, call hasNext will write to final output early. + hasNext item } From 2f43ff3c6d1a4a428e5cbe8f4a4e4347274fc95c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 4 Jul 2015 00:23:16 +0800 Subject: [PATCH 4/4] Fix scala style. --- .../org/apache/spark/util/CheckpointingIteratorSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala index 96a0e9c682fa..9df816cda0ee 100644 --- a/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala @@ -51,7 +51,7 @@ class CheckpointingIteratorSuite extends SparkFunSuite with LocalSparkContext wi val context = new TaskContextImpl(0, 0, 0, 0, null) val iter = List(1, 2, 3).iterator val checkpoingIter = CheckpointingIterator[Int]( - iter, + iter, checkpointDir.toString, broadcastedConf, 0, @@ -85,7 +85,7 @@ class CheckpointingIteratorSuite extends SparkFunSuite with LocalSparkContext wi val context = new TaskContextImpl(0, 0, 0, 0, null) val iter = List(1, 2, 3).iterator val checkpoingIter = CheckpointingIterator[Int]( - iter, + iter, checkpointDir.toString, broadcastedConf, 0,