Skip to content
Closed
8 changes: 7 additions & 1 deletion core/src/main/scala/org/apache/spark/TaskEndReason.scala
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,15 @@ case object TaskResultLost extends TaskFailedReason {
* Task was killed intentionally and needs to be rescheduled.
*/
@DeveloperApi
case class TaskKilled(reason: String) extends TaskFailedReason {
case class TaskKilled(
reason: String,
accumUpdates: Seq[AccumulableInfo] = Seq.empty,
private[spark] val accums: Seq[AccumulatorV2[_, _]] = Nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously we use AccumulableInfo to expose accumulator information to end users. Now AccumulatorV2 is already a public classs and we don't need to do it anymore, I think we can just do

case class TaskKilled(reason: String, accums: Seq[AccumulatorV2[_, _]])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I noticed accumUpdates: Seq[AccumulableInfo] is only used in JsonProtocol. Is that for a reason?

The current impl is constructed to be sync with existing TaskEndReason such as ExceptionFailure

@DeveloperApi
case class ExceptionFailure(
    className: String,
    description: String,
    stackTrace: Array[StackTraceElement],
    fullStackTrace: String,
    private val exceptionWrapper: Option[ThrowableSerializationWrapper],
    accumUpdates: Seq[AccumulableInfo] = Seq.empty,
    private[spark] var accums: Seq[AccumulatorV2[_, _]] = Nil)

I'd prefer to keep in sync, leave two options for cleanup:

  1. leave it as it is, then cleanup with ExceptionFailure together
  2. Cleanup ExceptionFailure first.

@cloud-fan what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's clean up ExceptionFailure at the same time, and use only AccumulatorV2 in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan After a second look, I don't think we can clean up ExceptionFailure unless we can break JsonProtocol

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now the question is: shall we keep the unnecessary Seq[AccumulableInfo] in new APIs, to make the API consistent? I'd like to not keep the Seq[AccumulableInfo], we may deprecate it in the existing APIs in the near future.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok with not keeping Seq[AccumulableInfo]. But it means inconsistent logic and api and may make future refactoring a bit difficult.

Let's see what I can do.

I'd like to not keep the Seq[AccumulableInfo], we may deprecate it in the existing APIs in the near future.

BTW, I think we have already deprecated AccumulableInfo. Unless we are planing to remove it in Spark 3.0 and Spark 3.0 is the next release, AccumulableInfo will be there for a long time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cloud-fan, I have looked at how to remove Seq[AccumulableInfo] tonight.
It turns out that we cannot because JsonProtocol calls taskEndReasonFromJson to reconstruct TaskEndReasons. Since AccumulatorV2 is an abstract class, we cannot simply construct AccumulatorV2s from json.

Even we are promoting AccumulatorV2, we still need AccumulableInfo when (de)serializing json.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense, let's keep AccumulableInfo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cloud-fan so, could you trigger the test and have a look?

And looks like I am not in the whitelist again...

extends TaskFailedReason {

override def toErrorString: String = s"TaskKilled ($reason)"
override def countTowardsTaskFailures: Boolean = false

}

/**
Expand Down
55 changes: 35 additions & 20 deletions core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,28 @@ private[spark] class Executor(
notifyAll()
}

/**
* Utility function to:
* 1. Report executor runtime and JVM gc time if possible
* 2. Collect accumulator updates
* 3. Set the finished flag to true and clear current thread's interrupt status
*/
private def collectAccumulatorsAndResetStatusOnFailure(taskStartTime: Long) = {
// Report executor runtime and JVM gc time
Option(task).foreach(t => {
t.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStartTime)
t.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
})

// Collect latest accumulator values to report back to the driver
val accums: Seq[AccumulatorV2[_, _]] =
Option(task).map(_.collectAccumulatorUpdates(taskFailed = true)).getOrElse(Seq.empty)
val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))

setTaskFinishedAndClearInterruptStatus()
(accums, accUpdates)
}

