Skip to content

Commit 28da1d8

Browse files
utkarsh39cloud-fan
authored andcommitted
[SPARK-46383] Reduce Driver Heap Usage by Reducing the Lifespan of TaskInfo.accumulables()
### What changes were proposed in this pull request? `AccumulableInfo` is one of the top heap consumers in driver's heap dumps for stages with many tasks. For a stage with a large number of tasks (**_O(100k)_**), we saw **30%** of the heap usage stemming from `TaskInfo.accumulables()`. ![image](https://github.com/apache/spark/assets/10495099/13ef5d07-abfc-47fd-81b6-705f599db011) The `TaskSetManager` today keeps around the TaskInfo objects ([ref1](https://github.com/apache/spark/blob/c1ba963e64a22dea28e17b1ed954e6d03d38da1e/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L134), [ref2](https://github.com/apache/spark/blob/c1ba963e64a22dea28e17b1ed954e6d03d38da1e/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala#L192))) and in turn the task metrics (`AccumulableInfo`) for every task attempt until the stage is completed. This means that for stages with a large number of tasks, we keep metrics for all the tasks (`AccumulableInfo`) around even when the task has completed and its metrics have been aggregated. Given a task has a large number of metrics, stages with many tasks end up with a large heap usage in the form of task metrics. This PR is an opt-in change (disabled by default) to reduce the driver's heap usage for stages with many tasks by no longer referencing the task metrics of completed tasks. Once a task is completed in `TaskSetManager`, we no longer keep its metrics around. Upon task completion, we clone the `TaskInfo` object and empty out the metrics for the clone. The cloned `TaskInfo` is retained by the `TaskSetManager` while the original `TaskInfo` object with the metrics is sent over to the `DAGScheduler` where the task metrics are aggregated. Thus for a completed task, `TaskSetManager` holds a `TaskInfo` object with empty metrics. This reduces the memory footprint by ensuring that the number of task metric objects is proportional to the number of active tasks and not to the total number of tasks in the stage. ### Config to gate changes The changes in the PR are guarded with the Spark conf `spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled` which can be used for rollback or staged rollouts. ### Why are the changes disabled by default? The PR introduces a breaking change wherein the `TaskInfo.accumulables()` are empty for `Resubmitted` tasks upon the loss of an executor. Read #44321 (review) for details. ### Why are the changes needed? Reduce driver's heap usage, especially for stages with many tasks ## Benchmarking On a cluster running a scan stage with 100k tasks, the TaskSetManager's heap usage dropped from 1.1 GB to 37 MB. This **reduced the total driver's heap usage by 38%**, down to 2 GB from 3.5 GB. **BEFORE** ![image](https://github.com/databricks/runtime/assets/10495099/7c1599f3-3587-48a1-b019-84115b1bb90d) **WITH FIX** <img width="1386" alt="image" src="https://github.com/databricks/runtime/assets/10495099/b85129c8-dc10-4ee2-898d-61c8e7449616"> ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added new tests and did benchmarking on a cluster. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Github Copilot Closes #44321 from utkarsh39/SPARK-46383. Authored-by: Utkarsh <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 4828f49 commit 28da1d8

File tree

5 files changed

+169
-8
lines changed

5 files changed

+169
-8
lines changed

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2620,4 +2620,14 @@ package object config {
26202620
.stringConf
26212621
.toSequence
26222622
.createWithDefault("org.apache.spark.sql.connect.client" :: Nil)
2623+
2624+
private[spark] val DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION =
2625+
ConfigBuilder("spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled")
2626+
.internal()
2627+
.doc("If true, the task info accumulables will be cleared upon task completion in " +
2628+
"TaskSetManager. This reduces the heap usage of the driver by only referencing the " +
2629+
"task info accumulables for the active tasks and not for completed tasks.")
2630+
.version("4.0.0")
2631+
.booleanConf
2632+
.createWithDefault(false)
26232633
}

