Skip to content

Commit 676317b

Browse files
committed
Fix style and comments
1 parent 647162f commit 676317b

File tree

2 files changed

+37
-35
lines changed

2 files changed

+37
-35
lines changed

core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
3939
val _cpDir = ReliableRDDCheckpointData.checkpointPath(rdd.context, rdd.id)
4040
.map(_.toString)
4141
.getOrElse {
42-
throw new SparkException("Checkpoint dir must be specified.")
43-
}
42+
throw new SparkException("Checkpoint dir must be specified.")
43+
}
4444
val path = new Path(_cpDir)
4545
val fs = new Path(_cpDir).getFileSystem(rdd.context.hadoopConfiguration)
4646
if (!fs.mkdirs(path)) {
@@ -71,7 +71,10 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
7171
* Note: this is called in executor.
7272
*/
7373
def getCheckpointIterator(
74-
rdd: RDD[T], values: Iterator[T], context: TaskContext, partitionIndex: Int): Iterator[T] = {
74+
rdd: RDD[T],
75+
values: Iterator[T],
76+
context: TaskContext,
77+
partitionIndex: Int): Iterator[T] = {
7578
if (cpState == Initialized) {
7679
CheckpointingIterator[T](
7780
rdd,
@@ -100,12 +103,11 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
100103

101104
val checkpointedPartitionFiles = fs.listStatus(path).map(_.getPath.getName).toSet
102105
// Not all actions compute all partitions of the RDD (e.g. take). For correctness, we
103-
// must checkpoint any missing partitions. TODO: avoid running another job here (SPARK-8582).
106+
// must checkpoint any missing partitions.
104107
val missingPartitionIndices = rdd.partitions.map(_.index).filter { i =>
105108
!checkpointedPartitionFiles(ReliableCheckpointRDD.checkpointFileName(i))
106109
}
107110
if (missingPartitionIndices.nonEmpty) {
108-
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
109111
rdd.context.runJob(
110112
rdd,
111113
ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _,

core/src/main/scala/org/apache/spark/util/CheckpointingIterator.scala

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.util
1919

2020
import java.io.IOException
21+
import java.io.OutputStream
2122
import java.util.concurrent.ConcurrentHashMap
2223

2324
import scala.reflect.ClassTag
@@ -53,14 +54,12 @@ private[spark] class CheckpointingIterator[T: ClassTag](
5354

5455
private[this] var completed = false
5556

56-
context.addTaskCompletionListener { ctx =>
57-
// We don't know if the task is successful. So it's possible that we still checkpoint the
58-
// remaining values even if the task is failed.
59-
// TODO optimize the failure case if we can know the task status
60-
complete()
61-
}
57+
// We don't know if the task is successful. So it's possible that we still checkpoint the
58+
// remaining values even if the task is failed.
59+
// TODO optimize the failure case if we can know the task status
60+
context.addTaskCompletionListener { _ => complete() }
6261

63-
private[this] val fileOutputStream = {
62+
private[this] val fileOutputStream: OutputStream = {
6463
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
6564
if (blockSize < 0) {
6665
fs.create(tempOutputPath, false, bufferSize)
@@ -74,7 +73,7 @@ private[spark] class CheckpointingIterator[T: ClassTag](
7473
SparkEnv.get.serializer.newInstance().serializeStream(fileOutputStream)
7574

7675
/**
77-
* Called when this iterator is on the latest element by `hasNext`.
76+
* Called when this iterator is on the last element by `hasNext`.
7877
* This method will rename temporary output path to final output path of checkpoint data.
7978
*/
8079
private[this] def complete(): Unit = {
@@ -120,13 +119,9 @@ private[spark] class CheckpointingIterator[T: ClassTag](
120119
fs.delete(tempOutputPath, false)
121120
}
122121

123-
override def hasNext: Boolean = {
122+
private[this] def cleanupOnFailure[A](body: => A): A = {
124123
try {
125-
val r = values.hasNext
126-
if (!r) {
127-
complete()
128-
}
129-
r
124+
body
130125
} catch {
131126
case e: Throwable =>
132127
try {
@@ -140,22 +135,18 @@ private[spark] class CheckpointingIterator[T: ClassTag](
140135
}
141136
}
142137

143-
override def next(): T = {
144-
try {
145-
val value = values.next()
146-
serializeStream.writeObject(value)
147-
value
148-
} catch {
149-
case e: Throwable =>
150-
try {
151-
cleanup()
152-
} catch {
153-
case NonFatal(e1) =>
154-
// Log `e1` since we should not override `e`
155-
logError(e1.getMessage, e1)
156-
}
157-
throw e
138+
override def hasNext: Boolean = cleanupOnFailure {
139+
val r = values.hasNext
140+
if (!r) {
141+
complete()
158142
}
143+
r
144+
}
145+
146+
override def next(): T = cleanupOnFailure {
147+
val value = values.next()
148+
serializeStream.writeObject(value)
149+
value
159150
}
160151
}
161152

@@ -178,6 +169,15 @@ private[spark] object CheckpointingIterator {
178169
checkpointingRDDPartitions.remove(id)
179170
}
180171

172+
/**
173+
* Create a `CheckpointingIterator` to wrap the original `Iterator` so that when the wrapper is
174+
* consumed, it will checkpoint the values. Even if the wrapper is not drained, we will still
175+
* drain the remaining values when a task is completed.
176+
*
177+
* If this method is called multiple times for the same partition of an `RDD`, only one `Iterator`
178+
* that gets the lock will be wrapped. For other `Iterator`s that don't get the lock or find the
179+
* partition has been checkpointed, we just return the original `Iterator`.
180+
*/
181181
def apply[T: ClassTag](
182182
rdd: RDD[T],
183183
values: Iterator[T],
@@ -218,7 +218,7 @@ private[spark] object CheckpointingIterator {
218218
throw e
219219
}
220220
} else {
221-
values
221+
values // Iterator is being checkpointed, so just return the values
222222
}
223223

224224
}

0 commit comments

Comments
 (0)