override def run(): Unit = {
threadId = Thread.currentThread.getId
Thread.currentThread.setName(threadName)
Expand All @@ -300,7 +322,7 @@ private[spark] class Executor(
val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
var taskStartTime: Long = 0
var taskStartCpu: Long = 0
startGCTime = computeTotalGcTime()

Expand Down Expand Up @@ -336,7 +358,7 @@ private[spark] class Executor(
}

// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
taskStartTime = System.currentTimeMillis()
taskStartCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime
} else 0L
Expand Down Expand Up @@ -396,11 +418,11 @@ private[spark] class Executor(
// Deserialization happens in two parts: first, we deserialize a Task object, which
// includes the Partition. Second, Task.run() deserializes the RDD and function to be run.
task.metrics.setExecutorDeserializeTime(
(taskStart - deserializeStartTime) + task.executorDeserializeTime)
(taskStartTime - deserializeStartTime) + task.executorDeserializeTime)
task.metrics.setExecutorDeserializeCpuTime(
(taskStartCpu - deserializeStartCpuTime) + task.executorDeserializeCpuTime)
// We need to subtract Task.run()'s deserialization time to avoid double-counting
task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime)
task.metrics.setExecutorRunTime((taskFinish - taskStartTime) - task.executorDeserializeTime)
task.metrics.setExecutorCpuTime(
(taskFinishCpu - taskStartCpu) - task.executorDeserializeCpuTime)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
Expand Down Expand Up @@ -482,16 +504,19 @@ private[spark] class Executor(
} catch {
case t: TaskKilledException =>
logInfo(s"Executor killed $taskName (TID $taskId), reason: ${t.reason}")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason)))

val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
val serializedTK = ser.serialize(TaskKilled(t.reason, accUpdates, accums))
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)

case _: InterruptedException | NonFatal(_) if
task != null && task.reasonIfKilled.isDefined =>
val killReason = task.reasonIfKilled.getOrElse("unknown reason")
logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason")
setTaskFinishedAndClearInterruptStatus()
execBackend.statusUpdate(
taskId, TaskState.KILLED, ser.serialize(TaskKilled(killReason)))

val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)
val serializedTK = ser.serialize(TaskKilled(killReason, accUpdates, accums))
execBackend.statusUpdate(taskId, TaskState.KILLED, serializedTK)

case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) =>
val reason = task.context.fetchFailed.get.toTaskFailedReason
Expand Down Expand Up @@ -524,17 +549,7 @@ private[spark] class Executor(
// the task failure would not be ignored if the shutdown happened because of premption,
// instead of an app issue).
if (!ShutdownHookManager.inShutdown()) {
// Collect latest accumulator values to report back to the driver
val accums: Seq[AccumulatorV2[_, _]] =
if (task != null) {
task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart)
task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime)
task.collectAccumulatorUpdates(taskFailed = true)
} else {
Seq.empty
}

val accUpdates = accums.map(acc => acc.toInfo(Some(acc.value), None))
val (accums, accUpdates) = collectAccumulatorsAndResetStatusOnFailure(taskStartTime)

val serializedTaskEndReason = {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1210,7 +1210,7 @@ class DAGScheduler(
case _ =>
updateAccumulators(event)
}
case _: ExceptionFailure => updateAccumulators(event)
case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
case _ =>
}
postTaskEnd(event)
Expand Down Expand Up @@ -1414,13 +1414,13 @@ class DAGScheduler(
case commitDenied: TaskCommitDenied =>
// Do nothing here, left up to the TaskScheduler to decide how to handle denied commits

case exceptionFailure: ExceptionFailure =>
case _: ExceptionFailure | _: TaskKilled =>
// Nothing left to do, already handled above for accumulator updates.

case TaskResultLost =>
// Do nothing here; the TaskScheduler handles these failures and resubmits the task.

case _: ExecutorLostFailure | _: TaskKilled | UnknownReason =>
case _: ExecutorLostFailure | UnknownReason =>
// Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
// will abort the job.
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -833,13 +833,19 @@ private[spark] class TaskSetManager(
}
ef.exception

case tk: TaskKilled =>
// TaskKilled might have accumulator updates
accumUpdates = tk.accums
logWarning(failureReason)
None

case e: ExecutorLostFailure if !e.exitCausedByApp =>
logInfo(s"Task $tid failed because while it was being computed, its executor " +
"exited for a reason unrelated to the task. Not counting this failure towards the " +
"maximum number of failures for the task.")
None

case e: TaskFailedReason => // TaskResultLost, TaskKilled, and others
case e: TaskFailedReason => // TaskResultLost and others
logWarning(failureReason)
None
}
Expand Down
9 changes: 7 additions & 2 deletions core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,9 @@ private[spark] object JsonProtocol {
("Exit Caused By App" -> exitCausedByApp) ~
("Loss Reason" -> reason.map(_.toString))
case taskKilled: TaskKilled =>
("Kill Reason" -> taskKilled.reason)
val accumUpdates = JArray(taskKilled.accumUpdates.map(accumulableInfoToJson).toList)
("Kill Reason" -> taskKilled.reason) ~
("Accumulator Updates" -> accumUpdates)
case _ => emptyJson
}
("Reason" -> reason) ~ json
Expand Down Expand Up @@ -917,7 +919,10 @@ private[spark] object JsonProtocol {
case `taskKilled` =>
val killReason = jsonOption(json \ "Kill Reason")
.map(_.extract[String]).getOrElse("unknown reason")
TaskKilled(killReason)
val accumUpdates = jsonOption(json \ "Accumulator Updates")
.map(_.extract[List[JValue]].map(accumulableInfoFromJson))
.getOrElse(Seq[AccumulableInfo]())
TaskKilled(killReason, accumUpdates)
case `taskCommitDenied` =>
// Unfortunately, the `TaskCommitDenied` message was introduced in 1.3.0 but the JSON
// de/serialization logic was not added until 1.5.1. To provide backward compatibility
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1852,7 +1852,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
assertDataStructuresEmpty()
}

