diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 9e3880714a79..a6ab989bbdd8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -63,6 +63,9 @@ class PartitionerAwareUnionRDD[T: ClassTag]( require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) + require(rdds.map(_.partitioner.get.numPartitions).toSet.size == 1, + "Parent RDDs have different number of partitions: " + + rdds.map(_.partitioner.get.numPartitions)) override val partitioner = rdds.head.partitioner diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9f7ebae3e9af..7d85071dae49 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -38,7 +38,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, CheckpointingIterator, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, SamplingUtils} @@ -238,11 +238,16 @@ abstract class RDD[T: ClassTag]( * subclasses of RDD. */ final def iterator(split: Partition, context: TaskContext): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { + val iter = if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { computeOrReadCheckpoint(split, context) } + if (checkpointData.isDefined) { + checkpointData.get.getCheckpointIterator(iter, context, split.index) + } else { + iter + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 4f954363bed8..79ecaf65b6e5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -19,9 +19,11 @@ package org.apache.spark.rdd import scala.reflect.ClassTag +import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path import org.apache.spark._ +import org.apache.spark.util.{CheckpointingIterator, SerializableConfiguration} import org.apache.spark.util.SerializableConfiguration /** @@ -44,6 +46,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) import CheckpointState._ + // Because SparkContext is transient in RDD, so we can't get the id and checkpointDir later. + // So keep a copy of the id and checkpointDir. + // The id of RDD + val rddId: Int = rdd.id + + // The path the checkpoint data will write to. + val checkpointDir = rdd.context.checkpointDir + @transient var checkpointPath: Path = null + @transient var fs: FileSystem = null + if (checkpointDir.isDefined) { + checkpointPath = new Path(checkpointDir.get, "rdd-" + rddId) + fs = checkpointPath.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(checkpointPath)) { + throw new SparkException("Failed to create checkpoint path " + checkpointPath) + } + } + + val broadcastedConf = rdd.context.broadcast( + new SerializableConfiguration(rdd.context.hadoopConfiguration)) + // The checkpoint state of the associated RDD. private var cpState = Initialized @@ -66,6 +88,27 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cpFile } + // Get the iterator used to write checkpoint data to HDFS + def getCheckpointIterator( + rddIterator: Iterator[T], + context: TaskContext, + partitionId: Int): Iterator[T] = { + RDDCheckpointData.synchronized { + if (cpState == Initialized) { + // Create the output path for the checkpoint + val path = new Path(checkpointDir.get, "rdd-" + rddId) + CheckpointingIterator[T]( + rddIterator, + path.toString, + broadcastedConf, + partitionId, + context) + } else { + rddIterator + } + } + } + /** * Materialize this RDD and write its content to a reliable DFS. * This is called immediately after the first action invoked on this RDD has completed. @@ -82,25 +125,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } - // Create the output path for the checkpoint - val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException(s"Failed to create checkpoint path $path") - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) + val path = checkpointPath val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { rdd.context.cleaner.foreach { cleaner => - cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rddId) } } - - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + @@ -114,7 +145,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed } - logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}") + logInfo(s"Done checkpointing RDD ${rddId} to $path, new parent is RDD ${newRDD.id}") } def getPartitions: Array[Partition] = RDDCheckpointData.synchronized { diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala new file mode 100644 index 000000000000..5332c6f6c106 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.IOException + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.CheckpointRDD +import org.apache.spark.serializer.SerializationStream + +/** + * Wrapper around an iterator which writes checkpoint data to HDFS while running action on + * a RDD to support checkpointing RDD. + */ +private[spark] class CheckpointingIterator[A: ClassTag]( + values: Iterator[A], + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + partitionId: Int, + context: TaskContext, + blockSize: Int = -1) extends Iterator[A] with Logging { + + private val env = SparkEnv.get + private var fs: FileSystem = null + private val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + private var serializeStream: SerializationStream = null + + private var finalOutputPath: Path = null + private var tempOutputPath: Path = null + + /** + * Initialize this iterator by creating temporary output path and serializer instance. + * + */ + def init(): this.type = { + val outputDir = new Path(path) + fs = outputDir.getFileSystem(broadcastedConf.value.value) + + val finalOutputName = CheckpointRDD.splitIdToFile(partitionId) + finalOutputPath = new Path(outputDir, finalOutputName) + tempOutputPath = + new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptNumber) + + if (fs.exists(tempOutputPath)) { + // There are more than one iterator of the RDD is consumed. + // Don't checkpoint data in this iterator. + doCheckpoint = false + return this + } + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = env.serializer.newInstance() + serializeStream = serializer.serializeStream(fileOutputStream) + this + } + + /** + * Called when this iterator is on the latest element by `hasNext`. + * This method will rename temporary output path to final output path of checkpoint data. + */ + def completion(): Unit = { + if (!doCheckpoint) { + return + } + + serializeStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + logInfo("Deleting tempOutputPath " + tempOutputPath) + fs.delete(tempOutputPath, false) + throw new IOException("Checkpoint failed: failed to save output of task: " + + context.attemptNumber + " and final output path does not exist") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") + fs.delete(tempOutputPath, false) + } + } + } + + def checkpointing(item: A): Unit = { + serializeStream.writeObject(item) + } + + override def next(): A = { + val item = values.next() + if (doCheckpoint) { + checkpointing(item) + } + // If this the latest item, call hasNext will write to final output early. + hasNext + item + } + + private[this] var doCheckpoint = true + private[this] var completed = false + + override def hasNext: Boolean = { + val r = values.hasNext + if (!r && !completed) { + completed = true + completion() + } + r + } +} + +private[spark] object CheckpointingIterator { + def apply[A: ClassTag]( + values: Iterator[A], + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + partitionId: Int, + context: TaskContext, + blockSize: Int = -1) : CheckpointingIterator[A] = { + new CheckpointingIterator[A]( + values, + path, + broadcastedConf, + partitionId, + context, + blockSize).init() + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index cc50e6d79a3e..a0eb2e45ec2a 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -359,7 +359,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * have large size. */ def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + new FatPairRDD(sc.makeRDD(1 to 100, 2), partitioner).mapValues(x => x) } /** diff --git a/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala new file mode 100644 index 000000000000..9df816cda0ee --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File + +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.rdd.CheckpointRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkFunSuite + +class CheckpointingIteratorSuite extends SparkFunSuite with LocalSparkContext with Logging { + var checkpointDir: File = _ + val partitioner = new HashPartitioner(2) + + override def beforeEach() { + super.beforeEach() + checkpointDir = File.createTempFile("temp", "", Utils.createTempDir()) + checkpointDir.delete() + sc = new SparkContext("local", "test") + sc.setCheckpointDir(checkpointDir.toString) + } + + override def afterEach() { + super.afterEach() + Utils.deleteRecursively(checkpointDir) + } + + test("basic test") { + val broadcastedConf = sc.broadcast( + new SerializableConfiguration(sc.hadoopConfiguration)) + + val context = new TaskContextImpl(0, 0, 0, 0, null) + val iter = List(1, 2, 3).iterator + val checkpoingIter = CheckpointingIterator[Int]( + iter, + checkpointDir.toString, + broadcastedConf, + 0, + context) + + assert(checkpoingIter.hasNext) + assert(checkpoingIter.next() === 1) + + assert(checkpoingIter.hasNext) + assert(checkpoingIter.next() === 2) + + assert(checkpoingIter.hasNext) + assert(checkpoingIter.next() === 3) + + // checkpoint data should be written out now. + val outputDir = new Path(checkpointDir.toString) + val finalOutputName = CheckpointRDD.splitIdToFile(0) + val finalPath = new Path(outputDir, finalOutputName) + + val fs = outputDir.getFileSystem(sc.hadoopConfiguration) + assert(fs.exists(finalPath)) + + assert(!checkpoingIter.hasNext) + assert(!checkpoingIter.hasNext) + } + + test("no checkpoing data output if iterator is not consumed to latest element") { + val broadcastedConf = sc.broadcast( + new SerializableConfiguration(sc.hadoopConfiguration)) + + val context = new TaskContextImpl(0, 0, 0, 0, null) + val iter = List(1, 2, 3).iterator + val checkpoingIter = CheckpointingIterator[Int]( + iter, + checkpointDir.toString, + broadcastedConf, + 0, + context) + + assert(checkpoingIter.hasNext) + assert(checkpoingIter.next() === 1) + + assert(checkpoingIter.hasNext) + assert(checkpoingIter.next() === 2) + + // checkpoint data should not be written out yet. + val outputDir = new Path(checkpointDir.toString) + val finalOutputName = CheckpointRDD.splitIdToFile(0) + val finalPath = new Path(outputDir, finalOutputName) + + val fs = outputDir.getFileSystem(sc.hadoopConfiguration) + assert(!fs.exists(finalPath)) + } +}