Skip to content

Commit a350945

Browse files
committed
Cache preferred locations of checkpointed RDD.
1 parent db9e0fd commit a350945

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,13 @@ package object config {
783783
.booleanConf
784784
.createWithDefault(false)
785785

786+
private[spark] val CACHE_CHECKPOINT_PREFERRED_LOCS =
787+
ConfigBuilder("spark.rdd.checkpoint.cachePreferredLocs")
788+
.internal()
789+
.doc("Whether to cache preferred locations of checkpointed RDD.")
790+
.booleanConf
791+
.createWithDefault(false)
792+
786793
private[spark] val SHUFFLE_ACCURATE_BLOCK_THRESHOLD =
787794
ConfigBuilder("spark.shuffle.accurateBlockThreshold")
788795
.doc("Threshold in bytes above which the size of shuffle blocks in " +

core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.rdd
2020
import java.io.{FileNotFoundException, IOException}
2121
import java.util.concurrent.TimeUnit
2222

23+
import scala.collection.mutable
2324
import scala.reflect.ClassTag
2425
import scala.util.control.NonFatal
2526

@@ -28,7 +29,7 @@ import org.apache.hadoop.fs.Path
2829
import org.apache.spark._
2930
import org.apache.spark.broadcast.Broadcast
3031
import org.apache.spark.internal.Logging
31-
import org.apache.spark.internal.config.{BUFFER_SIZE, CHECKPOINT_COMPRESS}
32+
import org.apache.spark.internal.config.{BUFFER_SIZE, CACHE_CHECKPOINT_PREFERRED_LOCS, CHECKPOINT_COMPRESS}
3233
import org.apache.spark.io.CompressionCodec
3334
import org.apache.spark.util.{SerializableConfiguration, Utils}
3435

@@ -82,14 +83,28 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
8283
Array.tabulate(inputFiles.length)(i => new CheckpointRDDPartition(i))
8384
}
8485

86+
// Cache of preferred locations of checkpointed files.
87+
private[spark] val cachedPreferredLocations: mutable.HashMap[Int, Seq[String]] =
88+
mutable.HashMap.empty
89+
8590
/**
8691
* Return the locations of the checkpoint file associated with the given partition.
8792
*/
8893
protected override def getPreferredLocations(split: Partition): Seq[String] = {
89-
val status = fs.getFileStatus(
90-
new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index)))
91-
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
92-
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
94+
val cachePreferrredLoc = SparkEnv.get.conf.get(CACHE_CHECKPOINT_PREFERRED_LOCS)
95+
96+
if (cachePreferrredLoc && cachedPreferredLocations.contains(split.index)) {
97+
cachedPreferredLocations(split.index)
98+
} else {
99+
val status = fs.getFileStatus(
100+
new Path(checkpointPath, ReliableCheckpointRDD.checkpointFileName(split.index)))
101+
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
102+
val preferredLoc = locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
103+
if (cachePreferrredLoc) {
104+
cachedPreferredLocations.update(split.index, preferredLoc)
105+
}
106+
preferredLoc
107+
}
93108
}
94109

95110
/**

core/src/test/scala/org/apache/spark/CheckpointSuite.scala

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ object CheckpointSuite {
584584
}
585585
}
586586

587-
class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext {
587+
class CheckpointStorageSuite extends SparkFunSuite with LocalSparkContext {
588588

589589
test("checkpoint compression") {
590590
withTempDir { checkpointDir =>
@@ -618,4 +618,26 @@ class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext {
618618
assert(rdd.collect().toSeq === (1 to 20))
619619
}
620620
}
621+
622+
test("cache checkpoint preferred location") {
623+
withTempDir { checkpointDir =>
624+
val conf = new SparkConf()
625+
.set("spark.rdd.checkpoint.cachePreferredLocs", "true")
626+
.set(UI_ENABLED.key, "false")
627+
sc = new SparkContext("local", "test", conf)
628+
sc.setCheckpointDir(checkpointDir.toString)
629+
val rdd = sc.makeRDD(1 to 20, numSlices = 1)
630+
rdd.checkpoint()
631+
assert(rdd.collect().toSeq === (1 to 20))
632+
633+
// Verify that RDD is checkpointed
634+
assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]])
635+
val checkpointedRDD = rdd.firstParent.asInstanceOf[ReliableCheckpointRDD[_]]
636+
assert(!checkpointedRDD.cachedPreferredLocations.isDefinedAt(0))
637+
638+
val preferredLoc = checkpointedRDD.preferredLocations(checkpointedRDD.partitions(0))
639+
assert(checkpointedRDD.cachedPreferredLocations.isDefinedAt(0))
640+
assert(preferredLoc == checkpointedRDD.cachedPreferredLocations(0))
641+
}
642+
}
621643
}

0 commit comments

Comments
 (0)