core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class TaskInfo(
4545
val executorId: String,
4646
val host: String,
4747
val taskLocality: TaskLocality.TaskLocality,
48-
val speculative: Boolean) {
48+
val speculative: Boolean) extends Cloneable {
4949

5050
/**
5151
* This api doesn't contains partitionId, please use the new api.
@@ -83,6 +83,14 @@ class TaskInfo(
8383
_accumulables = newAccumulables
8484
}
8585

86+
override def clone(): TaskInfo = super.clone().asInstanceOf[TaskInfo]
87+
88+
private[scheduler] def cloneWithEmptyAccumulables(): TaskInfo = {
89+
val cloned = clone()
90+
cloned.setAccumulables(Nil)
91+
cloned
92+
}
93+
8694
/**
8795
* The time when the task has completed successfully (including the time to remotely fetch
8896
* results, if necessary).

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 64 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ private[spark] class TaskSetManager(
256256

257257
private[scheduler] var emittedTaskSizeWarning = false
258258

259+
private[scheduler] val dropTaskInfoAccumulablesOnTaskCompletion =
260+
conf.get(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION)
261+
259262
/** Add a task to all the pending-task lists that it should be on. */
260263
private[spark] def addPendingTask(
261264
index: Int,
@@ -785,6 +788,11 @@ private[spark] class TaskSetManager(
785788
// SPARK-37300: when the task was already finished state, just ignore it,
786789
// so that there won't cause successful and tasksSuccessful wrong result.
787790
if(info.finished) {
791+
if (dropTaskInfoAccumulablesOnTaskCompletion) {
792+
// SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable
793+
// lifetime.
794+
info.setAccumulables(Nil)
795+
}
788796
return
789797
}
790798
val index = info.index
@@ -802,6 +810,8 @@ private[spark] class TaskSetManager(
802810
// Handle this task as a killed task
803811
handleFailedTask(tid, TaskState.KILLED,
804812
TaskKilled("Finish but did not commit due to another attempt succeeded"))
813+
// SPARK-46383: Not clearing the accumulables here because they are already cleared in
814+
// handleFailedTask.
805815
return
806816
}
807817

@@ -844,11 +854,50 @@ private[spark] class TaskSetManager(
844854
// "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
845855
// Note: "result.value()" only deserializes the value when it's called at the first time, so
846856
// here "result.value()" just returns the value and won't block other threads.
847-
sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates,
848-
result.metricPeaks, info)
857+
858+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), Success, result.value(),
859+
result.accumUpdates, result.metricPeaks)
849860
maybeFinishTaskSet()
850861
}
851862

