Skip to content

Commit ac3591d

Browse files
committed
[SPARK-8606] Prevent exceptions in RDD.getPreferredLocations() from crashing DAGScheduler
If `RDD.getPreferredLocations()` throws an exception it may crash the DAGScheduler and SparkContext. This patch addresses this by adding a try-catch block. Author: Josh Rosen <[email protected]> Closes #7023 from JoshRosen/SPARK-8606 and squashes the following commits: 770b169 [Josh Rosen] Fix getPreferredLocations() DAGScheduler crash with try block. 44a9b55 [Josh Rosen] Add test of a buggy getPartitions() method 19aa9f7 [Josh Rosen] Add (failing) regression test for getPreferredLocations() DAGScheduler crash (cherry picked from commit 0b5abbf) Signed-off-by: Josh Rosen <[email protected]>
1 parent 88e303f commit ac3591d

File tree

2 files changed

+51
-13
lines changed

2 files changed

+51
-13
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -849,20 +849,27 @@ class DAGScheduler(
849849
return
850850
}
851851

852-
val tasks: Seq[Task[_]] = if (stage.isShuffleMap) {
853-
partitionsToCompute.map { id =>
854-
val locs = getPreferredLocs(stage.rdd, id)
855-
val part = stage.rdd.partitions(id)
856-
new ShuffleMapTask(stage.id, taskBinary, part, locs)
857-
}
858-
} else {
859-
val job = stage.resultOfJob.get
860-
partitionsToCompute.map { id =>
861-
val p: Int = job.partitions(id)
862-
val part = stage.rdd.partitions(p)
863-
val locs = getPreferredLocs(stage.rdd, p)
864-
new ResultTask(stage.id, taskBinary, part, locs, id)
852+
val tasks: Seq[Task[_]] = try {
853+
if (stage.isShuffleMap) {
854+
partitionsToCompute.map { id =>
855+
val locs = getPreferredLocs(stage.rdd, id)
856+
val part = stage.rdd.partitions(id)
857+
new ShuffleMapTask(stage.id, taskBinary, part, locs)
858+
}
859+
} else {
860+
val job = stage.resultOfJob.get
861+
partitionsToCompute.map { id =>
862+
val p: Int = job.partitions(id)
863+
val part = stage.rdd.partitions(p)
864+
val locs = getPreferredLocs(stage.rdd, p)
865+
new ResultTask(stage.id, taskBinary, part, locs, id)
866+
}
865867
}
868+
} catch {
869+
case NonFatal(e) =>
870+
abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
871+
runningStages -= stage
872+
return
866873
}
867874

868875
if (tasks.size > 0) {

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -738,6 +738,37 @@ class DAGSchedulerSuite
738738
assert(sc.parallelize(1 to 10, 2).first() === 1)
739739
}
740740

741+
test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") {
742+
val e1 = intercept[DAGSchedulerSuiteDummyException] {
743+
val rdd = new MyRDD(sc, 2, Nil) {
744+
override def getPartitions: Array[Partition] = {
745+
throw new DAGSchedulerSuiteDummyException
746+
}
747+
}
748+
rdd.reduceByKey(_ + _, 1).count()
749+
}
750+
751+
// Make sure we can still run local commands as well as cluster commands.
752+
assert(sc.parallelize(1 to 10, 2).count() === 10)
753+
assert(sc.parallelize(1 to 10, 2).first() === 1)
754+
}
755+
756+
test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") {
757+
val e1 = intercept[SparkException] {
758+
val rdd = new MyRDD(sc, 2, Nil) {
759+
override def getPreferredLocations(split: Partition): Seq[String] = {
760+
throw new DAGSchedulerSuiteDummyException
761+
}
762+
}
763+
rdd.count()
764+
}
765+
assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName))
766+
767+
// Make sure we can still run local commands as well as cluster commands.
768+
assert(sc.parallelize(1 to 10, 2).count() === 10)
769+
assert(sc.parallelize(1 to 10, 2).first() === 1)
770+
}
771+
741772
test("accumulator not calculated for resubmitted result stage") {
742773
//just for register
743774
val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam)

0 commit comments

Comments
 (0)