@@ -20,21 +20,22 @@ package org.apache.spark.rdd
2020import java .io .IOException
2121
2222import scala .reflect .ClassTag
23+ import scala .util .control .NonFatal
2324
2425import org .apache .hadoop .fs .Path
2526
2627import org .apache .spark ._
2728import org .apache .spark .broadcast .Broadcast
28- import org .apache .spark .deploy .SparkHadoopUtil
2929import org .apache .spark .util .{SerializableConfiguration , Utils }
3030
3131/**
3232 * An RDD that reads from checkpoint files previously written to reliable storage.
3333 */
3434private [spark] class ReliableCheckpointRDD [T : ClassTag ](
3535 sc : SparkContext ,
36- val checkpointPath : String )
37- extends CheckpointRDD [T ](sc) {
36+ val checkpointPath : String ,
37+ _partitioner : Option [Partitioner ] = None
38+ ) extends CheckpointRDD [T ](sc) {
3839
3940 @ transient private val hadoopConf = sc.hadoopConfiguration
4041 @ transient private val cpath = new Path (checkpointPath)
@@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
4748 /**
4849 * Return the path of the checkpoint directory this RDD reads data from.
4950 */
50- override def getCheckpointFile : Option [String ] = Some (checkpointPath)
51+ override val getCheckpointFile : Option [String ] = Some (checkpointPath)
52+
53+ override val partitioner : Option [Partitioner ] = {
54+ _partitioner.orElse {
55+ ReliableCheckpointRDD .readCheckpointedPartitionerFile(context, checkpointPath)
56+ }
57+ }
5158
5259 /**
5360 * Return partitions described by the files in the checkpoint directory.
@@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging {
100107 " part-%05d" .format(partitionIndex)
101108 }
102109
110+ private def checkpointPartitionerFileName (): String = {
111+ " _partitioner"
112+ }
113+
114+ /**
115+ * Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD.
116+ */
117+ def writeRDDToCheckpointDirectory [T : ClassTag ](
118+ originalRDD : RDD [T ],
119+ checkpointDir : String ,
120+ blockSize : Int = - 1 ): ReliableCheckpointRDD [T ] = {
121+
122+ val sc = originalRDD.sparkContext
123+
124+ // Create the output path for the checkpoint
125+ val checkpointDirPath = new Path (checkpointDir)
126+ val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
127+ if (! fs.mkdirs(checkpointDirPath)) {
128+ throw new SparkException (s " Failed to create checkpoint path $checkpointDirPath" )
129+ }
130+
131+ // Save to file, and reload it as an RDD
132+ val broadcastedConf = sc.broadcast(
133+ new SerializableConfiguration (sc.hadoopConfiguration))
134+ // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
135+ sc.runJob(originalRDD,
136+ writePartitionToCheckpointFile[T ](checkpointDirPath.toString, broadcastedConf) _)
137+
138+ if (originalRDD.partitioner.nonEmpty) {
139+ writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
140+ }
141+
142+ val newRDD = new ReliableCheckpointRDD [T ](
143+ sc, checkpointDirPath.toString, originalRDD.partitioner)
144+ if (newRDD.partitions.length != originalRDD.partitions.length) {
145+ throw new SparkException (
146+ s " Checkpoint RDD $newRDD( ${newRDD.partitions.length}) has different " +
147+ s " number of partitions from original RDD $originalRDD( ${originalRDD.partitions.length}) " )
148+ }
149+ newRDD
150+ }
151+
103152 /**
104- * Write this partition's values to a checkpoint file.
153+ * Write a RDD partition's data to a checkpoint file.
105154 */
106- def writeCheckpointFile [T : ClassTag ](
155+ def writePartitionToCheckpointFile [T : ClassTag ](
107156 path : String ,
108157 broadcastedConf : Broadcast [SerializableConfiguration ],
109158 blockSize : Int = - 1 )(ctx : TaskContext , iterator : Iterator [T ]) {
@@ -151,6 +200,67 @@ private[spark] object ReliableCheckpointRDD extends Logging {
151200 }
152201 }
153202
203+ /**
204+ * Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort
205+ * basis; any exception while writing the partitioner is caught, logged and ignored.
206+ */
207+ private def writePartitionerToCheckpointDir (
208+ sc : SparkContext , partitioner : Partitioner , checkpointDirPath : Path ): Unit = {
209+ try {
210+ val partitionerFilePath = new Path (checkpointDirPath, checkpointPartitionerFileName)
211+ val bufferSize = sc.conf.getInt(" spark.buffer.size" , 65536 )
212+ val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
213+ val fileOutputStream = fs.create(partitionerFilePath, false , bufferSize)
214+ val serializer = SparkEnv .get.serializer.newInstance()
215+ val serializeStream = serializer.serializeStream(fileOutputStream)
216+ Utils .tryWithSafeFinally {
217+ serializeStream.writeObject(partitioner)
218+ } {
219+ serializeStream.close()
220+ }
221+ logDebug(s " Written partitioner to $partitionerFilePath" )
222+ } catch {
223+ case NonFatal (e) =>
224+ logWarning(s " Error writing partitioner $partitioner to $checkpointDirPath" )
225+ }
226+ }
227+
228+
229+ /**
230+ * Read a partitioner from the given RDD checkpoint directory, if it exists.
231+ * This is done on a best-effort basis; any exception while reading the partitioner is
232+ * caught, logged and ignored.
233+ */
234+ private def readCheckpointedPartitionerFile (
235+ sc : SparkContext ,
236+ checkpointDirPath : String ): Option [Partitioner ] = {
237+ try {
238+ val bufferSize = sc.conf.getInt(" spark.buffer.size" , 65536 )
239+ val partitionerFilePath = new Path (checkpointDirPath, checkpointPartitionerFileName)
240+ val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
241+ if (fs.exists(partitionerFilePath)) {
242+ val fileInputStream = fs.open(partitionerFilePath, bufferSize)
243+ val serializer = SparkEnv .get.serializer.newInstance()
244+ val deserializeStream = serializer.deserializeStream(fileInputStream)
245+ val partitioner = Utils .tryWithSafeFinally[Partitioner ] {
246+ deserializeStream.readObject[Partitioner ]
247+ } {
248+ deserializeStream.close()
249+ }
250+ logDebug(s " Read partitioner from $partitionerFilePath" )
251+ Some (partitioner)
252+ } else {
253+ logDebug(" No partitioner file" )
254+ None
255+ }
256+ } catch {
257+ case NonFatal (e) =>
258+ logWarning(s " Error reading partitioner from $checkpointDirPath, " +
259+ s " partitioner will not be recovered which may lead to performance loss " , e)
260+ None
261+ }
262+ }
263+
154264 /**
155265 * Read the content of the specified checkpoint file.
156266 */
0 commit comments