Skip to content

Commit 60b541e

Browse files
tdasAndrew Or
authored andcommitted
[SPARK-12004] Preserve the RDD partitioner through RDD checkpointing
The solution is the save the RDD partitioner in a separate file in the RDD checkpoint directory. That is, `<checkpoint dir>/_partitioner`. In most cases, whether the RDD partitioner was recovered or not, does not affect the correctness, only reduces performance. So this solution makes a best-effort attempt to save and recover the partitioner. If either fails, the checkpointing is not affected. This makes this patch safe and backward compatible. Author: Tathagata Das <[email protected]> Closes #9983 from tdas/SPARK-12004.
1 parent 2cef1cd commit 60b541e

File tree

3 files changed

+173
-31
lines changed

3 files changed

+173
-31
lines changed

core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala

Lines changed: 116 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,22 @@ package org.apache.spark.rdd
2020
import java.io.IOException
2121

2222
import scala.reflect.ClassTag
23+
import scala.util.control.NonFatal
2324

2425
import org.apache.hadoop.fs.Path
2526

2627
import org.apache.spark._
2728
import org.apache.spark.broadcast.Broadcast
28-
import org.apache.spark.deploy.SparkHadoopUtil
2929
import org.apache.spark.util.{SerializableConfiguration, Utils}
3030

