-
Notifications
You must be signed in to change notification settings - Fork 28.9k
[SPARK-8582][Core]Optimize checkpointing to avoid computing an RDD twice #9428
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d863516
1a3055e
3c5b203
a829a7d
2f43ff3
5c42503
647162f
676317b
49248c7
93c8feb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -22,7 +22,7 @@ import scala.reflect.ClassTag | |
| import org.apache.hadoop.fs.Path | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.util.SerializableConfiguration | ||
| import org.apache.spark.util.{CheckpointingIterator, SerializableConfiguration} | ||
|
|
||
| /** | ||
| * An implementation of checkpointing that writes the RDD data to reliable storage. | ||
|
|
@@ -31,12 +31,26 @@ import org.apache.spark.util.SerializableConfiguration | |
| private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private val rdd: RDD[T]) | ||
| extends RDDCheckpointData[T](rdd) with Logging { | ||
|
|
||
| import CheckpointState._ | ||
|
|
||
| // The directory to which the associated RDD has been checkpointed to | ||
| // This is assumed to be a non-local path that points to some reliable storage | ||
| private val cpDir: String = | ||
| ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id) | ||
| private val cpDir: String = { | ||
| val _cpDir = ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id) | ||
| .map(_.toString) | ||
| .getOrElse { throw new SparkException("Checkpoint dir must be specified.") } | ||
| .getOrElse { | ||
| throw new SparkException("Checkpoint dir must be specified.") | ||
| } | ||
| val path = new Path(_cpDir) | ||
| val fs = new Path(_cpDir).getFileSystem(rdd.context.hadoopConfiguration) | ||
| if (!fs.mkdirs(path)) { | ||
| throw new SparkException("Failed to create checkpoint path " + path) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What if _cpDir exists before mkdirs() is called ?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's how it was before. If checkpointing fails at the first time, |
||
| } | ||
| _cpDir | ||
| } | ||
|
|
||
| private val broadcastedConf = rdd.context.broadcast( | ||
| new SerializableConfiguration(rdd.context.hadoopConfiguration)) | ||
|
|
||
| /** | ||
| * Return the directory to which this RDD was checkpointed. | ||
|
|
@@ -50,6 +64,30 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v | |
| } | ||
| } | ||
|
|
||
| /** | ||
| * Return an Iterator that will checkpoint the data when consuming the original Iterator if | ||
| * this RDD has not yet been checkpointed. Otherwise, just return the original Iterator. | ||
| * | ||
| * Note: this is called in executor. | ||
| */ | ||
| def getCheckpointIterator( | ||
| rdd: RDD[T], | ||
| values: Iterator[T], | ||
| context: TaskContext, | ||
| partitionIndex: Int): Iterator[T] = { | ||
| if (cpState == Initialized) { | ||
| CheckpointingIterator[T]( | ||
| rdd, | ||
| values, | ||
| cpDir, | ||
| broadcastedConf, | ||
| partitionIndex, | ||
| context) | ||
| } else { | ||
| values | ||
| } | ||
| } | ||
|
|
||
| /** | ||
| * Materialize this RDD and write its content to a reliable DFS. | ||
| * This is called immediately after the first action invoked on this RDD has completed. | ||
|
|
@@ -63,11 +101,19 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v | |
| throw new SparkException(s"Failed to create checkpoint path $cpDir") | ||
| } | ||
|
|
||
| // Save to file, and reload it as an RDD | ||
| val broadcastedConf = rdd.context.broadcast( | ||
| new SerializableConfiguration(rdd.context.hadoopConfiguration)) | ||
| // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582) | ||
| rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _) | ||
| val checkpointedPartitionFiles = fs.listStatus(path).map(_.getPath.getName).toSet | ||
| // Not all actions compute all partitions of the RDD (e.g. take). For correctness, we | ||
| // must checkpoint any missing partitions. | ||
| val missingPartitionIndices = rdd.partitions.map(_.index).filter { i => | ||
| !checkpointedPartitionFiles(ReliableCheckpointRDD.checkpointFileName(i)) | ||
| } | ||
| if (missingPartitionIndices.nonEmpty) { | ||
| rdd.context.runJob( | ||
| rdd, | ||
| ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _, | ||
| missingPartitionIndices) | ||
| } | ||
|
|
||
| val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir) | ||
| if (newRDD.partitions.length != rdd.partitions.length) { | ||
| throw new SparkException( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,227 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.util | ||
|
|
||
| import java.io.IOException | ||
| import java.io.OutputStream | ||
| import java.util.concurrent.ConcurrentHashMap | ||
|
|
||
| import scala.reflect.ClassTag | ||
| import scala.util.control.{ControlThrowable, NonFatal} | ||
|
|
||
| import org.apache.hadoop.fs.{FileSystem, Path} | ||
|
|
||
| import org.apache.spark._ | ||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.rdd.{RDD, ReliableCheckpointRDD} | ||
| import org.apache.spark.storage.RDDBlockId | ||
|
|
||
| /** | ||
| * Wrapper around an iterator which writes checkpoint data to HDFS while running action on | ||
| * an RDD. | ||
| * | ||
| * @param id the unique id for a partition of an RDD | ||
| * @param values the data to be checkpointed | ||
| * @param fs the FileSystem to use | ||
| * @param tempOutputPath the temp path to write the checkpoint data | ||
| * @param finalOutputPath the final path to move the temp file to when finishing checkpointing | ||
| * @param context the task context | ||
| * @param blockSize the block size for writing the checkpoint data | ||
| */ | ||
| private[spark] class CheckpointingIterator[T: ClassTag]( | ||
| id: RDDBlockId, | ||
| values: Iterator[T], | ||
| fs: FileSystem, | ||
| tempOutputPath: Path, | ||
| finalOutputPath: Path, | ||
| context: TaskContext, | ||
| blockSize: Int) extends Iterator[T] with Logging { | ||
|
|
||
| private[this] var completed = false | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this need to be volatile? (do we ever have 2 threads in the same JVM computing the iterator?)
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
|
||
| // We don't know if the task is successful. So it's possible that we still checkpoint the | ||
| // remaining values even if the task is failed. | ||
| // TODO optimize the failure case if we can know the task status | ||
| context.addTaskCompletionListener { _ => complete() } | ||
|
|
||
| private[this] val fileOutputStream: OutputStream = { | ||
| val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536) | ||
| if (blockSize < 0) { | ||
| fs.create(tempOutputPath, false, bufferSize) | ||
| } else { | ||
| // This is mainly for testing purpose | ||
| fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) | ||
| } | ||
| } | ||
|
|
||
| private[this] val serializeStream = | ||
| SparkEnv.get.serializer.newInstance().serializeStream(fileOutputStream) | ||
|
|
||
| /** | ||
| * Called when this iterator is on the last element by `hasNext`. | ||
| * This method will rename temporary output path to final output path of checkpoint data. | ||
| */ | ||
| private[this] def complete(): Unit = { | ||
| if (completed) { | ||
| return | ||
| } | ||
|
|
||
| if (serializeStream == null) { | ||
| // There is some exception when creating serializeStream, we only need to clean up the | ||
| // resources. | ||
| cleanup() | ||
| return | ||
| } | ||
|
|
||
| while (values.hasNext) { | ||
| serializeStream.writeObject(values.next) | ||
| } | ||
|
|
||
| completed = true | ||
| CheckpointingIterator.releaseLockForPartition(id) | ||
| serializeStream.close() | ||
|
|
||
| if (!fs.rename(tempOutputPath, finalOutputPath)) { | ||
| if (!fs.exists(finalOutputPath)) { | ||
| logInfo("Deleting tempOutputPath " + tempOutputPath) | ||
| fs.delete(tempOutputPath, false) | ||
| throw new IOException("Checkpoint failed: failed to save output of task: " | ||
| + context.attemptNumber + " and final output path does not exist") | ||
| } else { | ||
| // Some other copy of this task must've finished before us and renamed it | ||
| logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it") | ||
| fs.delete(tempOutputPath, false) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| private[this] def cleanup(): Unit = { | ||
| completed = true | ||
| CheckpointingIterator.releaseLockForPartition(id) | ||
| if (serializeStream != null) { | ||
| serializeStream.close() | ||
| } | ||
| fs.delete(tempOutputPath, false) | ||
| } | ||
|
|
||
| private[this] def cleanupOnFailure[A](body: => A): A = { | ||
| try { | ||
| body | ||
| } catch { | ||
| case e: ControlThrowable => throw e | ||
| case e: Throwable => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, still need to handle ControlThrowable. Updated it. |
||
| try { | ||
| cleanup() | ||
| } catch { | ||
| case NonFatal(e1) => | ||
| // Log `e1` since we should not override `e` | ||
| logError(e1.getMessage, e1) | ||
| } | ||
| throw e | ||
| } | ||
| } | ||
|
|
||
| override def hasNext: Boolean = cleanupOnFailure { | ||
| val r = values.hasNext | ||
| if (!r) { | ||
| complete() | ||
| } | ||
| r | ||
| } | ||
|
|
||
| override def next(): T = cleanupOnFailure { | ||
| val value = values.next() | ||
| serializeStream.writeObject(value) | ||
| value | ||
| } | ||
| } | ||
|
|
||
| private[spark] object CheckpointingIterator { | ||
|
|
||
| private val checkpointingRDDPartitions = new ConcurrentHashMap[RDDBlockId, RDDBlockId]() | ||
|
|
||
| /** | ||
| * Return true if the caller gets the lock to write the checkpoint file. Otherwise, the caller | ||
| * should not do checkpointing. | ||
| */ | ||
| private def acquireLockForPartition(id: RDDBlockId): Boolean = { | ||
| checkpointingRDDPartitions.putIfAbsent(id, id) == null | ||
| } | ||
|
|
||
| /** | ||
| * Release the lock to avoid memory leak. | ||
| */ | ||
| private def releaseLockForPartition(id: RDDBlockId): Unit = { | ||
| checkpointingRDDPartitions.remove(id) | ||
| } | ||
|
|
||
| /** | ||
| * Create a `CheckpointingIterator` to wrap the original `Iterator` so that when the wrapper is | ||
| * consumed, it will checkpoint the values. Even if the wrapper is not drained, we will still | ||
| * drain the remaining values when a task is completed. | ||
| * | ||
| * If this method is called multiple times for the same partition of an `RDD`, only one `Iterator` | ||
| * that gets the lock will be wrapped. For other `Iterator`s that don't get the lock or find the | ||
| * partition has been checkpointed, we just return the original `Iterator`. | ||
| */ | ||
| def apply[T: ClassTag]( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. need big java doc to explain what it's doing. We should mention the implicit locking this does and why it's needed etc. |
||
| rdd: RDD[T], | ||
| values: Iterator[T], | ||
| path: String, | ||
| broadcastedConf: Broadcast[SerializableConfiguration], | ||
| partitionIndex: Int, | ||
| context: TaskContext, | ||
| blockSize: Int = -1): Iterator[T] = { | ||
| val id = RDDBlockId(rdd.id, partitionIndex) | ||
| if (CheckpointingIterator.acquireLockForPartition(id)) { | ||
| try { | ||
| val outputDir = new Path(path) | ||
| val fs = outputDir.getFileSystem(broadcastedConf.value.value) | ||
| val finalOutputName = ReliableCheckpointRDD.checkpointFileName(partitionIndex) | ||
| val finalOutputPath = new Path(outputDir, finalOutputName) | ||
| if (fs.exists(finalOutputPath)) { | ||
| // RDD has already been checkpointed by a previous task. So we don't need to checkpoint | ||
| // again. | ||
| return values | ||
| } | ||
| val tempOutputPath = | ||
| new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptNumber) | ||
| if (fs.exists(tempOutputPath)) { | ||
| throw new IOException(s"Checkpoint failed: temporary path $tempOutputPath already exists") | ||
| } | ||
|
|
||
| new CheckpointingIterator( | ||
| id, | ||
| values, | ||
| fs, | ||
| tempOutputPath, | ||
| finalOutputPath, | ||
| context, | ||
| blockSize) | ||
| } catch { | ||
| case e: ControlThrowable => throw e | ||
| case e: Throwable => | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here |
||
| releaseLockForPartition(id) | ||
| throw e | ||
| } | ||
| } else { | ||
| values // Iterator is being checkpointed, so just return the values | ||
| } | ||
|
|
||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(deleted comment) actually, this is fine. On second thought I think local checkpoint doesn't really have this problem since we set the storage level there anyway.