Skip to content

Commit 824be91

Browse files
committed
Refactor
1 parent d69f775 commit 824be91

File tree

2 files changed

+27
-37
lines changed

2 files changed

+27
-37
lines changed

core/src/main/scala/org/apache/spark/CheckpointManager.scala

Lines changed: 23 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,49 +20,40 @@ package org.apache.spark
2020
import scala.collection.mutable
2121
import scala.reflect.ClassTag
2222

23-
import org.apache.hadoop.conf.Configuration
24-
import org.apache.hadoop.fs.Path
25-
26-
import org.apache.spark.rdd.{ReliableRDDCheckpointData, ReliableCheckpointRDD, RDD}
23+
import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData, ReliableCheckpointRDD}
2724
import org.apache.spark.storage._
2825

2926
private[spark] class CheckpointManager extends Logging {
3027

3128
/** Keys of RDD partitions that are being checkpointed. */
3229
private val checkpointingRDDPartitions = new mutable.HashSet[RDDBlockId]
3330

34-
/** Gets or computes an RDD partition. Used by RDD.iterator() when an RDD is about to be
35-
* checkpointed. */
36-
def getOrCompute[T: ClassTag](
31+
/**
32+
* Checkpoint the RDD partition. If it's being checkpointed, just wait until finishing
33+
* checkpointing.
34+
*/
35+
def doCheckpoint[T: ClassTag](
3736
rdd: RDD[T],
3837
checkpointData: ReliableRDDCheckpointData[T],
3938
partition: Partition,
40-
context: TaskContext): Iterator[T] = {
41-
val conf = checkpointData.broadcastedConf.value.value
42-
val path =
43-
new Path(checkpointData.cpDir, ReliableCheckpointRDD.checkpointFileName(partition.index))
39+
context: TaskContext): Unit = {
40+
val hadoopConf = checkpointData.broadcastedConf.value.value
4441
val key = RDDBlockId(rdd.id, partition.index)
4542
logDebug(s"Looking for partition $key")
46-
if (checkpointData.isCheckpointed) {
47-
// TODO how to know we should checkpoint
48-
return new InterruptibleIterator[T](context,
49-
ReliableCheckpointRDD.readCheckpointFile(path, conf, context))
50-
} else {
51-
// Acquire a lock for loading this partition
52-
// If another thread already holds the lock, wait for it to finish return its results
53-
val checkpoint = acquireLockForPartition[T](rdd, partition, key, context)
54-
if (checkpoint.isDefined) {
55-
return new InterruptibleIterator[T](context, checkpoint.get)
56-
}
57-
}
5843

59-
// Otherwise, we have to load the partition ourselves
6044
try {
61-
logInfo(s"Partition $key not found, computing it")
62-
val computedValues = rdd.computeOrReadCache(partition, context)
63-
ReliableCheckpointRDD.writeCheckpointFile(
64-
context, computedValues, checkpointData.cpDir, conf, partition.index)
65-
rdd.computeOrReadCache(partition, context)
45+
// Acquire a lock for loading this partition
46+
// If another thread already holds the lock, wait for it to finish
47+
if (acquireLockForPartition[T](rdd, partition, key, context)) {
48+
// Acquired the lock. We have to load the partition ourselves
49+
logInfo(s"Partition $key not found, computing it")
50+
val computedValues = rdd.computeOrReadCache(partition, context)
51+
// TODO Some operators may use the same partition of a RDD in different executors, such as
52+
// `cartesian`. `writeCheckpointFile` has already handled this corner case for speculation.
53+
// However, we may need to optimize for this case in future.
54+
ReliableCheckpointRDD.writeCheckpointFile(
55+
context, computedValues, checkpointData.cpDir, hadoopConf, partition.index)
56+
}
6657
} finally {
6758
checkpointingRDDPartitions.synchronized {
6859
checkpointingRDDPartitions.remove(key)
@@ -82,22 +73,22 @@ private[spark] class CheckpointManager extends Logging {
8273
rdd: RDD[T],
8374
partition: Partition,
8475
id: RDDBlockId,
85-
context: TaskContext): Option[Iterator[T]] = {
76+
context: TaskContext): Boolean = {
8677
checkpointingRDDPartitions.synchronized {
8778
if (!checkpointingRDDPartitions.contains(id)) {
8879
// If the partition is free, acquire its lock to compute its value
8980
checkpointingRDDPartitions.add(id)
90-
return None
81+
true
9182
} else {
9283
// Otherwise, wait for another thread to finish and return its result
9384
logInfo(s"Another thread is checkpointing $id, waiting for it to finish...")
9485
while (checkpointingRDDPartitions.contains(id)) {
9586
checkpointingRDDPartitions.wait()
9687
}
9788
logInfo(s"Finished waiting for $id")
89+
false
9890
}
9991
}
100-
Some(rdd.computeOrReadCache(partition, context))
10192
}
10293

10394
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,10 @@ abstract class RDD[T: ClassTag](
258258
* subclasses of RDD.
259259
*/
260260
final def iterator(split: Partition, context: TaskContext): Iterator[T] = {
261-
if (!isCheckpointedAndMaterialized) {
262-
if (checkpointData.exists(_.isInstanceOf[ReliableRDDCheckpointData[T]])) {
263-
return SparkEnv.get.checkpointManager.getOrCompute(
264-
this, checkpointData.get.asInstanceOf[ReliableRDDCheckpointData[T]], split, context)
265-
}
261+
if (!isCheckpointedAndMaterialized &&
262+
checkpointData.exists(_.isInstanceOf[ReliableRDDCheckpointData[T]])) {
263+
SparkEnv.get.checkpointManager.doCheckpoint(
264+
this, checkpointData.get.asInstanceOf[ReliableRDDCheckpointData[T]], split, context)
266265
}
267266
computeOrReadCache(split, context)
268267
}

0 commit comments

Comments
 (0)