Skip to content

Commit 2e68066

Browse files
tomwhitesquito
authored andcommitted
[SPARK-8625] [CORE] Propagate user exceptions in tasks back to driver
This allows clients to retrieve the original exception from the cause field of the SparkException that is thrown by the driver. If the original exception is not in fact Serializable then it will not be returned, but the message and stacktrace will be. (All Java Throwables implement the Serializable interface, but this is no guarantee that a particular implementation can actually be serialized.) Author: Tom White <[email protected]> Closes #7014 from tomwhite/propagate-user-exceptions.
1 parent 3ecb379 commit 2e68066

File tree

12 files changed

+165
-34
lines changed

12 files changed

+165
-34
lines changed

core/src/main/scala/org/apache/spark/TaskEndReason.scala

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark
1919

20+
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
21+
2022
import org.apache.spark.annotation.DeveloperApi
2123
import org.apache.spark.executor.TaskMetrics
2224
import org.apache.spark.storage.BlockManagerId
@@ -90,18 +92,37 @@ case class FetchFailed(
9092
*
9193
* `fullStackTrace` is a better representation of the stack trace because it contains the whole
9294
* stack trace including the exception and its causes
95+
*
96+
* `exception` is the actual exception that caused the task to fail. It may be `None` in
97+
* the case that the exception is not in fact serializable. If a task fails more than
98+
* once (due to retries), `exception` is that one that caused the last failure.
9399
*/
94100
@DeveloperApi
95101
case class ExceptionFailure(
96102
className: String,
97103
description: String,
98104
stackTrace: Array[StackTraceElement],
99105
fullStackTrace: String,
100-
metrics: Option[TaskMetrics])
106+
metrics: Option[TaskMetrics],
107+
private val exceptionWrapper: Option[ThrowableSerializationWrapper])
101108
extends TaskFailedReason {
102109

110+
/**
111+
* `preserveCause` is used to keep the exception itself so it is available to the
112+
* driver. This may be set to `false` in the event that the exception is not in fact
113+
* serializable.
114+
*/
115+
private[spark] def this(e: Throwable, metrics: Option[TaskMetrics], preserveCause: Boolean) {
116+
this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics,
117+
if (preserveCause) Some(new ThrowableSerializationWrapper(e)) else None)
118+
}
119+
103120
private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
104-
this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
121+
this(e, metrics, preserveCause = true)
122+
}
123+
124+
def exception: Option[Throwable] = exceptionWrapper.flatMap {
125+
(w: ThrowableSerializationWrapper) => Option(w.exception)
105126
}
106127

107128
override def toErrorString: String =
@@ -127,6 +148,25 @@ case class ExceptionFailure(
127148
}
128149
}
129150

