Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ class PartitionerAwareUnionRDD[T: ClassTag](
require(rdds.forall(_.partitioner.isDefined))
require(rdds.flatMap(_.partitioner).toSet.size == 1,
"Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))
require(rdds.map(_.partitioner.get.numPartitions).toSet.size == 1,
"Parent RDDs have different number of partitions: " +
rdds.map(_.partitioner.get.numPartitions))

override val partitioner = rdds.head.partitioner

Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,14 @@ abstract class RDD[T: ClassTag](
* subclasses of RDD.
*/
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
val iter = if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
computeOrReadCheckpoint(split, context)
}
checkpointData.collect { case data: ReliableRDDCheckpointData[T] =>
data.getCheckpointIterator(this, iter, context, split.index)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(deleted comment) actually, this is fine. On second thought I think local checkpoint doesn't really have this problem since we set the storage level there anyway.

}.getOrElse(iter)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
/**
* Return the checkpoint file name for the given partition.
*/
private def checkpointFileName(partitionIndex: Int): String = {
def checkpointFileName(partitionIndex: Int): String = {
"part-%05d".format(partitionIndex)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path

import org.apache.spark._
import org.apache.spark.util.SerializableConfiguration
import org.apache.spark.util.{CheckpointingIterator, SerializableConfiguration}

/**
* An implementation of checkpointing that writes the RDD data to reliable storage.
Expand All @@ -31,12 +31,26 @@ import org.apache.spark.util.SerializableConfiguration
private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T])
extends RDDCheckpointData[T](rdd) with Logging {

import CheckpointState._

// The directory to which the associated RDD has been checkpointed to
// This is assumed to be a non-local path that points to some reliable storage
private val cpDir: String =
ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id)
private val cpDir: String = {
val _cpDir = ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id)
.map(_.toString)
.getOrElse { throw new SparkException("Checkpoint dir must be specified.") }
.getOrElse {
throw new SparkException("Checkpoint dir must be specified.")
}
val path = new Path(_cpDir)
val fs = new Path(_cpDir).getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if _cpDir exists before mkdirs() is called ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's how it was before.

If checkpointing fails at the first time, _cpDir won't be deleted. Then the user may try to do it again, so we should allow checkpointing the same RDD even if _cpDir exists.

}
_cpDir
}

private val broadcastedConf = rdd.context.broadcast(
new SerializableConfiguration(rdd.context.hadoopConfiguration))

/**
* Return the directory to which this RDD was checkpointed.
Expand All @@ -50,6 +64,30 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
}
}

/**
* Return an Iterator that will checkpoint the data when consuming the original Iterator if
* this RDD has not yet been checkpointed. Otherwise, just return the original Iterator.
*
* Note: this is called in executor.
*/
def getCheckpointIterator(
rdd: RDD[T],
values: Iterator[T],
context: TaskContext,
partitionIndex: Int): Iterator[T] = {
if (cpState == Initialized) {
CheckpointingIterator[T](
rdd,
values,
cpDir,
broadcastedConf,
partitionIndex,
context)
} else {
values
}
}

/**
* Materialize this RDD and write its content to a reliable DFS.
* This is called immediately after the first action invoked on this RDD has completed.
Expand All @@ -63,11 +101,19 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
throw new SparkException(s"Failed to create checkpoint path $cpDir")
}

// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableConfiguration(rdd.context.hadoopConfiguration))
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
val checkpointedPartitionFiles = fs.listStatus(path).map(_.getPath.getName).toSet
// Not all actions compute all partitions of the RDD (e.g. take). For correctness, we
// must checkpoint any missing partitions.
val missingPartitionIndices = rdd.partitions.map(_.index).filter { i =>
!checkpointedPartitionFiles(ReliableCheckpointRDD.checkpointFileName(i))
}
if (missingPartitionIndices.nonEmpty) {
rdd.context.runJob(
rdd,
ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _,
missingPartitionIndices)
}

val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
Expand Down
227 changes: 227 additions & 0 deletions core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import java.io.IOException
import java.io.OutputStream
import java.util.concurrent.ConcurrentHashMap

import scala.reflect.ClassTag
import scala.util.control.{ControlThrowable, NonFatal}

import org.apache.hadoop.fs.{FileSystem, Path}

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{RDD, ReliableCheckpointRDD}
import org.apache.spark.storage.RDDBlockId

/**
* Wrapper around an iterator which writes checkpoint data to HDFS while running action on
* an RDD.
*
* @param id the unique id for a partition of an RDD
* @param values the data to be checkpointed
* @param fs the FileSystem to use
* @param tempOutputPath the temp path to write the checkpoint data
* @param finalOutputPath the final path to move the temp file to when finishing checkpointing
* @param context the task context
* @param blockSize the block size for writing the checkpoint data
*/
private[spark] class CheckpointingIterator[T: ClassTag](
id: RDDBlockId,
values: Iterator[T],
fs: FileSystem,
tempOutputPath: Path,
finalOutputPath: Path,
context: TaskContext,
blockSize: Int) extends Iterator[T] with Logging {

private[this] var completed = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be volatile? (do we ever have 2 threads in the same JVM computing the iterator?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be volatile? (do we ever have 2 threads in the same JVM computing the iterator?)

Iterator is not thread safe, there must be some protection if someone uses it in multiple threads. So we don't need to use volatile here.


// We don't know if the task is successful. So it's possible that we still checkpoint the
// remaining values even if the task is failed.
// TODO optimize the failure case if we can know the task status
context.addTaskCompletionListener { _ => complete() }

private[this] val fileOutputStream: OutputStream = {
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
if (blockSize < 0) {
fs.create(tempOutputPath, false, bufferSize)
} else {
// This is mainly for testing purpose
fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
}
}

private[this] val serializeStream =
SparkEnv.get.serializer.newInstance().serializeStream(fileOutputStream)

/**
* Called when this iterator is on the last element by `hasNext`.
* This method will rename temporary output path to final output path of checkpoint data.
*/
private[this] def complete(): Unit = {
if (completed) {
return
}

if (serializeStream == null) {
// There is some exception when creating serializeStream, we only need to clean up the
// resources.
cleanup()
return
}

while (values.hasNext) {
serializeStream.writeObject(values.next)
}

completed = true
CheckpointingIterator.releaseLockForPartition(id)
serializeStream.close()

if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.exists(finalOutputPath)) {
logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ context.attemptNumber + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
fs.delete(tempOutputPath, false)
}
}
}

private[this] def cleanup(): Unit = {
completed = true
CheckpointingIterator.releaseLockForPartition(id)
if (serializeStream != null) {
serializeStream.close()
}
fs.delete(tempOutputPath, false)
}

private[this] def cleanupOnFailure[A](body: => A): A = {
try {
body
} catch {
case e: ControlThrowable => throw e
case e: Throwable =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NonFatal

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Throwable here because it will be thrown later. It's better to cleanup as well for fatal errors.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, still need to handle ControlThrowable. Updated it.

try {
cleanup()
} catch {
case NonFatal(e1) =>
// Log `e1` since we should not override `e`
logError(e1.getMessage, e1)
}
throw e
}
}

override def hasNext: Boolean = cleanupOnFailure {
val r = values.hasNext
if (!r) {
complete()
}
r
}

override def next(): T = cleanupOnFailure {
val value = values.next()
serializeStream.writeObject(value)
value
}
}

private[spark] object CheckpointingIterator {

private val checkpointingRDDPartitions = new ConcurrentHashMap[RDDBlockId, RDDBlockId]()

/**
* Return true if the caller gets the lock to write the checkpoint file. Otherwise, the caller
* should not do checkpointing.
*/
private def acquireLockForPartition(id: RDDBlockId): Boolean = {
checkpointingRDDPartitions.putIfAbsent(id, id) == null
}

/**
* Release the lock to avoid memory leak.
*/
private def releaseLockForPartition(id: RDDBlockId): Unit = {
checkpointingRDDPartitions.remove(id)
}

/**
* Create a `CheckpointingIterator` to wrap the original `Iterator` so that when the wrapper is
* consumed, it will checkpoint the values. Even if the wrapper is not drained, we will still
* drain the remaining values when a task is completed.
*
* If this method is called multiple times for the same partition of an `RDD`, only one `Iterator`
* that gets the lock will be wrapped. For other `Iterator`s that don't get the lock or find the
* partition has been checkpointed, we just return the original `Iterator`.
*/
def apply[T: ClassTag](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need big java doc to explain what it's doing. We should mention the implicit locking this does and why it's needed etc.

rdd: RDD[T],
values: Iterator[T],
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
partitionIndex: Int,
context: TaskContext,
blockSize: Int = -1): Iterator[T] = {
val id = RDDBlockId(rdd.id, partitionIndex)
if (CheckpointingIterator.acquireLockForPartition(id)) {
try {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(broadcastedConf.value.value)
val finalOutputName = ReliableCheckpointRDD.checkpointFileName(partitionIndex)
val finalOutputPath = new Path(outputDir, finalOutputName)
if (fs.exists(finalOutputPath)) {
// RDD has already been checkpointed by a previous task. So we don't need to checkpoint
// again.
return values
}
val tempOutputPath =
new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptNumber)
if (fs.exists(tempOutputPath)) {
throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists")
}

new CheckpointingIterator(
id,
values,
fs,
tempOutputPath,
finalOutputPath,
context,
blockSize)
} catch {
case e: ControlThrowable => throw e
case e: Throwable =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

releaseLockForPartition(id)
throw e
}
} else {
values // Iterator is being checkpointed, so just return the values
}

}
}
Loading