3131
/**
3232
* An RDD that reads from checkpoint files previously written to reliable storage.
3333
*/
3434
private[spark] class ReliableCheckpointRDD[T: ClassTag](
3535
sc: SparkContext,
36-
val checkpointPath: String)
37-
extends CheckpointRDD[T](sc) {
36+
val checkpointPath: String,
37+
_partitioner: Option[Partitioner] = None
38+
) extends CheckpointRDD[T](sc) {
3839

3940
@transient private val hadoopConf = sc.hadoopConfiguration
4041
@transient private val cpath = new Path(checkpointPath)
@@ -47,7 +48,13 @@ private[spark] class ReliableCheckpointRDD[T: ClassTag](
4748
/**
4849
* Return the path of the checkpoint directory this RDD reads data from.
4950
*/
50-
override def getCheckpointFile: Option[String] = Some(checkpointPath)
51+
override val getCheckpointFile: Option[String] = Some(checkpointPath)
52+
53+
override val partitioner: Option[Partitioner] = {
54+
_partitioner.orElse {
55+
ReliableCheckpointRDD.readCheckpointedPartitionerFile(context, checkpointPath)
56+
}
57+
}
5158

5259
/**
5360
* Return partitions described by the files in the checkpoint directory.
@@ -100,10 +107,52 @@ private[spark] object ReliableCheckpointRDD extends Logging {
100107
"part-%05d".format(partitionIndex)
101108
}
102109

110+
private def checkpointPartitionerFileName(): String = {
111+
"_partitioner"
112+
}
113+
114+
/**
115+
* Write RDD to checkpoint files and return a ReliableCheckpointRDD representing the RDD.
116+
*/
117+
def writeRDDToCheckpointDirectory[T: ClassTag](
118+
originalRDD: RDD[T],
119+
checkpointDir: String,
120+
blockSize: Int = -1): ReliableCheckpointRDD[T] = {
121+
122+
val sc = originalRDD.sparkContext
123+
124+
// Create the output path for the checkpoint
125+
val checkpointDirPath = new Path(checkpointDir)
126+
val fs = checkpointDirPath.getFileSystem(sc.hadoopConfiguration)
127+
if (!fs.mkdirs(checkpointDirPath)) {
128+
throw new SparkException(s"Failed to create checkpoint path $checkpointDirPath")
129+
}
130+
131+
// Save to file, and reload it as an RDD
132+
val broadcastedConf = sc.broadcast(
133+
new SerializableConfiguration(sc.hadoopConfiguration))
134+
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
135+
sc.runJob(originalRDD,
136+
writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)
137+
138+
if (originalRDD.partitioner.nonEmpty) {
139+
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
140+
}
141+
142+
val newRDD = new ReliableCheckpointRDD[T](
143+
sc, checkpointDirPath.toString, originalRDD.partitioner)
144+
if (newRDD.partitions.length != originalRDD.partitions.length) {
145+
throw new SparkException(
146+
s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
147+
s"number of partitions from original RDD $originalRDD(${originalRDD.partitions.length})")
148+
}
149+
newRDD
150+
}
151+
103152
/**
104-
* Write this partition's values to a checkpoint file.
153+
* Write a RDD partition's data to a checkpoint file.
105154
*/
106-
def writeCheckpointFile[T: ClassTag](
155+
def writePartitionToCheckpointFile[T: ClassTag](
107156
path: String,
108157
broadcastedConf: Broadcast[SerializableConfiguration],
109158
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
@@ -151,6 +200,67 @@ private[spark] object ReliableCheckpointRDD extends Logging {
151200
}
152201
}
153202

203+
/**
204+
* Write a partitioner to the given RDD checkpoint directory. This is done on a best-effort
205+
* basis; any exception while writing the partitioner is caught, logged and ignored.
206+
*/
207+
private def writePartitionerToCheckpointDir(
208+
sc: SparkContext, partitioner: Partitioner, checkpointDirPath: Path): Unit = {
209+
try {
210+
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
211+
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
212+
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
213+
val fileOutputStream = fs.create(partitionerFilePath, false, bufferSize)
214+
val serializer = SparkEnv.get.serializer.newInstance()
215+
val serializeStream = serializer.serializeStream(fileOutputStream)
216+
Utils.tryWithSafeFinally {
217+
serializeStream.writeObject(partitioner)
218+
} {
219+
serializeStream.close()
220+
}
221+
logDebug(s"Written partitioner to $partitionerFilePath")
222+
} catch {
223+
case NonFatal(e) =>
224+
logWarning(s"Error writing partitioner $partitioner to $checkpointDirPath")
225+
}
226+
}
227+
228+
229+
/**
230+
* Read a partitioner from the given RDD checkpoint directory, if it exists.
231+
* This is done on a best-effort basis; any exception while reading the partitioner is
232+
* caught, logged and ignored.
233+
*/
234+
private def readCheckpointedPartitionerFile(
235+
sc: SparkContext,
236+
checkpointDirPath: String): Option[Partitioner] = {
237+
try {
238+
val bufferSize = sc.conf.getInt("spark.buffer.size", 65536)
239+
val partitionerFilePath = new Path(checkpointDirPath, checkpointPartitionerFileName)
240+
val fs = partitionerFilePath.getFileSystem(sc.hadoopConfiguration)
241+
if (fs.exists(partitionerFilePath)) {
242+
val fileInputStream = fs.open(partitionerFilePath, bufferSize)
243+
val serializer = SparkEnv.get.serializer.newInstance()
244+
val deserializeStream = serializer.deserializeStream(fileInputStream)
245+
val partitioner = Utils.tryWithSafeFinally[Partitioner] {
246+
deserializeStream.readObject[Partitioner]
247+
} {
248+
deserializeStream.close()
249+
}
250+
logDebug(s"Read partitioner from $partitionerFilePath")
251+
Some(partitioner)
252+
} else {
253+
logDebug("No partitioner file")
254+
None
255+
}
256+
} catch {
257+
case NonFatal(e) =>
258+
logWarning(s"Error reading partitioner from $checkpointDirPath, " +
259+
s"partitioner will not be recovered which may lead to performance loss", e)
260+
None
261+
}
262+
}
263+
154264
/**
155265
* Read the content of the specified checkpoint file.
156266
*/

core/src/main/scala/org/apache/spark/rdd/ReliableRDDCheckpointData.scala

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,25 +55,7 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
5555
* This is called immediately after the first action invoked on this RDD has completed.
5656
*/
5757
protected override def doCheckpoint(): CheckpointRDD[T] = {
58-
59-
// Create the output path for the checkpoint
60-
val path = new Path(cpDir)
61-
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
62-
if (!fs.mkdirs(path)) {
63-
throw new SparkException(s"Failed to create checkpoint path $cpDir")
64-
}
65-
66-
// Save to file, and reload it as an RDD
67-
val broadcastedConf = rdd.context.broadcast(
68-
new SerializableConfiguration(rdd.context.hadoopConfiguration))
69-
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
70-
rdd.context.runJob(rdd, ReliableCheckpointRDD.writeCheckpointFile[T](cpDir, broadcastedConf) _)
71-
val newRDD = new ReliableCheckpointRDD[T](rdd.context, cpDir)
72-
if (newRDD.partitions.length != rdd.partitions.length) {
73-
throw new SparkException(
74-
s"Checkpoint RDD $newRDD(${newRDD.partitions.length}) has different " +
75-
s"number of partitions from original RDD $rdd(${rdd.partitions.length})")
76-
}
58+
val newRDD = ReliableCheckpointRDD.writeRDDToCheckpointDirectory(rdd, cpDir)
7759

7860
// Optionally clean our checkpoint files if the reference is out of scope
7961
if (rdd.conf.getBoolean("spark.cleaner.referenceTracking.cleanCheckpoints", false)) {
@@ -83,7 +65,6 @@ private[spark] class ReliableRDDCheckpointData[T: ClassTag](@transient private v
8365
}
8466

8567
logInfo(s"Done checkpointing RDD ${rdd.id} to $cpDir, new parent is RDD ${newRDD.id}")
86-
8768
newRDD
8869
}
8970

core/src/test/scala/org/apache/spark/CheckpointSuite.scala

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ import java.io.File
2121

2222
import scala.reflect.ClassTag
2323

24-
import org.apache.spark.CheckpointSuite._
24+
import org.apache.hadoop.fs.Path
25+
2526
import org.apache.spark.rdd._
2627
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
2728
import org.apache.spark.util.Utils
@@ -74,8 +75,10 @@ trait RDDCheckpointTester { self: SparkFunSuite =>
7475

7576
// Test whether the checkpoint file has been created
7677
if (reliableCheckpoint) {
77-
assert(
78-
collectFunc(sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)) === result)
78+
assert(operatedRDD.getCheckpointFile.nonEmpty)
79+
val recoveredRDD = sparkContext.checkpointFile[U](operatedRDD.getCheckpointFile.get)
80+
assert(collectFunc(recoveredRDD) === result)
81+
assert(recoveredRDD.partitioner === operatedRDD.partitioner)
7982
}
8083

8184
// Test whether dependencies have been changed from its earlier parent RDD
@@ -211,9 +214,14 @@ trait RDDCheckpointTester { self: SparkFunSuite =>
211214
}
212215

213216
/** Run a test twice, once for local checkpointing and once for reliable checkpointing. */
214-
protected def runTest(name: String)(body: Boolean => Unit): Unit = {
217+
protected def runTest(
218+
name: String,
219+
skipLocalCheckpoint: Boolean = false
220+
)(body: Boolean => Unit): Unit = {
215221
test(name + " [reliable checkpoint]")(body(true))
216-
test(name + " [local checkpoint]")(body(false))
222+
if (!skipLocalCheckpoint) {
223+
test(name + " [local checkpoint]")(body(false))
224+
}
217225
}
218226

219227
/**
@@ -264,6 +272,49 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
264272
assert(flatMappedRDD.collect() === result)
265273
}
266274

275+
runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean =>
276+
277+
def testPartitionerCheckpointing(
278+
partitioner: Partitioner,
279+
corruptPartitionerFile: Boolean = false
280+
): Unit = {
281+
val rddWithPartitioner = sc.makeRDD(1 to 4).map { _ -> 1 }.partitionBy(partitioner)
282+
rddWithPartitioner.checkpoint()
283+
rddWithPartitioner.count()
284+
assert(rddWithPartitioner.getCheckpointFile.get.nonEmpty,
285+
"checkpointing was not successful")
286+
287+
if (corruptPartitionerFile) {
288+
// Overwrite the partitioner file with garbage data
289+
val checkpointDir = new Path(rddWithPartitioner.getCheckpointFile.get)
290+
val fs = checkpointDir.getFileSystem(sc.hadoopConfiguration)
291+
val partitionerFile = fs.listStatus(checkpointDir)
292+
.find(_.getPath.getName.contains("partitioner"))
293+
.map(_.getPath)
294+
require(partitionerFile.nonEmpty, "could not find the partitioner file for testing")
295+
val output = fs.create(partitionerFile.get, true)
296+
output.write(100)
297+
output.close()
298+
}
299+
300+
val newRDD = sc.checkpointFile[(Int, Int)](rddWithPartitioner.getCheckpointFile.get)
301+
assert(newRDD.collect().toSet === rddWithPartitioner.collect().toSet, "RDD not recovered")
302+
303+
if (!corruptPartitionerFile) {
304+
assert(newRDD.partitioner != None, "partitioner not recovered")
305+
assert(newRDD.partitioner === rddWithPartitioner.partitioner,
306+
"recovered partitioner does not match")
307+
} else {
308+
assert(newRDD.partitioner == None, "partitioner unexpectedly recovered")
309+
}
310+
}
311+
312+
testPartitionerCheckpointing(partitioner)
313+
314+
// Test that corrupted partitioner file does not prevent recovery of RDD
315+
testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true)
316+
}
317+
267318
runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
268319
testRDD(_.map(x => x.toString), reliableCheckpoint)
269320
testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)

0 commit comments

Comments
 (0)