Skip to content
Closed
10 changes: 10 additions & 0 deletions core/src/main/scala/org/apache/spark/internal/config/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2620,4 +2620,14 @@ package object config {
.stringConf
.toSequence
.createWithDefault("org.apache.spark.sql.connect.client" :: Nil)

private[spark] val DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION =
ConfigBuilder("spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled")
.internal()
.doc("If true, the task info accumulables will be cleared upon task completion in " +
"TaskSetManager. This reduces the heap usage of the driver by only referencing the " +
"task info accumulables for the active tasks and not for completed tasks.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)
}
10 changes: 9 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class TaskInfo(
val executorId: String,
val host: String,
val taskLocality: TaskLocality.TaskLocality,
val speculative: Boolean) {
val speculative: Boolean) extends Cloneable {

/**
* This api doesn't contains partitionId, please use the new api.
Expand Down Expand Up @@ -83,6 +83,14 @@ class TaskInfo(
_accumulables = newAccumulables
}

override def clone(): TaskInfo = super.clone().asInstanceOf[TaskInfo]

private[scheduler] def cloneWithEmptyAccumulables(): TaskInfo = {
val cloned = clone()
cloned.setAccumulables(Nil)
cloned
}

/**
* The time when the task has completed successfully (including the time to remotely fetch
* results, if necessary).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,9 @@ private[spark] class TaskSetManager(

private[scheduler] var emittedTaskSizeWarning = false

private[scheduler] val dropTaskInfoAccumulablesOnTaskCompletion =
conf.get(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION)

/** Add a task to all the pending-task lists that it should be on. */
private[spark] def addPendingTask(
index: Int,
Expand Down Expand Up @@ -787,6 +790,11 @@ private[spark] class TaskSetManager(
// SPARK-37300: when the task was already finished state, just ignore it,
// so that there won't cause successful and tasksSuccessful wrong result.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading this comment, the partition is already completed, probably by another TaskSetManager, and we just need to reset the task info here?

Copy link
Contributor

@JoshRosen JoshRosen Dec 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this branch is handling a rare corner-case where the same TaskSetManager can mark the same task as both succeeded and failed. There is some detailed prior discussion of this in https://issues.apache.org/jira/browse/SPARK-37300

if(info.finished) {
if (dropTaskInfoAccumulablesOnTaskCompletion) {
// SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable
// lifetime.
info.setAccumulables(Nil)
}
return
}
val index = info.index
Expand All @@ -804,6 +812,8 @@ private[spark] class TaskSetManager(
// Handle this task as a killed task
handleFailedTask(tid, TaskState.KILLED,
TaskKilled("Finish but did not commit due to another attempt succeeded"))
// SPARK-46383: Not clearing the accumulables here because they are already cleared in
// handleFailedTask.
return
}

Expand Down Expand Up @@ -846,11 +856,50 @@ private[spark] class TaskSetManager(
// "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
// Note: "result.value()" only deserializes the value when it's called at the first time, so
// here "result.value()" just returns the value and won't block other threads.
sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates,
result.metricPeaks, info)

emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), Success, result.value(),
result.accumUpdates, result.metricPeaks)
maybeFinishTaskSet()
}

/**
* A wrapper around [[DAGScheduler.taskEnded()]] that empties out the accumulables for the
* TaskInfo object, corresponding to the completed task, referenced by this class.
*
* SPARK-46383: For the completed task, we ship the original TaskInfo to the DAGScheduler and only
* retain a cloned TaskInfo in this class. We then set the accumulables to Nil for the TaskInfo
* object that corresponds to the completed task.
* We do this to release references to `TaskInfo.accumulables()` as the TaskInfo
* objects held by this class are long-lived and have a heavy memory footprint on the driver.
*
* This is safe as the TaskInfo accumulables are not needed once they are shipped to the
* DAGScheduler where they are aggregated. Additionally, the original TaskInfo, and not a
* clone, must be sent to the DAGScheduler as this TaskInfo object is sent to the
* DAGScheduler on multiple events during the task's lifetime. Users can install
* SparkListeners that compare the TaskInfo objects across these SparkListener events and
* thus the TaskInfo object sent to the DAGScheduler must always reference the same TaskInfo
* object.
*/
private def emptyTaskInfoAccumulablesAndNotifyDagScheduler(
taskId: Long,
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: Seq[AccumulatorV2[_, _]],
metricPeaks: Array[Long]): Unit = {
val taskInfoWithAccumulables = taskInfos(taskId);
if (dropTaskInfoAccumulablesOnTaskCompletion) {
val index = taskInfoWithAccumulables.index
val clonedTaskInfo = taskInfoWithAccumulables.cloneWithEmptyAccumulables()
// Update this task's taskInfo while preserving its position in the list
taskAttempts(index) =
taskAttempts(index).map { i => if (i eq taskInfoWithAccumulables) clonedTaskInfo else i }
taskInfos(taskId) = clonedTaskInfo
}
sched.dagScheduler.taskEnded(task, reason, result, accumUpdates, metricPeaks,
taskInfoWithAccumulables)
}

private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
partitionToIndex.get(partitionId).foreach { index =>
if (!successful(index)) {
Expand All @@ -874,6 +923,11 @@ private[spark] class TaskSetManager(
// SPARK-37300: when the task was already finished state, just ignore it,
// so that there won't cause copiesRunning wrong result.
if (info.finished) {
if (dropTaskInfoAccumulablesOnTaskCompletion) {
// SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable
// lifetime.
info.setAccumulables(Nil)
}
return
}
removeRunningTask(tid)
Expand Down Expand Up @@ -908,7 +962,8 @@ private[spark] class TaskSetManager(
if (ef.className == classOf[NotSerializableException].getName) {
// If the task result wasn't serializable, there's no point in trying to re-execute it.
logError(s"$task had a not serializable result: ${ef.description}; not retrying")
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
accumUpdates, metricPeaks)
abort(s"$task had a not serializable result: ${ef.description}")
return
}
Expand All @@ -917,7 +972,8 @@ private[spark] class TaskSetManager(
// re-execute it.
logError("Task %s in stage %s (TID %d) can not write to output file: %s; not retrying"
.format(info.id, taskSet.id, tid, ef.description))
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
accumUpdates, metricPeaks)
abort("Task %s in stage %s (TID %d) can not write to output file: %s".format(
info.id, taskSet.id, tid, ef.description))
return
Expand Down Expand Up @@ -970,7 +1026,8 @@ private[spark] class TaskSetManager(
isZombie = true
}

sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
accumUpdates, metricPeaks)

if (!isZombie && reason.countTowardsTaskFailures) {
assert (null != failureReason)
Expand Down Expand Up @@ -1086,8 +1143,8 @@ private[spark] class TaskSetManager(
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.dagScheduler.taskEnded(
tasks(index), Resubmitted, null, Seq.empty, Array.empty, info)
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid,
tasks(index), Resubmitted, null, Seq.empty, Array.empty)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.scheduler

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.util.{Collections, IdentityHashMap}
import java.util.concurrent.Semaphore

import scala.collection.mutable
Expand Down Expand Up @@ -289,6 +290,19 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true}
}

test("SPARK-46383: Track TaskInfo objects") {
// Test that the same TaskInfo object is sent to the `DAGScheduler` in the `onTaskStart` and
// `onTaskEnd` events.
val conf = new SparkConf().set(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION, true)
sc = new SparkContext("local", "SparkListenerSuite", conf)
val listener = new SaveActiveTaskInfos
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
sc.runJob(rdd1, (items: Iterator[Int]) => items.size, Seq(0, 1))
sc.listenerBus.waitUntilEmpty()
listener.taskInfos.size should be { 0 }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow this test, what is it trying to do ?
This test will be successful even with DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION = true, right ? (Since it is simply checking for instance equality in the fired event ?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test asserts that the same TaskInfo object is sent in the onTaskStart and onTaskEnd events. This test asserts the design in this PR that we are sending the original TaskInfo object to the DAGScheduler upon task completion and not a clone.

Copy link
Contributor

@mridulm mridulm Jan 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that not simply an implementation detail ? (for ex, the resubmission case would break it)
I am not sure what is the behavior we are testing for here - and how would this test help with some future change (and validation).

I dont see a harm is keeping it, but want to make sure I am not missing something here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind dropping it. I was just trying to assert one of the ways SparkListeners could be used. The test is more of a general test to ensure that we preserve the behavior of SparkListeners

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Functionally that (the right task info is in the event) should be covered already (in use of SaveStageAndTaskInfo for example). Do let me know if that is not the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SaveActiveTaskInfos is caching TaskInfos but there are no tests on TaskInfo objects and none asserting that the TaskInfo objects are expected to remain the same across listener events

}

test("local metrics") {
sc = new SparkContext("local", "SparkListenerSuite")
val listener = new SaveStageAndTaskInfo
Expand Down Expand Up @@ -643,6 +657,27 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
}
}

/**
* A simple listener that tracks task infos for all active tasks.
*/
private class SaveActiveTaskInfos extends SparkListener {
// Use a set based on IdentityHashMap instead of a HashSet to track unique references of
// TaskInfo objects.
val taskInfos = Collections.newSetFromMap[TaskInfo](new IdentityHashMap)

override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
val info = taskStart.taskInfo
if (info != null) {
taskInfos.add(info)
}
}

override def onTaskEnd(task: SparkListenerTaskEnd): Unit = {
val info = task.taskInfo
taskInfos.remove(info)
}
}

/**
* A simple listener that saves the task indices for all task events.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
accumUpdates: Seq[AccumulatorV2[_, _]],
metricPeaks: Array[Long],
taskInfo: TaskInfo): Unit = {
// Set task accumulables emulating DAGScheduler behavior to enable tests related to
// `TaskInfo.accumulables`.
accumUpdates.foreach(acc =>
taskInfo.setAccumulables(
acc.toInfo(Some(acc.value), Some(acc.value)) +: taskInfo.accumulables)
)
taskScheduler.endedTasks(taskInfo.index) = reason
}

Expand Down Expand Up @@ -229,6 +235,51 @@ class TaskSetManagerSuite
super.afterEach()
}

test("SPARK-46383: TaskInfo accumulables are cleared upon task completion") {
val conf = new SparkConf().
set(config.DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION, true)
sc = new SparkContext("local", "test", conf)
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
val taskSet = FakeTask.createTaskSet(2)
val clock = new ManualClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
val accumUpdates = taskSet.tasks.head.metrics.internalAccums

// Offer a host. This will launch the first task.
val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)._1
assert(taskOption.isDefined)

clock.advance(1)
// Tell it the first task has finished successfully
manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates))
assert(sched.endedTasks(0) === Success)

// Only one task was launched and it completed successfully, thus the TaskInfo accumulables
// should be empty.
assert(!manager.taskInfos.exists(t => !t._2.accumulables.isEmpty))
assert(manager.taskAttempts.flatMap(t => t.filter(!_.accumulables.isEmpty)).isEmpty)

// Fail the second task (MAX_TASK_FAILURES - 1) times.
(1 to manager.maxTaskFailures - 1).foreach { index =>
val offerResult = manager.resourceOffer("exec1", "host1", ANY)._1
assert(offerResult.isDefined,
"Expect resource offer on iteration %s to return a task".format(index))
assert(offerResult.get.index === 1)
manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
}

clock.advance(1)
// Successfully finish the second task.
val taskOption1 = manager.resourceOffer("exec1", "host1", ANY)._1
manager.handleSuccessfulTask(taskOption1.get.taskId, createTaskResult(1, accumUpdates))
assert(sched.endedTasks(1) === Success)
// The TaskInfo accumulables should be empty as the second task has now completed successfully.
assert(!manager.taskInfos.exists(t => !t._2.accumulables.isEmpty))
assert(manager.taskAttempts.flatMap(t => t.filter(!_.accumulables.isEmpty)).isEmpty)

assert(sched.finishedManagers.contains(manager))
}

test("TaskSet with no preferences") {
sc = new SparkContext("local", "test")
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
Expand Down