Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,19 @@ case object TaskResultLost extends TaskFailedReason {
* Task was killed intentionally and needs to be rescheduled.
*/
@DeveloperApi
case class TaskKilled(reason: String) extends TaskFailedReason {
override def toErrorString: String = s"TaskKilled ($reason)"
case class TaskKilled(
reason: String,
accumUpdates: Seq[AccumulableInfo] = Seq.empty,
private[spark] var accums: Seq[AccumulatorV2[_, _]] = Nil)
extends TaskFailedReason {

override def toErrorString: String = "TaskKilled ($reason)"
override def countTowardsTaskFailures: Boolean = false

private[spark] def withAccums(accums: Seq[AccumulatorV2[_, _]]): TaskKilled = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this method is really necessary at all, you could just pass it in the constructor in the places its used.

this.accums = accums
this
}
}

/**
Expand Down
33 changes: 30 additions & 3 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -429,15 +429,42 @@ private[spark] class Executor(

case t: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")

// Collect latest accumulator values to report back to the driver
val accums: Seq[AccumulatorV2[_, _]] =
if (task != null) {
task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.collectAccumulatorUpdates(taskFailed = true)
} else {
Seq.empty
}
val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be refactored, and not repeated 3 times.


setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))

val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates).withAccums(accums))
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)

case _: InterruptedException if task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")

// Collect latest accumulator values to report back to the driver
val accums: Seq[AccumulatorV2[_, _]] =
if (task != null) {
task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.collectAccumulatorUpdates(taskFailed = true)
} else {
Seq.empty
}
val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))

setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))

val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates).withAccums(accums))
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)

case CausedBy(cDE: CommitDeniedException) =>
val reason = cDE.toTaskFailedReason
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1355,14 +1355,18 @@ class DAGScheduler(
case commitDenied: TaskCommitDenied =>
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits

case exceptionFailure: ExceptionFailure =>
// Tasks failed with exceptions might still have accumulator updates.
case _: ExceptionFailure =>
// Tasks killed or failed with exceptions might still have accumulator updates.
updateAccumulators(event)

case _: TaskKilled =>
// Tasks killed or failed with exceptions might still have accumulator updates.
updateAccumulators(event)

case TaskResultLost =>
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.

case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
case _: ExecutorLostFailure | UnknownReason =>
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
// will abort the job.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -812,13 +812,19 @@ private[spark] class TaskSetManager(
}
ef.exception

case tk: TaskKilled =>
// TaskKilled might have accumulator updates
accumUpdates = tk.accums
logWarning(failureReason)
None

case e: ExecutorLostFailure if !e.exitCausedByApp =>
logInfo(s"Task $tid failed because while it was being computed, its executor " +
"exited for a reason unrelated to the task. Not counting this failure towards the " +
"maximum number of failures for the task.")
None

case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
case e: TaskFailedReason => // TaskResultLost and others
logWarning(failureReason)
None
}
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,9 @@ private[spark] object JsonProtocol {
("Exit Caused By App" -> exitCausedByApp) ~
("Loss Reason" -> reason.map(_.toString))
case taskKilled: TaskKilled =>
("Kill Reason" -> taskKilled.reason)
val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList)
("Kill Reason" -> taskKilled.reason) ~
("Accumulator Updates" -> accumUpdates)
case _ => Utils.emptyJson
}
("Reason" -> reason) ~ json
Expand Down Expand Up @@ -882,7 +884,10 @@ private[spark] object JsonProtocol {
case `taskKilled` =>
val killReason = Utils.jsonOption(json \ "Kill Reason")
.map(_.extract[String]).getOrElse("unknown reason")
TaskKilled(killReason)
val accumUpdates = Utils.jsonOption(json \ "Accumulator Updates")
.map(_.extract[List[JValue]].map(accumulableInfoFromJson))
.getOrElse(Seq[AccumulableInfo]())
TaskKilled(killReason, accumUpdates)
case `taskCommitDenied` =>
// Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON
// de/serialization logic was not added until 1.5.1. To provide backward compatibility
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1760,7 +1760,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
assertDataStructuresEmpty()
}

test("accumulators are updated on exception failures") {
test("accumulators are updated on exception failures and task killed") {
val acc1 = AccumulatorSuite.createLongAccum("ingenieur")
val acc2 = AccumulatorSuite.createLongAccum("boulanger")
val acc3 = AccumulatorSuite.createLongAccum("agriculteur")
Expand All @@ -1776,15 +1776,26 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
val accUpdate3 = new LongAccumulator
accUpdate3.metadata = acc3.metadata
accUpdate3.setValue(18)
val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3)
val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo)

val accumUpdates1 = Seq(accUpdate1, accUpdate2)
val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo)
val exceptionFailure = new ExceptionFailure(
new SparkException("fondue?"),
accumInfo).copy(accums = accumUpdates)
accumInfo1).copy(accums = accumUpdates1)
submit(new MyRDD(sc, 1, Nil), Array(0))
runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result"))

assert(AccumulatorContext.get(acc1.id).get.value === 15L)
assert(AccumulatorContext.get(acc2.id).get.value === 13L)

val accumUpdates2 = Seq(accUpdate3)
val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo)

val taskKilled = new TaskKilled(
"test",
accumInfo2).copy(accums = accumUpdates2)
runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result"))

assert(AccumulatorContext.get(acc3.id).get.value === 18L)
}

Expand Down Expand Up @@ -2323,6 +2334,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
val accumUpdates = reason match {
case Success => task.metrics.accumulators()
case ef: ExceptionFailure => ef.accums
case tk: TaskKilled => tk.accums
case _ => Seq.empty
}
CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo)
Expand Down