From d8635168a9f01e3be2b53a27cc5918a1a0ed1612 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 Jun 2015 00:50:45 +0800 Subject: [PATCH 1/8] Add CheckpointingIterator to optimize checkpointing. --- .../spark/rdd/PartitionerAwareUnionRDD.scala | 3 + .../main/scala/org/apache/spark/rdd/RDD.scala | 9 +- .../apache/spark/rdd/RDDCheckpointData.scala | 55 +++++-- .../spark/util/CheckpointingIterator.scala | 141 ++++++++++++++++++ .../org/apache/spark/CheckpointSuite.scala | 2 +- 5 files changed, 193 insertions(+), 17 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 9e3880714a79..a6ab989bbdd8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -63,6 +63,9 @@ class PartitionerAwareUnionRDD[T: ClassTag]( require(rdds.forall(_.partitioner.isDefined)) require(rdds.flatMap(_.partitioner).toSet.size == 1, "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) + require(rdds.map(_.partitioner.get.numPartitions).toSet.size == 1, + "Parent RDDs have different number of partitions: " + + rdds.map(_.partitioner.get.numPartitions)) override val partitioner = rdds.head.partitioner diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 10610f4b6f1f..caae592284ee 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -38,7 +38,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, CheckpointingIterator, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, SamplingUtils} @@ -238,11 +238,16 @@ abstract class RDD[T: ClassTag]( * subclasses of RDD. */ final def iterator(split: Partition, context: TaskContext): Iterator[T] = { - if (storageLevel != StorageLevel.NONE) { + val iter = if (storageLevel != StorageLevel.NONE) { SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel) } else { computeOrReadCheckpoint(split, context) } + if (checkpointData.isDefined) { + checkpointData.get.getCheckpointIterator(iter, context, split.index) + } else { + iter + } } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index acbd31aacdf5..e9662913b2a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask} -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{CheckpointingIterator, SerializableConfiguration} /** * Enumeration to manage state transitions of an RDD through checkpointing @@ -45,6 +45,22 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) import CheckpointState._ + // Because SparkContext is transient in RDD, so we can't get the id and checkpointDir later. + // So keep a copy of the id and checkpointDir. + // The id of RDD + val rddId: Int = rdd.id + + // The path the checkpoint data will write to. + val checkpointDir = rdd.context.checkpointDir.get + @transient val checkpointPath = new Path(checkpointDir, "rdd-" + rddId) + @transient val fs = checkpointPath.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(checkpointPath)) { + throw new SparkException("Failed to create checkpoint path " + checkpointPath) + } + + val broadcastedConf = rdd.context.broadcast( + new SerializableConfiguration(rdd.context.hadoopConfiguration)) + // The checkpoint state of the associated RDD. var cpState = Initialized @@ -71,6 +87,27 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) RDDCheckpointData.synchronized { cpFile } } + // Get the iterator used to write checkpoint data to HDFS + def getCheckpointIterator( + rddIterator: Iterator[T], + context: TaskContext, + partitionId: Int): Iterator[T] = { + RDDCheckpointData.synchronized { + if (cpState == MarkedForCheckpoint) { + // Create the output path for the checkpoint + val path = new Path(checkpointDir, "rdd-" + rddId) + CheckpointingIterator[T, Iterator[T]]( + rddIterator, + path.toString, + broadcastedConf, + partitionId, + context) + } else { + rddIterator + } + } + } + // Do the checkpointing of the RDD. Called after the first job using that RDD is over. def doCheckpoint() { // If it is marked for checkpointing AND checkpointing is not already in progress, @@ -83,23 +120,13 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } - // Create the output path for the checkpoint - val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) - } - - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) + val path = checkpointPath val newRDD = new CheckpointRDD[T](rdd.context, path.toString) if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) { rdd.context.cleaner.foreach { cleaner => - cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id) + cleaner.registerRDDCheckpointDataForCleanup(newRDD, rddId) } } - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.length + ") has different " + @@ -113,7 +140,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) + logInfo("Done checkpointing RDD " + rddId + " to " + path + ", new parent is RDD " + newRDD.id) } // Get preferred location of a split after checkpointing diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala new file mode 100644 index 000000000000..e7599ee2d261 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -0,0 +1,141 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.IOException + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path + +import org.apache.spark._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.CheckpointRDD +import org.apache.spark.serializer.SerializationStream + +/** + * Wrapper around an iterator which writes checkpoint data to HDFS while running action on + * a RDD to support checkpointing RDD. + */ +private[spark] class CheckpointingIterator[A: ClassTag, +I <: Iterator[A]]( + sub: I, + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + partitionId: Int, + context: TaskContext, + blockSize: Int = -1) extends Iterator[A] with Logging { + + val env = SparkEnv.get + var fs: FileSystem = null + val bufferSize = env.conf.getInt("spark.buffer.size", 65536) + var serializeStream: SerializationStream = null + + var finalOutputPath: Path = null + var tempOutputPath: Path = null + + def init(): this.type = { + val outputDir = new Path(path) + fs = outputDir.getFileSystem(broadcastedConf.value.value) + + val finalOutputName = CheckpointRDD.splitIdToFile(partitionId) + finalOutputPath = new Path(outputDir, finalOutputName) + tempOutputPath = + new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptNumber) + + if (fs.exists(tempOutputPath)) { + // There are more than one iterator of the RDD is consumed. + // Don't checkpoint data in this iterator. + doCheckpoint = false + return this + } + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = env.serializer.newInstance() + serializeStream = serializer.serializeStream(fileOutputStream) + this + } + + def completion(): Unit = { + if (!doCheckpoint) { + return + } + + serializeStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.exists(finalOutputPath)) { + logInfo("Deleting tempOutputPath " + tempOutputPath) + fs.delete(tempOutputPath, false) + throw new IOException("Checkpoint failed: failed to save output of task: " + + context.attemptNumber + " and final output path does not exist") + } else { + // Some other copy of this task must've finished before us and renamed it + logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") + fs.delete(tempOutputPath, false) + } + } + } + + def checkpointing(item: A): Unit = { + serializeStream.writeObject(item) + } + + override def next(): A = { + val item = sub.next() + if (doCheckpoint) { + checkpointing(item) + } + item + } + + private[this] var doCheckpoint = true + private[this] var completed = false + + override def hasNext: Boolean = { + val r = sub.hasNext + if (!r && !completed) { + completed = true + completion() + } + r + } +} + +private[spark] object CheckpointingIterator { + def apply[A: ClassTag, I <: Iterator[A]]( + sub: I, + path: String, + broadcastedConf: Broadcast[SerializableConfiguration], + partitionId: Int, + context: TaskContext, + blockSize: Int = -1) : CheckpointingIterator[A, I] = { + new CheckpointingIterator[A, I]( + sub, + path, + broadcastedConf, + partitionId, + context, + blockSize).init() + } +} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index d1761a48babb..e27162c228e9 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -359,7 +359,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging * have large size. */ def generateFatPairRDD(): RDD[(Int, Int)] = { - new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x) + new FatPairRDD(sc.makeRDD(1 to 100, 2), partitioner).mapValues(x => x) } /** From 1a3055ea6ca67fcb23c1188b7c3344c726b054b3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 Jun 2015 01:10:52 +0800 Subject: [PATCH 2/8] Fix scala style. --- .../org/apache/spark/util/CheckpointingIterator.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala index e7599ee2d261..b80de6708942 100644 --- a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -40,16 +40,16 @@ private[spark] class CheckpointingIterator[A: ClassTag, +I <: Iterator[A]]( partitionId: Int, context: TaskContext, blockSize: Int = -1) extends Iterator[A] with Logging { - + val env = SparkEnv.get - var fs: FileSystem = null + var fs: FileSystem = null val bufferSize = env.conf.getInt("spark.buffer.size", 65536) var serializeStream: SerializationStream = null - + var finalOutputPath: Path = null var tempOutputPath: Path = null - def init(): this.type = { + def init(): this.type = { val outputDir = new Path(path) fs = outputDir.getFileSystem(broadcastedConf.value.value) @@ -100,7 +100,7 @@ private[spark] class CheckpointingIterator[A: ClassTag, +I <: Iterator[A]]( def checkpointing(item: A): Unit = { serializeStream.writeObject(item) } - + override def next(): A = { val item = sub.next() if (doCheckpoint) { From 3c5b203fd2b85f4110795a1fc6ca3e289ca0d837 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 26 Jun 2015 17:49:01 +0800 Subject: [PATCH 3/8] Write checkpoint data to disk if it is at the end of iterator. --- .../apache/spark/rdd/RDDCheckpointData.scala | 17 +++++++++++------ .../spark/util/CheckpointingIterator.scala | 2 ++ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index e9662913b2a6..f75f3736e1f4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -19,6 +19,7 @@ package org.apache.spark.rdd import scala.reflect.ClassTag +import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path import org.apache.spark._ @@ -51,11 +52,15 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) val rddId: Int = rdd.id // The path the checkpoint data will write to. - val checkpointDir = rdd.context.checkpointDir.get - @transient val checkpointPath = new Path(checkpointDir, "rdd-" + rddId) - @transient val fs = checkpointPath.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(checkpointPath)) { - throw new SparkException("Failed to create checkpoint path " + checkpointPath) + val checkpointDir = rdd.context.checkpointDir + @transient var checkpointPath: Path = null + @transient var fs: FileSystem = null + if (checkpointDir.isDefined) { + checkpointPath = new Path(checkpointDir.get, "rdd-" + rddId) + fs = checkpointPath.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(checkpointPath)) { + throw new SparkException("Failed to create checkpoint path " + checkpointPath) + } } val broadcastedConf = rdd.context.broadcast( @@ -95,7 +100,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) RDDCheckpointData.synchronized { if (cpState == MarkedForCheckpoint) { // Create the output path for the checkpoint - val path = new Path(checkpointDir, "rdd-" + rddId) + val path = new Path(checkpointDir.get, "rdd-" + rddId) CheckpointingIterator[T, Iterator[T]]( rddIterator, path.toString, diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala index b80de6708942..61f699b3c127 100644 --- a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -106,6 +106,8 @@ private[spark] class CheckpointingIterator[A: ClassTag, +I <: Iterator[A]]( if (doCheckpoint) { checkpointing(item) } + // If this the latest item, call hasNext will write to final output early. + hasNext item } From 2f43ff3c6d1a4a428e5cbe8f4a4e4347274fc95c Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 4 Jul 2015 00:23:16 +0800 Subject: [PATCH 4/8] Fix scala style. --- .../org/apache/spark/util/CheckpointingIteratorSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala index 96a0e9c682fa..9df816cda0ee 100644 --- a/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala @@ -51,7 +51,7 @@ class CheckpointingIteratorSuite extends SparkFunSuite with LocalSparkContext wi val context = new TaskContextImpl(0, 0, 0, 0, null) val iter = List(1, 2, 3).iterator val checkpoingIter = CheckpointingIterator[Int]( - iter, + iter, checkpointDir.toString, broadcastedConf, 0, @@ -85,7 +85,7 @@ class CheckpointingIteratorSuite extends SparkFunSuite with LocalSparkContext wi val context = new TaskContextImpl(0, 0, 0, 0, null) val iter = List(1, 2, 3).iterator val checkpoingIter = CheckpointingIterator[Int]( - iter, + iter, checkpointDir.toString, broadcastedConf, 0, From 647162fd67d745108ee4816c13d38b71bc71cd59 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 2 Nov 2015 17:33:22 -0800 Subject: [PATCH 5/8] Fix the corner cases in CheckpointingIterator --- .../main/scala/org/apache/spark/rdd/RDD.scala | 10 +- .../spark/rdd/ReliableCheckpointRDD.scala | 2 +- .../spark/rdd/ReliableRDDCheckpointData.scala | 62 ++++- .../spark/util/CheckpointingIterator.scala | 218 ++++++++++++------ .../org/apache/spark/CheckpointSuite.scala | 32 +++ .../util/CheckpointingIteratorSuite.scala | 108 --------- 6 files changed, 236 insertions(+), 196 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 5f12a4c4c543..2e6b157a74aa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -38,7 +38,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, CheckpointingIterator, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, SamplingUtils} @@ -263,11 +263,9 @@ abstract class RDD[T: ClassTag]( } else { computeOrReadCheckpoint(split, context) } - if (checkpointData.isDefined) { - checkpointData.get.getCheckpointIterator(iter, context, split.index) - } else { - iter - } + checkpointData.collect { case data: ReliableRDDCheckpointData[T] => + data.getCheckpointIterator(this, iter, context, split.index) + }.getOrElse(iter) } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index a69be6a068bb..bef80ebef195 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -96,7 +96,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { /** * Return the checkpoint file name for the given partition. */ - private def checkpointFileName(partitionIndex: Int): String = { + def checkpointFileName(partitionIndex: Int): String = { "part-%05d".format(partitionIndex) } diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 91cad6662e4d..52e83b131d02 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.fs.Path import org.apache.spark._ -import org.apache.spark.util.SerializableConfiguration +import org.apache.spark.util.{CheckpointingIterator, SerializableConfiguration} /** * An implementation of checkpointing that writes the RDD data to reliable storage. @@ -31,12 +31,26 @@ import org.apache.spark.util.SerializableConfiguration private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) extends RDDCheckpointData[T](rdd) with Logging { + import CheckpointState._ + // The directory to which the associated RDD has been checkpointed to // This is assumed to be a non-local path that points to some reliable storage - private val cpDir: String = - ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id) + private val cpDir: String = { + val _cpDir = ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id) .map(_.toString) - .getOrElse { throw new SparkException("Checkpoint dir must be specified.") } + .getOrElse { + throw new SparkException("Checkpoint dir must be specified.") + } + val path = new Path(_cpDir) + val fs = new Path(_cpDir).getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(path)) { + throw new SparkException("Failed to create checkpoint path " + path) + } + _cpDir + } + + private val broadcastedConf = rdd.context.broadcast( + new SerializableConfiguration(rdd.context.hadoopConfiguration)) /** * Return the directory to which this RDD was checkpointed. @@ -50,6 +64,27 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v } } + /** + * Return an Iterator that will checkpoint the data when consuming the original Iterator if + * this RDD has not yet been checkpointed. Otherwise, just return the original Iterator. + * + * Note: this is called in executor. + */ + def getCheckpointIterator( + rdd: RDD[T], values: Iterator[T], context: TaskContext, partitionIndex: Int): Iterator[T] = { + if (cpState == Initialized) { + CheckpointingIterator[T]( + rdd, + values, + cpDir, + broadcastedConf, + partitionIndex, + context) + } else { + values + } + } + /** * Materialize this RDD and write its content to a reliable DFS. * This is called immediately after the first action invoked on this RDD has completed. @@ -63,11 +98,20 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v throw new SparkException(s"Failed to create checkpoint path $cpDir") } - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableConfiguration(rdd.context.hadoopConfiguration)) - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) - rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) + val checkpointedPartitionFiles = fs.listStatus(path).map(_.getPath.getName).toSet + // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we + // must checkpoint any missing partitions. TODO: avoid running another job here (SPARK-8582). + val missingPartitionIndices = rdd.partitions.map(_.index).filter { i => + !checkpointedPartitionFiles(ReliableCheckpointRDD.checkpointFileName(i)) + } + if (missingPartitionIndices.nonEmpty) { + // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) + rdd.context.runJob( + rdd, + ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _, + missingPartitionIndices) + } + val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) if (newRDD.partitions.length != rdd.partitions.length) { throw new SparkException( diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala index 5332c6f6c106..591479c00e60 100644 --- a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -18,77 +18,83 @@ package org.apache.spark.util import java.io.IOException +import java.util.concurrent.ConcurrentHashMap import scala.reflect.ClassTag +import scala.util.control.NonFatal -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.CheckpointRDD -import org.apache.spark.serializer.SerializationStream +import org.apache.spark.rdd.{RDD, ReliableCheckpointRDD} +import org.apache.spark.storage.RDDBlockId /** * Wrapper around an iterator which writes checkpoint data to HDFS while running action on - * a RDD to support checkpointing RDD. + * an RDD. + * + * @param id the unique id for a partition of an RDD + * @param values the data to be checkpointed + * @param fs the FileSystem to use + * @param tempOutputPath the temp path to write the checkpoint data + * @param finalOutputPath the final path to move the temp file to when finishing checkpointing + * @param context the task context + * @param blockSize the block size for writing the checkpoint data */ -private[spark] class CheckpointingIterator[A: ClassTag]( - values: Iterator[A], - path: String, - broadcastedConf: Broadcast[SerializableConfiguration], - partitionId: Int, +private[spark] class CheckpointingIterator[T: ClassTag]( + id: RDDBlockId, + values: Iterator[T], + fs: FileSystem, + tempOutputPath: Path, + finalOutputPath: Path, context: TaskContext, - blockSize: Int = -1) extends Iterator[A] with Logging { - - private val env = SparkEnv.get - private var fs: FileSystem = null - private val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - private var serializeStream: SerializationStream = null + blockSize: Int) extends Iterator[T] with Logging { - private var finalOutputPath: Path = null - private var tempOutputPath: Path = null + private[this] var completed = false - /** - * Initialize this iterator by creating temporary output path and serializer instance. - * - */ - def init(): this.type = { - val outputDir = new Path(path) - fs = outputDir.getFileSystem(broadcastedConf.value.value) - - val finalOutputName = CheckpointRDD.splitIdToFile(partitionId) - finalOutputPath = new Path(outputDir, finalOutputName) - tempOutputPath = - new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptNumber) - - if (fs.exists(tempOutputPath)) { - // There are more than one iterator of the RDD is consumed. - // Don't checkpoint data in this iterator. - doCheckpoint = false - return this - } + context.addTaskCompletionListener { ctx => + // We don't know if the task is successful. So it's possible that we still checkpoint the + // remaining values even if the task is failed. + // TODO optimize the failure case if we can know the task status + complete() + } - val fileOutputStream = if (blockSize < 0) { + private[this] val fileOutputStream = { + val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) + if (blockSize < 0) { fs.create(tempOutputPath, false, bufferSize) } else { // This is mainly for testing purpose fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) } - val serializer = env.serializer.newInstance() - serializeStream = serializer.serializeStream(fileOutputStream) - this } + private[this] val serializeStream = + SparkEnv.get.serializer.newInstance().serializeStream(fileOutputStream) + /** * Called when this iterator is on the latest element by `hasNext`. * This method will rename temporary output path to final output path of checkpoint data. */ - def completion(): Unit = { - if (!doCheckpoint) { + private[this] def complete(): Unit = { + if (completed) { return } + if (serializeStream == null) { + // There is some exception when creating serializeStream, we only need to clean up the + // resources. + cleanup() + return + } + + while (values.hasNext) { + serializeStream.writeObject(values.next) + } + + completed = true + CheckpointingIterator.releaseLockForPartition(id) serializeStream.close() if (!fs.rename(tempOutputPath, finalOutputPath)) { @@ -105,47 +111,115 @@ private[spark] class CheckpointingIterator[A: ClassTag]( } } - def checkpointing(item: A): Unit = { - serializeStream.writeObject(item) + private[this] def cleanup(): Unit = { + completed = true + CheckpointingIterator.releaseLockForPartition(id) + if (serializeStream != null) { + serializeStream.close() + } + fs.delete(tempOutputPath, false) } - override def next(): A = { - val item = values.next() - if (doCheckpoint) { - checkpointing(item) + override def hasNext: Boolean = { + try { + val r = values.hasNext + if (!r) { + complete() + } + r + } catch { + case e: Throwable => + try { + cleanup() + } catch { + case NonFatal(e1) => + // Log `e1` since we should not override `e` + logError(e1.getMessage, e1) + } + throw e } - // If this the latest item, call hasNext will write to final output early. - hasNext - item } - private[this] var doCheckpoint = true - private[this] var completed = false - - override def hasNext: Boolean = { - val r = values.hasNext - if (!r && !completed) { - completed = true - completion() + override def next(): T = { + try { + val value = values.next() + serializeStream.writeObject(value) + value + } catch { + case e: Throwable => + try { + cleanup() + } catch { + case NonFatal(e1) => + // Log `e1` since we should not override `e` + logError(e1.getMessage, e1) + } + throw e } - r } } private[spark] object CheckpointingIterator { - def apply[A: ClassTag]( - values: Iterator[A], + + private val checkpointingRDDPartitions = new ConcurrentHashMap[RDDBlockId, RDDBlockId]() + + /** + * Return true if the caller gets the lock to write the checkpoint file. Otherwise, the caller + * should not do checkpointing. + */ + private def acquireLockForPartition(id: RDDBlockId): Boolean = { + checkpointingRDDPartitions.putIfAbsent(id, id) == null + } + + /** + * Release the lock to avoid memory leak. + */ + private def releaseLockForPartition(id: RDDBlockId): Unit = { + checkpointingRDDPartitions.remove(id) + } + + def apply[T: ClassTag]( + rdd: RDD[T], + values: Iterator[T], path: String, broadcastedConf: Broadcast[SerializableConfiguration], - partitionId: Int, + partitionIndex: Int, context: TaskContext, - blockSize: Int = -1) : CheckpointingIterator[A] = { - new CheckpointingIterator[A]( - values, - path, - broadcastedConf, - partitionId, - context, - blockSize).init() + blockSize: Int = -1): Iterator[T] = { + val id = RDDBlockId(rdd.id, partitionIndex) + if (CheckpointingIterator.acquireLockForPartition(id)) { + try { + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(broadcastedConf.value.value) + val finalOutputName = ReliableCheckpointRDD.checkpointFileName(partitionIndex) + val finalOutputPath = new Path(outputDir, finalOutputName) + if (fs.exists(finalOutputPath)) { + // RDD has already been checkpointed by a previous task. So we don't need to checkpoint + // again. + return values + } + val tempOutputPath = + new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptNumber) + if (fs.exists(tempOutputPath)) { + throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists") + } + + new CheckpointingIterator( + id, + values, + fs, + tempOutputPath, + finalOutputPath, + context, + blockSize) + } catch { + case e: Throwable => + releaseLockForPartition(id) + throw e + } + } else { + values + } + } } diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 119e5fc28e41..5aaf54093373 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -22,6 +22,7 @@ import java.io.File import scala.reflect.ClassTag import org.apache.spark.rdd._ +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -251,6 +252,37 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(rdd.partitions.size === 0) } + runTest("SPARK-8582: checkpointing should only launch one job") { reliableCheckpoint: Boolean => + @volatile var jobCounter = 0 + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobCounter += 1 + } + }) + val rdd = sc.parallelize(1 to 100, 10) + checkpoint(rdd, reliableCheckpoint) + assert(rdd.collect() === (1 to 100)) + sc.listenerBus.waitUntilEmpty(10000) + assert(jobCounter === 1) + } + + runTest("checkpointing without draining Iterators") { reliableCheckpoint: Boolean => + @volatile var jobCounter = 0 + sc.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobCounter += 1 + } + }) + val rdd = sc.parallelize(1 to 100, 10) + checkpoint(rdd, reliableCheckpoint) + assert(rdd.take(5) === (1 to 5)) + sc.listenerBus.waitUntilEmpty(10000) + // Because `take(5)` only consumes the first partition, there should be another job to + // checkpoint other partitions. + assert(jobCounter === 2) + assert(rdd.collect() === (1 to 100)) + } + // Utility test methods /** Checkpoint the RDD either locally or reliably. */ diff --git a/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala deleted file mode 100644 index 9df816cda0ee..000000000000 --- a/core/src/test/scala/org/apache/spark/util/CheckpointingIteratorSuite.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.io.File - -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.fs.Path - -import org.apache.spark._ -import org.apache.spark.rdd.CheckpointRDD -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkFunSuite - -class CheckpointingIteratorSuite extends SparkFunSuite with LocalSparkContext with Logging { - var checkpointDir: File = _ - val partitioner = new HashPartitioner(2) - - override def beforeEach() { - super.beforeEach() - checkpointDir = File.createTempFile("temp", "", Utils.createTempDir()) - checkpointDir.delete() - sc = new SparkContext("local", "test") - sc.setCheckpointDir(checkpointDir.toString) - } - - override def afterEach() { - super.afterEach() - Utils.deleteRecursively(checkpointDir) - } - - test("basic test") { - val broadcastedConf = sc.broadcast( - new SerializableConfiguration(sc.hadoopConfiguration)) - - val context = new TaskContextImpl(0, 0, 0, 0, null) - val iter = List(1, 2, 3).iterator - val checkpoingIter = CheckpointingIterator[Int]( - iter, - checkpointDir.toString, - broadcastedConf, - 0, - context) - - assert(checkpoingIter.hasNext) - assert(checkpoingIter.next() === 1) - - assert(checkpoingIter.hasNext) - assert(checkpoingIter.next() === 2) - - assert(checkpoingIter.hasNext) - assert(checkpoingIter.next() === 3) - - // checkpoint data should be written out now. - val outputDir = new Path(checkpointDir.toString) - val finalOutputName = CheckpointRDD.splitIdToFile(0) - val finalPath = new Path(outputDir, finalOutputName) - - val fs = outputDir.getFileSystem(sc.hadoopConfiguration) - assert(fs.exists(finalPath)) - - assert(!checkpoingIter.hasNext) - assert(!checkpoingIter.hasNext) - } - - test("no checkpoing data output if iterator is not consumed to latest element") { - val broadcastedConf = sc.broadcast( - new SerializableConfiguration(sc.hadoopConfiguration)) - - val context = new TaskContextImpl(0, 0, 0, 0, null) - val iter = List(1, 2, 3).iterator - val checkpoingIter = CheckpointingIterator[Int]( - iter, - checkpointDir.toString, - broadcastedConf, - 0, - context) - - assert(checkpoingIter.hasNext) - assert(checkpoingIter.next() === 1) - - assert(checkpoingIter.hasNext) - assert(checkpoingIter.next() === 2) - - // checkpoint data should not be written out yet. - val outputDir = new Path(checkpointDir.toString) - val finalOutputName = CheckpointRDD.splitIdToFile(0) - val finalPath = new Path(outputDir, finalOutputName) - - val fs = outputDir.getFileSystem(sc.hadoopConfiguration) - assert(!fs.exists(finalPath)) - } -} From 676317bf1e0db0e43ccaf3dccafe43ad812f2147 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 9 Nov 2015 14:36:57 -0800 Subject: [PATCH 6/8] Fix style and comments --- .../spark/rdd/ReliableRDDCheckpointData.scala | 12 ++-- .../spark/util/CheckpointingIterator.scala | 60 +++++++++---------- 2 files changed, 37 insertions(+), 35 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala index 52e83b131d02..536f361344ce 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala @@ -39,8 +39,8 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v val _cpDir = ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id) .map(_.toString) .getOrElse { - throw new SparkException("Checkpoint dir must be specified.") - } + throw new SparkException("Checkpoint dir must be specified.") + } val path = new Path(_cpDir) val fs = new Path(_cpDir).getFileSystem(rdd.context.hadoopConfiguration) if (!fs.mkdirs(path)) { @@ -71,7 +71,10 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v * Note: this is called in executor. */ def getCheckpointIterator( - rdd: RDD[T], values: Iterator[T], context: TaskContext, partitionIndex: Int): Iterator[T] = { + rdd: RDD[T], + values: Iterator[T], + context: TaskContext, + partitionIndex: Int): Iterator[T] = { if (cpState == Initialized) { CheckpointingIterator[T]( rdd, @@ -100,12 +103,11 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v val checkpointedPartitionFiles = fs.listStatus(path).map(_.getPath.getName).toSet // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we - // must checkpoint any missing partitions. TODO: avoid running another job here (SPARK-8582). + // must checkpoint any missing partitions. val missingPartitionIndices = rdd.partitions.map(_.index).filter { i => !checkpointedPartitionFiles(ReliableCheckpointRDD.checkpointFileName(i)) } if (missingPartitionIndices.nonEmpty) { - // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) rdd.context.runJob( rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _, diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala index 591479c00e60..b2046f72a0c9 100644 --- a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.IOException +import java.io.OutputStream import java.util.concurrent.ConcurrentHashMap import scala.reflect.ClassTag @@ -53,14 +54,12 @@ private[spark] class CheckpointingIterator[T: ClassTag]( private[this] var completed = false - context.addTaskCompletionListener { ctx => - // We don't know if the task is successful. So it's possible that we still checkpoint the - // remaining values even if the task is failed. - // TODO optimize the failure case if we can know the task status - complete() - } + // We don't know if the task is successful. So it's possible that we still checkpoint the + // remaining values even if the task is failed. + // TODO optimize the failure case if we can know the task status + context.addTaskCompletionListener { _ => complete() } - private[this] val fileOutputStream = { + private[this] val fileOutputStream: OutputStream = { val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) if (blockSize < 0) { fs.create(tempOutputPath, false, bufferSize) @@ -74,7 +73,7 @@ private[spark] class CheckpointingIterator[T: ClassTag]( SparkEnv.get.serializer.newInstance().serializeStream(fileOutputStream) /** - * Called when this iterator is on the latest element by `hasNext`. + * Called when this iterator is on the last element by `hasNext`. * This method will rename temporary output path to final output path of checkpoint data. */ private[this] def complete(): Unit = { @@ -120,13 +119,9 @@ private[spark] class CheckpointingIterator[T: ClassTag]( fs.delete(tempOutputPath, false) } - override def hasNext: Boolean = { + private[this] def cleanupOnFailure[A](body: => A): A = { try { - val r = values.hasNext - if (!r) { - complete() - } - r + body } catch { case e: Throwable => try { @@ -140,22 +135,18 @@ private[spark] class CheckpointingIterator[T: ClassTag]( } } - override def next(): T = { - try { - val value = values.next() - serializeStream.writeObject(value) - value - } catch { - case e: Throwable => - try { - cleanup() - } catch { - case NonFatal(e1) => - // Log `e1` since we should not override `e` - logError(e1.getMessage, e1) - } - throw e + override def hasNext: Boolean = cleanupOnFailure { + val r = values.hasNext + if (!r) { + complete() } + r + } + + override def next(): T = cleanupOnFailure { + val value = values.next() + serializeStream.writeObject(value) + value } } @@ -178,6 +169,15 @@ private[spark] object CheckpointingIterator { checkpointingRDDPartitions.remove(id) } + /** + * Create a `CheckpointingIterator` to wrap the original `Iterator` so that when the wrapper is + * consumed, it will checkpoint the values. Even if the wrapper is not drained, we will still + * drain the remaining values when a task is completed. + * + * If this method is called multiple times for the same partition of an `RDD`, only one `Iterator` + * that gets the lock will be wrapped. For other `Iterator`s that don't get the lock or find the + * partition has been checkpointed, we just return the original `Iterator`. + */ def apply[T: ClassTag]( rdd: RDD[T], values: Iterator[T], @@ -218,7 +218,7 @@ private[spark] object CheckpointingIterator { throw e } } else { - values + values // Iterator is being checkpointed, so just return the values } } From 49248c7cf4de14e42c9eeb731621cc346fcfb29e Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 9 Nov 2015 15:03:28 -0800 Subject: [PATCH 7/8] Handle ControlThrowable --- .../scala/org/apache/spark/util/CheckpointingIterator.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala index b2046f72a0c9..f56316b11054 100644 --- a/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala +++ b/core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala @@ -22,7 +22,7 @@ import java.io.OutputStream import java.util.concurrent.ConcurrentHashMap import scala.reflect.ClassTag -import scala.util.control.NonFatal +import scala.util.control.{ControlThrowable, NonFatal} import org.apache.hadoop.fs.{FileSystem, Path} @@ -123,6 +123,7 @@ private[spark] class CheckpointingIterator[T: ClassTag]( try { body } catch { + case e: ControlThrowable => throw e case e: Throwable => try { cleanup() @@ -213,6 +214,7 @@ private[spark] object CheckpointingIterator { context, blockSize) } catch { + case e: ControlThrowable => throw e case e: Throwable => releaseLockForPartition(id) throw e From 93c8febb3ef643208fe9630da2ef10dbd9607675 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Mon, 9 Nov 2015 16:25:30 -0800 Subject: [PATCH 8/8] Add a failure test for CheckpointingIterator --- .../org/apache/spark/CheckpointSuite.scala | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 5aaf54093373..e3c7c39920f3 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -283,6 +283,15 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging assert(rdd.collect() === (1 to 100)) } + runTest("call RDD.iterator lazily") { reliableCheckpoint: Boolean => + val parCollection = sc.makeRDD(1 to 10, 1) + checkpoint(parCollection, reliableCheckpoint) + val lazyRDD = new LazyRDD(parCollection) + checkpoint(lazyRDD, reliableCheckpoint) + lazyRDD.take(5) + assert(lazyRDD.collect() === (1 to 10)) + } + // Utility test methods /** Checkpoint the RDD either locally or reliably. */ @@ -501,6 +510,20 @@ class FatRDD(parent: RDD[Int]) extends RDD[Int](parent) { } } +class LazyRDD(parent: RDD[Int]) extends RDD[Int](parent) { + + protected def getPartitions: Array[Partition] = parent.partitions + + def compute(split: Partition, context: TaskContext): Iterator[Int] = new Iterator[Int] { + + lazy val iter = parent.iterator(split, context) + + override def hasNext: Boolean = iter.hasNext + + override def next(): Int = iter.next() + } +} + /** Pair RDD that has large serialized size. */ class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, Int)](parent) { val bigData = new Array[Byte](100000)