863+
/**
864+
* A wrapper around [[DAGScheduler.taskEnded()]] that empties out the accumulables for the
865+
* TaskInfo object, corresponding to the completed task, referenced by this class.
866+
*
867+
* SPARK-46383: For the completed task, we ship the original TaskInfo to the DAGScheduler and only
868+
* retain a cloned TaskInfo in this class. We then set the accumulables to Nil for the TaskInfo
869+
* object that corresponds to the completed task.
870+
* We do this to release references to `TaskInfo.accumulables()` as the TaskInfo
871+
* objects held by this class are long-lived and have a heavy memory footprint on the driver.
872+
*
873+
* This is safe as the TaskInfo accumulables are not needed once they are shipped to the
874+
* DAGScheduler where they are aggregated. Additionally, the original TaskInfo, and not a
875+
* clone, must be sent to the DAGScheduler as this TaskInfo object is sent to the
876+
* DAGScheduler on multiple events during the task's lifetime. Users can install
877+
* SparkListeners that compare the TaskInfo objects across these SparkListener events and
878+
* thus the TaskInfo object sent to the DAGScheduler must always reference the same TaskInfo
879+
* object.
880+
*/
881+
private def emptyTaskInfoAccumulablesAndNotifyDagScheduler(
882+
taskId: Long,
883+
task: Task[_],
884+
reason: TaskEndReason,
885+
result: Any,
886+
accumUpdates: Seq[AccumulatorV2[_, _]],
887+
metricPeaks: Array[Long]): Unit = {
888+
val taskInfoWithAccumulables = taskInfos(taskId);
889+
if (dropTaskInfoAccumulablesOnTaskCompletion) {
890+
val index = taskInfoWithAccumulables.index
891+
val clonedTaskInfo = taskInfoWithAccumulables.cloneWithEmptyAccumulables()
892+
// Update this task's taskInfo while preserving its position in the list
893+
taskAttempts(index) =
894+
taskAttempts(index).map { i => if (i eq taskInfoWithAccumulables) clonedTaskInfo else i }
895+
taskInfos(taskId) = clonedTaskInfo
896+
}
897+
sched.dagScheduler.taskEnded(task, reason, result, accumUpdates, metricPeaks,
898+
taskInfoWithAccumulables)
899+
}
900+
852901
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
853902
partitionToIndex.get(partitionId).foreach { index =>
854903
if (!successful(index)) {
@@ -872,6 +921,11 @@ private[spark] class TaskSetManager(
872921
// SPARK-37300: when the task was already finished state, just ignore it,
873922
// so that there won't cause copiesRunning wrong result.
874923
if (info.finished) {
924+
if (dropTaskInfoAccumulablesOnTaskCompletion) {
925+
// SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable
926+
// lifetime.
927+
info.setAccumulables(Nil)
928+
}
875929
return
876930
}
877931
removeRunningTask(tid)
@@ -906,7 +960,8 @@ private[spark] class TaskSetManager(
906960
if (ef.className == classOf[NotSerializableException].getName) {
907961
// If the task result wasn't serializable, there's no point in trying to re-execute it.
908962
logError(s"$task had a not serializable result: ${ef.description}; not retrying")
909-
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
963+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
964+
accumUpdates, metricPeaks)
910965
abort(s"$task had a not serializable result: ${ef.description}")
911966
return
912967
}
@@ -915,7 +970,8 @@ private[spark] class TaskSetManager(
915970
// re-execute it.
916971
logError("Task %s in stage %s (TID %d) can not write to output file: %s; not retrying"
917972
.format(info.id, taskSet.id, tid, ef.description))
918-
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
973+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
974+
accumUpdates, metricPeaks)
919975
abort("Task %s in stage %s (TID %d) can not write to output file: %s".format(
920976
info.id, taskSet.id, tid, ef.description))
921977
return
@@ -968,7 +1024,8 @@ private[spark] class TaskSetManager(
9681024
isZombie = true
9691025
}
9701026

971-
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
1027+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
1028+
accumUpdates, metricPeaks)
9721029

9731030
if (!isZombie && reason.countTowardsTaskFailures) {
9741031
assert (null != failureReason)
@@ -1084,8 +1141,8 @@ private[spark] class TaskSetManager(
10841141
addPendingTask(index)
10851142
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
10861143
// stage finishes when a total of tasks.size tasks finish.
1087-
sched.dagScheduler.taskEnded(
1088-
tasks(index), Resubmitted, null, Seq.empty, Array.empty, info)
1144+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid,
1145+
tasks(index), Resubmitted, null, Seq.empty, Array.empty)
10891146
}
10901147
}
10911148
}

core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.scheduler
1919

2020
import java.io.{Externalizable, ObjectInput, ObjectOutput}
21+
import java.util.{Collections, IdentityHashMap}
2122
import java.util.concurrent.Semaphore
2223

2324
import scala.collection.mutable
@@ -289,6 +290,19 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
289290
stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true}
290291
}
291292

293+
test("SPARK-46383: Track TaskInfo objects") {
294+
// Test that the same TaskInfo object is sent to the `DAGScheduler` in the `onTaskStart` and
295+
// `onTaskEnd` events.
296+
val conf = new SparkConf().set(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION, true)
297+
sc = new SparkContext("local", "SparkListenerSuite", conf)
298+
val listener = new SaveActiveTaskInfos
299+
sc.addSparkListener(listener)
300+
val rdd1 = sc.parallelize(1 to 100, 4)
301+
sc.runJob(rdd1, (items: Iterator[Int]) => items.size, Seq(0, 1))
302+
sc.listenerBus.waitUntilEmpty()
303+
listener.taskInfos.size should be { 0 }
304+
}
305+
292306
test("local metrics") {
293307
sc = new SparkContext("local", "SparkListenerSuite")
294308
val listener = new SaveStageAndTaskInfo
@@ -643,6 +657,27 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
643657
}
644658
}
645659

