@@ -20,6 +20,7 @@ package org.apache.spark.rdd
2020import java .io .{FileNotFoundException , IOException }
2121import java .util .concurrent .TimeUnit
2222
23+ import scala .collection .mutable
2324import scala .reflect .ClassTag
2425import scala .util .control .NonFatal
2526
@@ -28,7 +29,7 @@ import org.apache.hadoop.fs.Path
2829import org .apache .spark ._
2930import org .apache .spark .broadcast .Broadcast
3031import org .apache .spark .internal .Logging
31- import org .apache .spark .internal .config .{BUFFER_SIZE , CHECKPOINT_COMPRESS }
32+ import org .apache .spark .internal .config .{BUFFER_SIZE , CACHE_CHECKPOINT_PREFERRED_LOCS , CHECKPOINT_COMPRESS }
3233import org .apache .spark .io .CompressionCodec
3334import org .apache .spark .util .{SerializableConfiguration , Utils }
3435
@@ -82,14 +83,28 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
8283 Array .tabulate(inputFiles.length)(i => new CheckpointRDDPartition (i))
8384 }
8485
86+ // Cache of preferred locations of checkpointed files.
87+ private [spark] val cachedPreferredLocations : mutable.HashMap [Int , Seq [String ]] =
88+ mutable.HashMap .empty
89+
8590 /**
8691 * Return the locations of the checkpoint file associated with the given partition.
8792 */
8893 protected override def getPreferredLocations (split : Partition ): Seq [String ] = {
89- val status = fs.getFileStatus(
90- new Path (checkpointPath, ReliableCheckpointRDD .checkpointFileName(split.index)))
91- val locations = fs.getFileBlockLocations(status, 0 , status.getLen)
92- locations.headOption.toList.flatMap(_.getHosts).filter(_ != " localhost" )
94+ val cachePreferrredLoc = SparkEnv .get.conf.get(CACHE_CHECKPOINT_PREFERRED_LOCS )
95+
96+ if (cachePreferrredLoc && cachedPreferredLocations.contains(split.index)) {
97+ cachedPreferredLocations(split.index)
98+ } else {
99+ val status = fs.getFileStatus(
100+ new Path (checkpointPath, ReliableCheckpointRDD .checkpointFileName(split.index)))
101+ val locations = fs.getFileBlockLocations(status, 0 , status.getLen)
102+ val preferredLoc = locations.headOption.toList.flatMap(_.getHosts).filter(_ != " localhost" )
103+ if (cachePreferrredLoc) {
104+ cachedPreferredLocations.update(split.index, preferredLoc)
105+ }
106+ preferredLoc
107+ }
93108 }
94109
95110 /**
0 commit comments