Skip to content

Commit e1bbf1a

Browse files
committed
[SPARK-8606] Prevent exceptions in RDD.getPreferredLocations() from crashing DAGScheduler
If `RDD.getPreferredLocations()` throws an exception it may crash the DAGScheduler and SparkContext. This patch addresses this by adding a try-catch block. Author: Josh Rosen <[email protected]> Closes #7023 from JoshRosen/SPARK-8606 and squashes the following commits: 770b169 [Josh Rosen] Fix getPreferredLocations() DAGScheduler crash with try block. 44a9b55 [Josh Rosen] Add test of a buggy getPartitions() method 19aa9f7 [Josh Rosen] Add (failing) regression test for getPreferredLocations() DAGScheduler crash (cherry picked from commit 0b5abbf) Signed-off-by: Josh Rosen <[email protected]>
1 parent a2dbb48 commit e1bbf1a

File tree

2 files changed

+53
-15
lines changed

2 files changed

+53
-15
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -886,22 +886,29 @@ class DAGScheduler(
886886
return
887887
}
888888

889-
val tasks: Seq[Task[_]] = stage match {
890-
case stage: ShuffleMapStage =>
891-
partitionsToCompute.map { id =>
892-
val locs = getPreferredLocs(stage.rdd, id)
893-
val part = stage.rdd.partitions(id)
894-
new ShuffleMapTask(stage.id, taskBinary, part, locs)
895-
}
889+
val tasks: Seq[Task[_]] = try {
890+
stage match {
891+
case stage: ShuffleMapStage =>
892+
partitionsToCompute.map { id =>
893+
val locs = getPreferredLocs(stage.rdd, id)
894+
val part = stage.rdd.partitions(id)
895+
new ShuffleMapTask(stage.id, taskBinary, part, locs)
896+
}
896897

897-
case stage: ResultStage =>
898-
val job = stage.resultOfJob.get
899-
partitionsToCompute.map { id =>
900-
val p: Int = job.partitions(id)
901-
val part = stage.rdd.partitions(p)
902-
val locs = getPreferredLocs(stage.rdd, p)
903-
new ResultTask(stage.id, taskBinary, part, locs, id)
904-
}
898+
case stage: ResultStage =>
899+
val job = stage.resultOfJob.get
900+
partitionsToCompute.map { id =>
901+
val p: Int = job.partitions(id)
902+
val part = stage.rdd.partitions(p)
903+
val locs = getPreferredLocs(stage.rdd, p)
904+
new ResultTask(stage.id, taskBinary, part, locs, id)
905+
}
906+
}
907+
} catch {
908+
case NonFatal(e) =>
909+
abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
910+
runningStages -= stage
911+
return
905912
}
906913

907914
if (tasks.size > 0) {

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,37 @@ class DAGSchedulerSuite
757757
assert(sc.parallelize(1 to 10, 2).first() === 1)
758758
}
759759

760+
test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") {
761+
val e1 = intercept[DAGSchedulerSuiteDummyException] {
762+
val rdd = new MyRDD(sc, 2, Nil) {
763+
override def getPartitions: Array[Partition] = {
764+
throw new DAGSchedulerSuiteDummyException
765+
}
766+
}
767+
rdd.reduceByKey(_ + _, 1).count()
768+
}
769+
770+
// Make sure we can still run local commands as well as cluster commands.
771+
assert(sc.parallelize(1 to 10, 2).count() === 10)
772+
assert(sc.parallelize(1 to 10, 2).first() === 1)
773+
}
774+
775+
test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") {
776+
val e1 = intercept[SparkException] {
777+
val rdd = new MyRDD(sc, 2, Nil) {
778+
override def getPreferredLocations(split: Partition): Seq[String] = {
779+
throw new DAGSchedulerSuiteDummyException
780+
}
781+
}
782+
rdd.count()
783+
}
784+
assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName))
785+
786+
// Make sure we can still run local commands as well as cluster commands.
787+
assert(sc.parallelize(1 to 10, 2).count() === 10)
788+
assert(sc.parallelize(1 to 10, 2).first() === 1)
789+
}
790+
760791
test("accumulator not calculated for resubmitted result stage") {
761792
// just for register
762793
val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)

0 commit comments

Comments
 (0)