660+
/**
661+
* A simple listener that tracks task infos for all active tasks.
662+
*/
663+
private class SaveActiveTaskInfos extends SparkListener {
664+
// Use a set based on IdentityHashMap instead of a HashSet to track unique references of
665+
// TaskInfo objects.
666+
val taskInfos = Collections.newSetFromMap[TaskInfo](new IdentityHashMap)
667+
668+
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
669+
val info = taskStart.taskInfo
670+
if (info != null) {
671+
taskInfos.add(info)
672+
}
673+
}
674+
675+
override def onTaskEnd(task: SparkListenerTaskEnd): Unit = {
676+
val info = task.taskInfo
677+
taskInfos.remove(info)
678+
}
679+
}
680+
646681
/**
647682
* A simple listener that saves the task indices for all task events.
648683
*/

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,12 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
6262
accumUpdates: Seq[AccumulatorV2[_, _]],
6363
metricPeaks: Array[Long],
6464
taskInfo: TaskInfo): Unit = {
65+
// Set task accumulables emulating DAGScheduler behavior to enable tests related to
66+
// `TaskInfo.accumulables`.
67+
accumUpdates.foreach(acc =>
68+
taskInfo.setAccumulables(
69+
acc.toInfo(Some(acc.value), Some(acc.value)) +: taskInfo.accumulables)
70+
)
6571
taskScheduler.endedTasks(taskInfo.index) = reason
6672
}
6773

@@ -230,6 +236,51 @@ class TaskSetManagerSuite
230236
super.afterEach()
231237
}
232238

239+
test("SPARK-46383: TaskInfo accumulables are cleared upon task completion") {
240+
val conf = new SparkConf().
241+
set(config.DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION, true)
242+
sc = new SparkContext("local", "test", conf)
243+
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
244+
val taskSet = FakeTask.createTaskSet(2)
245+
val clock = new ManualClock
246+
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
247+
val accumUpdates = taskSet.tasks.head.metrics.internalAccums
248+
249+
// Offer a host. This will launch the first task.
250+
val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)._1
251+
assert(taskOption.isDefined)
252+
253+
clock.advance(1)
254+
// Tell it the first task has finished successfully
255+
manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates))
256+
assert(sched.endedTasks(0) === Success)
257+
258+
// Only one task was launched and it completed successfully, thus the TaskInfo accumulables
259+
// should be empty.
260+
assert(!manager.taskInfos.exists(t => !t._2.accumulables.isEmpty))
261+
assert(manager.taskAttempts.flatMap(t => t.filter(!_.accumulables.isEmpty)).isEmpty)
262+
263+
// Fail the second task (MAX_TASK_FAILURES - 1) times.
264+
(1 to manager.maxTaskFailures - 1).foreach { index =>
265+
val offerResult = manager.resourceOffer("exec1", "host1", ANY)._1
266+
assert(offerResult.isDefined,
267+
"Expect resource offer on iteration %s to return a task".format(index))
268+
assert(offerResult.get.index === 1)
269+
manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
270+
}
271+
272+
clock.advance(1)
273+
// Successfully finish the second task.
274+
val taskOption1 = manager.resourceOffer("exec1", "host1", ANY)._1
275+
manager.handleSuccessfulTask(taskOption1.get.taskId, createTaskResult(1, accumUpdates))
276+
assert(sched.endedTasks(1) === Success)
277+
// The TaskInfo accumulables should be empty as the second task has now completed successfully.
278+
assert(!manager.taskInfos.exists(t => !t._2.accumulables.isEmpty))
279+
assert(manager.taskAttempts.flatMap(t => t.filter(!_.accumulables.isEmpty)).isEmpty)
280+
281+
assert(sched.finishedManagers.contains(manager))
282+
}
283+
233284
test("TaskSet with no preferences") {
234285
sc = new SparkContext("local", "test")
235286
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))

0 commit comments

Comments
 (0)