151+
/**
152+
* A class for recovering from exceptions when deserializing a Throwable that was
153+
* thrown in user task code. If the Throwable cannot be deserialized it will be null,
154+
* but the stacktrace and message will be preserved correctly in SparkException.
155+
*/
156+
private[spark] class ThrowableSerializationWrapper(var exception: Throwable) extends
157+
Serializable with Logging {
158+
private def writeObject(out: ObjectOutputStream): Unit = {
159+
out.writeObject(exception)
160+
}
161+
private def readObject(in: ObjectInputStream): Unit = {
162+
try {
163+
exception = in.readObject().asInstanceOf[Throwable]
164+
} catch {
165+
case e : Exception => log.warn("Task exception could not be deserialized", e)
166+
}
167+
}
168+
}
169+
130170
/**
131171
* :: DeveloperApi ::
132172
* The task finished successfully, but the result was lost from the executor's block manager before

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.executor
1919

20-
import java.io.File
20+
import java.io.{File, NotSerializableException}
2121
import java.lang.management.ManagementFactory
2222
import java.net.URL
2323
import java.nio.ByteBuffer
@@ -305,8 +305,16 @@ private[spark] class Executor(
305305
m
306306
}
307307
}
308-
val taskEndReason = new ExceptionFailure(t, metrics)
309-
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(taskEndReason))
308+
val serializedTaskEndReason = {
309+
try {
310+
ser.serialize(new ExceptionFailure(t, metrics))
311+
} catch {
312+
case _: NotSerializableException =>
313+
// t is not serializable so just send the stacktrace
314+
ser.serialize(new ExceptionFailure(t, metrics, false))
315+
}
316+
}
317+
execBackend.statusUpdate(taskId, TaskState.FAILED, serializedTaskEndReason)
310318

311319
// Don't forcibly exit unless the exception was inherently fatal, to avoid
312320
// stopping other tasks unnecessarily.

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ class DAGScheduler(
200200

201201
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
202202
// cancellation of the job itself.
203-
def taskSetFailed(taskSet: TaskSet, reason: String): Unit = {
204-
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
203+
def taskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable]): Unit = {
204+
eventProcessLoop.post(TaskSetFailed(taskSet, reason, exception))
205205
}
206206

207207
private[scheduler]
@@ -677,8 +677,11 @@ class DAGScheduler(
677677
submitWaitingStages()
678678
}
679679

680-
private[scheduler] def handleTaskSetFailed(taskSet: TaskSet, reason: String) {
681-
stageIdToStage.get(taskSet.stageId).foreach {abortStage(_, reason) }
680+
private[scheduler] def handleTaskSetFailed(
681+
taskSet: TaskSet,
682+
reason: String,
683+
exception: Option[Throwable]): Unit = {
684+
stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason, exception) }
682685
submitWaitingStages()
683686
}
684687

@@ -762,7 +765,7 @@ class DAGScheduler(
762765
}
763766
}
764767
} else {
765-
abortStage(stage, "No active job for stage " + stage.id)
768+
abortStage(stage, "No active job for stage " + stage.id, None)
766769
}
767770
}
768771

@@ -816,7 +819,7 @@ class DAGScheduler(
816819
case NonFatal(e) =>
817820
stage.makeNewStageAttempt(partitionsToCompute.size)
818821
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
819-
abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
822+
abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
820823
runningStages -= stage
821824
return
822825
}
@@ -845,13 +848,13 @@ class DAGScheduler(
845848
} catch {
846849
// In the case of a failure during serialization, abort the stage.
847850
case e: NotSerializableException =>
848-
abortStage(stage, "Task not serializable: " + e.toString)
851+
abortStage(stage, "Task not serializable: " + e.toString, Some(e))
849852
runningStages -= stage
850853

851854
// Abort execution
852855
return
853856
case NonFatal(e) =>
854-
abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
857+
abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}", Some(e))
855858
runningStages -= stage
856859
return
857860
}
@@ -878,7 +881,7 @@ class DAGScheduler(
878881
}
879882
} catch {
880883
case NonFatal(e) =>
881-
abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}")
884+
abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}", Some(e))
882885
runningStages -= stage
883886
return
884887
}
@@ -1098,7 +1101,8 @@ class DAGScheduler(
10981101
}
10991102

11001103
if (disallowStageRetryForTest) {
1101-
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
1104+
abortStage(failedStage, "Fetch failure will not retry stage due to testing config",
1105+
None)
11021106
} else if (failedStages.isEmpty) {
11031107
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
11041108
// in that case the event will already have been scheduled.
@@ -1126,7 +1130,7 @@ class DAGScheduler(
11261130
case commitDenied: TaskCommitDenied =>
11271131
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
11281132

1129-
case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
1133+
case exceptionFailure: ExceptionFailure =>
11301134
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
11311135

11321136
case TaskResultLost =>
@@ -1235,7 +1239,10 @@ class DAGScheduler(
12351239
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
12361240
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
12371241
*/
1238-
private[scheduler] def abortStage(failedStage: Stage, reason: String) {
1242+
private[scheduler] def abortStage(
1243+
failedStage: Stage,
1244+
reason: String,
1245+
exception: Option[Throwable]): Unit = {
12391246
if (!stageIdToStage.contains(failedStage.id)) {
12401247
// Skip all the actions if the stage has been removed.
12411248
return
@@ -1244,16 +1251,19 @@ class DAGScheduler(
12441251
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
12451252
failedStage.latestInfo.completionTime = Some(clock.getTimeMillis())
12461253
for (job <- dependentJobs) {
1247-
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
1254+
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason", exception)
12481255
}
12491256
if (dependentJobs.isEmpty) {
12501257
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
12511258
}
12521259
}
12531260

12541261
/** Fails a job and all stages that are only used by that job, and cleans up relevant state. */
1255-
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
1256-
val error = new SparkException(failureReason)
1262+
private def failJobAndIndependentStages(
1263+
job: ActiveJob,
1264+
failureReason: String,
1265+
exception: Option[Throwable] = None): Unit = {
1266+
val error = new SparkException(failureReason, exception.getOrElse(null))
12571267
var ableToCancelStages = true
12581268

12591269
val shouldInterruptThread =
@@ -1462,8 +1472,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
14621472
case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
14631473
dagScheduler.handleTaskCompletion(completion)
14641474

1465-
case TaskSetFailed(taskSet, reason) =>
1466-
dagScheduler.handleTaskSetFailed(taskSet, reason)
1475+
case TaskSetFailed(taskSet, reason, exception) =>
1476+
dagScheduler.handleTaskSetFailed(taskSet, reason, exception)
14671477

14681478
case ResubmitFailedStages =>
14691479
dagScheduler.resubmitFailedStages()

core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ private[scheduler] case class ExecutorAdded(execId: String, host: String) extend
7373
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
7474

7575
private[scheduler]
76-
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
76+
case class TaskSetFailed(taskSet: TaskSet, reason: String, exception: Option[Throwable])
77+
extends DAGSchedulerEvent
7778

7879
private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ private[spark] class TaskSetManager(
662662

663663
val failureReason = s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid, ${info.host}): " +
664664
reason.asInstanceOf[TaskFailedReason].toErrorString
665-
reason match {
665+
val failureException: Option[Throwable] = reason match {
666666
case fetchFailed: FetchFailed =>
667667
logWarning(failureReason)
668668
if (!successful(index)) {
@@ -671,6 +671,7 @@ private[spark] class TaskSetManager(
671671
}
672672
// Not adding to failed executors for FetchFailed.
673673
isZombie = true
674+
None
674675

675676
case ef: ExceptionFailure =>
676677
taskMetrics = ef.metrics.orNull
@@ -706,12 +707,15 @@ private[spark] class TaskSetManager(
706707
s"Lost task ${info.id} in stage ${taskSet.id} (TID $tid) on executor ${info.host}: " +
707708
s"${ef.className} (${ef.description}) [duplicate $dupCount]")
708709
}
710+
ef.exception
709711

710712
case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
711713
logWarning(failureReason)
714+
None
712715

713716
case e: TaskEndReason =>
714717
logError("Unknown TaskEndReason: " + e)
718+
None
715719
}
716720
// always add to failed executors
717721
failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
@@ -728,16 +732,16 @@ private[spark] class TaskSetManager(
728732
logError("Task %d in stage %s failed %d times; aborting job".format(
729733
index, taskSet.id, maxTaskFailures))
730734
abort("Task %d in stage %s failed %d times, most recent failure: %s\nDriver stacktrace:"
731-
.format(index, taskSet.id, maxTaskFailures, failureReason))
735+
.format(index, taskSet.id, maxTaskFailures, failureReason), failureException)
732736
return
733737
}
734738
}
735739
maybeFinishTaskSet()
736740
}
737741

738-
def abort(message: String): Unit = sched.synchronized {
742+
def abort(message: String, exception: Option[Throwable] = None): Unit = sched.synchronized {
739743
// TODO: Kill running tasks if we were not terminated due to a Mesos error
740-
sched.dagScheduler.taskSetFailed(taskSet, message)
744+
sched.dagScheduler.taskSetFailed(taskSet, message, exception)
741745
isZombie = true
742746
maybeFinishTaskSet()
743747
}

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ private[spark] object JsonProtocol {
790790
val fullStackTrace = Utils.jsonOption(json \ "Full Stack Trace").
791791
map(_.extract[String]).orNull
792792
val metrics = Utils.jsonOption(json \ "Metrics").map(taskMetricsFromJson)
793-
ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics)
793+
ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics, None)
794794
case `taskResultLost` => TaskResultLost
795795
case `taskKilled` => TaskKilled
796796
case `executorLostFailure` =>

core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -800,7 +800,7 @@ class ExecutorAllocationManagerSuite
800800
assert(maxNumExecutorsNeeded(manager) === 1)
801801

802802
// If the task is failed, we expect it to be resubmitted later.
803-
val taskEndReason = ExceptionFailure(null, null, null, null, null)
803+
val taskEndReason = ExceptionFailure(null, null, null, null, null, None)
804804
sc.listenerBus.postToAll(SparkListenerTaskEnd(0, 0, null, taskEndReason, taskInfo, null))
805805
assert(maxNumExecutorsNeeded(manager) === 1)
806806
}

core/src/test/scala/org/apache/spark/FailureSuite.scala

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark
1919

2020
import org.apache.spark.util.NonSerializable
2121

22-
import java.io.NotSerializableException
22+
import java.io.{IOException, NotSerializableException, ObjectInputStream}
2323

2424
// Common state shared by FailureSuite-launched tasks. We use a global object
2525
// for this because any local variables used in the task closures will rightfully
@@ -166,5 +166,69 @@ class FailureSuite extends SparkFunSuite with LocalSparkContext {
166166
assert(thrownDueToMemoryLeak.getMessage.contains("memory leak"))
167167
}
168168

169+
// Run a 3-task map job in which task 1 always fails with a exception message that
170+
// depends on the failure number, and check that we get the last failure.
171+
test("last failure cause is sent back to driver") {
172+
sc = new SparkContext("local[1,2]", "test")
173+
val data = sc.makeRDD(1 to 3, 3).map { x =>
174+
FailureSuiteState.synchronized {
175+
FailureSuiteState.tasksRun += 1
176+
if (x == 3) {
177+
FailureSuiteState.tasksFailed += 1
178+
throw new UserException("oops",
179+
new IllegalArgumentException("failed=" + FailureSuiteState.tasksFailed))
180+
}
181+
}
182+
x * x
183+
}
184+
val thrown = intercept[SparkException] {
185+
data.collect()
186+
}
187+
FailureSuiteState.synchronized {
188+
assert(FailureSuiteState.tasksRun === 4)
189+
}
190+
assert(thrown.getClass === classOf[SparkException])
191+
assert(thrown.getCause.getClass === classOf[UserException])
192+
assert(thrown.getCause.getMessage === "oops")
193+
assert(thrown.getCause.getCause.getClass === classOf[IllegalArgumentException])
194+
assert(thrown.getCause.getCause.getMessage === "failed=2")
195+
FailureSuiteState.clear()
196+
}
197+
198+
test("failure cause stacktrace is sent back to driver if exception is not serializable") {
199+
sc = new SparkContext("local", "test")
200+
val thrown = intercept[SparkException] {
201+
sc.makeRDD(1 to 3).foreach { _ => throw new NonSerializableUserException }
202+
}
203+
assert(thrown.getClass === classOf[SparkException])
204+
assert(thrown.getCause === null)
205+
assert(thrown.getMessage.contains("NonSerializableUserException"))
206+
FailureSuiteState.clear()
207+
}
208+
209+
test("failure cause stacktrace is sent back to driver if exception is not deserializable") {
210+
sc = new SparkContext("local", "test")
211+
val thrown = intercept[SparkException] {
212+
sc.makeRDD(1 to 3).foreach { _ => throw new NonDeserializableUserException }
213+
}
214+
assert(thrown.getClass === classOf[SparkException])
215+
assert(thrown.getCause === null)
216+
assert(thrown.getMessage.contains("NonDeserializableUserException"))
217+
FailureSuiteState.clear()
218+
}
219+
169220
// TODO: Need to add tests with shuffle fetch failures.
170221
}
222+
223+
class UserException(message: String, cause: Throwable)
224+
extends RuntimeException(message, cause)
225+
226+
class NonSerializableUserException extends RuntimeException {
227+
val nonSerializableInstanceVariable = new NonSerializable
228+
}
229+
230+
class NonDeserializableUserException extends RuntimeException {
231+
private def readObject(in: ObjectInputStream): Unit = {
232+
throw new IOException("Intentional exception during deserialization.")
233+
}
234+
}

0 commit comments

Comments
 (0)