1818package org .apache .spark .util
1919
2020import java .io .IOException
21+ import java .io .OutputStream
2122import java .util .concurrent .ConcurrentHashMap
2223
2324import scala .reflect .ClassTag
@@ -53,14 +54,12 @@ private[spark] class CheckpointingIterator[T: ClassTag](
5354
5455 private [this ] var completed = false
5556
56- context.addTaskCompletionListener { ctx =>
57- // We don't know if the task is successful. So it's possible that we still checkpoint the
58- // remaining values even if the task is failed.
59- // TODO optimize the failure case if we can know the task status
60- complete()
61- }
57+ // We don't know if the task is successful. So it's possible that we still checkpoint the
58+ // remaining values even if the task is failed.
59+ // TODO optimize the failure case if we can know the task status
60+ context.addTaskCompletionListener { _ => complete() }
6261
63- private [this ] val fileOutputStream = {
62+ private [this ] val fileOutputStream : OutputStream = {
6463 val bufferSize = SparkEnv .get.conf.getInt(" spark.buffer.size" , 65536 )
6564 if (blockSize < 0 ) {
6665 fs.create(tempOutputPath, false , bufferSize)
@@ -74,7 +73,7 @@ private[spark] class CheckpointingIterator[T: ClassTag](
7473 SparkEnv .get.serializer.newInstance().serializeStream(fileOutputStream)
7574
7675 /**
77- * Called when this iterator is on the latest element by `hasNext`.
76+ * Called when this iterator is on the last element by `hasNext`.
7877 * This method will rename temporary output path to final output path of checkpoint data.
7978 */
8079 private [this ] def complete (): Unit = {
@@ -120,13 +119,9 @@ private[spark] class CheckpointingIterator[T: ClassTag](
120119 fs.delete(tempOutputPath, false )
121120 }
122121
123- override def hasNext : Boolean = {
122+ private [ this ] def cleanupOnFailure [ A ]( body : => A ) : A = {
124123 try {
125- val r = values.hasNext
126- if (! r) {
127- complete()
128- }
129- r
124+ body
130125 } catch {
131126 case e : Throwable =>
132127 try {
@@ -140,22 +135,18 @@ private[spark] class CheckpointingIterator[T: ClassTag](
140135 }
141136 }
142137
143- override def next (): T = {
144- try {
145- val value = values.next()
146- serializeStream.writeObject(value)
147- value
148- } catch {
149- case e : Throwable =>
150- try {
151- cleanup()
152- } catch {
153- case NonFatal (e1) =>
154- // Log `e1` since we should not override `e`
155- logError(e1.getMessage, e1)
156- }
157- throw e
138+ override def hasNext : Boolean = cleanupOnFailure {
139+ val r = values.hasNext
140+ if (! r) {
141+ complete()
158142 }
143+ r
144+ }
145+
146+ override def next (): T = cleanupOnFailure {
147+ val value = values.next()
148+ serializeStream.writeObject(value)
149+ value
159150 }
160151}
161152
@@ -178,6 +169,15 @@ private[spark] object CheckpointingIterator {
178169 checkpointingRDDPartitions.remove(id)
179170 }
180171
172+ /**
173+ * Create a `CheckpointingIterator` to wrap the original `Iterator` so that when the wrapper is
174+ * consumed, it will checkpoint the values. Even if the wrapper is not drained, we will still
175+ * drain the remaining values when a task is completed.
176+ *
177+ * If this method is called multiple times for the same partition of an `RDD`, only one `Iterator`
178+ * that gets the lock will be wrapped. For other `Iterator`s that don't get the lock or find the
179+ * partition has been checkpointed, we just return the original `Iterator`.
180+ */
181181 def apply [T : ClassTag ](
182182 rdd : RDD [T ],
183183 values : Iterator [T ],
@@ -218,7 +218,7 @@ private[spark] object CheckpointingIterator {
218218 throw e
219219 }
220220 } else {
221- values
221+ values // Iterator is being checkpointed, so just return the values
222222 }
223223
224224 }
0 commit comments