test("accumulators are updated on exception failures") {
test("accumulators are updated on exception failures and task killed") {
val acc1 = AccumulatorSuite.createLongAccum("ingenieur")
val acc2 = AccumulatorSuite.createLongAccum("boulanger")
val acc3 = AccumulatorSuite.createLongAccum("agriculteur")
Expand All @@ -1868,15 +1868,24 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val accUpdate3 = new LongAccumulator
accUpdate3.metadata = acc3.metadata
accUpdate3.setValue(18)
val accumUpdates = Seq(accUpdate1, accUpdate2, accUpdate3)
val accumInfo = accumUpdates.map(AccumulatorSuite.makeInfo)

val accumUpdates1 = Seq(accUpdate1, accUpdate2)
val accumInfo1 = accumUpdates1.map(AccumulatorSuite.makeInfo)
val exceptionFailure = new ExceptionFailure(
new SparkException("fondue?"),
accumInfo).copy(accums = accumUpdates)
accumInfo1).copy(accums = accumUpdates1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not caused by you but why we do a copy instead of passing accumUpdates1 to the constructor directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can avoid the copy call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, this copy call cannot be avoided as only the 2 arguments constructor
private[spark] def this(e: Throwable, accumUpdates: Seq[AccumulableInfo]) is defined.

submit(new MyRDD(sc, 1, Nil), Array(0))
runEvent(makeCompletionEvent(taskSets.head.tasks.head, exceptionFailure, "result"))

assert(AccumulatorContext.get(acc1.id).get.value === 15L)
assert(AccumulatorContext.get(acc2.id).get.value === 13L)

val accumUpdates2 = Seq(accUpdate3)
val accumInfo2 = accumUpdates2.map(AccumulatorSuite.makeInfo)

val taskKilled = new TaskKilled( "test", accumInfo2, accums = accumUpdates2)
runEvent(makeCompletionEvent(taskSets.head.tasks.head, taskKilled, "result"))

assert(AccumulatorContext.get(acc3.id).get.value === 18L)
}

Expand Down Expand Up @@ -2497,6 +2506,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
val accumUpdates = reason match {
case Success => task.metrics.accumulators()
case ef: ExceptionFailure => ef.accums
case tk: TaskKilled => tk.accums
case _ => Seq.empty
}
CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, taskInfo)
Expand Down
5 changes: 5 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ object MimaExcludes {

// Exclude rules for 2.4.x
lazy val v24excludes = v23excludes ++ Seq(
// [SPARK-20087][CORE] Attach accumulators / metrics to 'TaskKilled' end reason
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.apply"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.copy"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.TaskKilled.this"),

// [SPARK-22941][core] Do not exit JVM when submit fails with in-process launcher.
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.printWarning"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.SparkSubmit.parseSparkConfProperty"),
Expand Down