Skip to content

Commit 664e06d

Browse files
committed
fix
1 parent c1ba963 commit 664e06d

File tree

5 files changed

+205
-8
lines changed

5 files changed

+205
-8
lines changed

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2620,4 +2620,14 @@ package object config {
26202620
.stringConf
26212621
.toSequence
26222622
.createWithDefault("org.apache.spark.sql.connect.client" :: Nil)
2623+
2624+
private[spark] val DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION =
2625+
ConfigBuilder("spark.scheduler.dropTaskInfoAccumulablesOnTaskCompletion.enabled")
2626+
.internal()
2627+
.doc("If true, the task info accumulables will be cleared upon task completion in " +
2628+
"TaskSetManager. This reduces the heap usage of the driver by only referencing the " +
2629+
"task info accumulables for the active tasks and not for completed tasks.")
2630+
.version("4.0.0")
2631+
.booleanConf
2632+
.createWithDefault(true)
26232633
}

core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class TaskInfo(
4545
val executorId: String,
4646
val host: String,
4747
val taskLocality: TaskLocality.TaskLocality,
48-
val speculative: Boolean) {
48+
val speculative: Boolean) extends Cloneable {
4949

5050
/**
5151
* This api doesn't contains partitionId, please use the new api.
@@ -75,14 +75,47 @@ class TaskInfo(
7575
* accumulable to be updated multiple times in a single task or for two accumulables with the
7676
* same name but different IDs to exist in a task.
7777
*/
78-
def accumulables: Seq[AccumulableInfo] = _accumulables
78+
def accumulables: Seq[AccumulableInfo] = {
79+
if (throwOnAccumulablesCall) {
80+
throw new IllegalStateException("Accumulables for the TaskInfo have been cleared")
81+
} else {
82+
_accumulables
83+
}
84+
}
7985

8086
private[this] var _accumulables: Seq[AccumulableInfo] = Nil
8187

8288
private[spark] def setAccumulables(newAccumulables: Seq[AccumulableInfo]): Unit = {
8389
_accumulables = newAccumulables
8490
}
8591

92+
/**
93+
* If true, a call to TaskInfo.accumulables() will throw an exception.
94+
*/
95+
private var throwOnAccumulablesCall: Boolean = false
96+
97+
override def clone(): TaskInfo = super.clone().asInstanceOf[TaskInfo]
98+
99+
/**
100+
* For testing only. Allows probing accumulables without triggering the exception when
101+
* `throwOnAccumulablesCall` is set.
102+
*/
103+
private[scheduler] def isAccumulablesEmpty(): Boolean = {
104+
_accumulables.isEmpty
105+
}
106+
107+
private[scheduler] def resetAccumulables(): Unit = {
108+
setAccumulables(Nil)
109+
throwOnAccumulablesCall = true
110+
}
111+
112+
private[scheduler] def cloneWithEmptyAccumulables(): TaskInfo = {
113+
val cloned = clone()
114+
cloned.setAccumulables(Nil)
115+
cloned.throwOnAccumulablesCall = true
116+
cloned
117+
}
118+
86119
/**
87120
* The time when the task has completed successfully (including the time to remotely fetch
88121
* results, if necessary).

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,8 @@ private[spark] class TaskSetManager(
787787
// SPARK-37300: when the task was already finished state, just ignore it,
788788
// so that there won't cause successful and tasksSuccessful wrong result.
789789
if(info.finished) {
790+
// SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable lifetime.
791+
info.resetAccumulables()
790792
return
791793
}
792794
val index = info.index
@@ -804,6 +806,8 @@ private[spark] class TaskSetManager(
804806
// Handle this task as a killed task
805807
handleFailedTask(tid, TaskState.KILLED,
806808
TaskKilled("Finish but did not commit due to another attempt succeeded"))
809+
// SPARK-46383: Not clearing the accumulables here because they are already cleared in
810+
// handleFailedTask.
807811
return
808812
}
809813

@@ -846,11 +850,49 @@ private[spark] class TaskSetManager(
846850
// "result.value()" in "TaskResultGetter.enqueueSuccessfulTask" before reaching here.
847851
// Note: "result.value()" only deserializes the value when it's called at the first time, so
848852
// here "result.value()" just returns the value and won't block other threads.
849-
sched.dagScheduler.taskEnded(tasks(index), Success, result.value(), result.accumUpdates,
850-
result.metricPeaks, info)
853+
854+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), Success, result.value(),
855+
result.accumUpdates, result.metricPeaks, info)
851856
maybeFinishTaskSet()
852857
}
853858

859+
/**
860+
* A wrapper around [[DAGScheduler.taskEnded()]] that empties out the accumulables for the
861+
* TaskInfo object, corresponding to the completed task, referenced by this class.
862+
*
863+
* SPARK-46383: For the completed task, we ship the original TaskInfo to the DAGScheduler and only
864+
* retain a cloned TaskInfo in this class. We then set the accumulables to Nil for the TaskInfo
865+
* object that corresponds to the completed task.
866+
* We do this to release references to `TaskInfo.accumulables()` as the TaskInfo
867+
* objects held by this class are long-lived and have a heavy memory footprint on the driver.
868+
*
869+
* This is safe as the TaskInfo accumulables are not needed once they are shipped to the
870+
* DAGScheduler where they are aggregated. Additionally, the original TaskInfo, and not a
871+
* clone, must be sent to the DAGScheduler as this TaskInfo object is sent to the
872+
* DAGScheduler on multiple events during the task's lifetime. Users can install
873+
* SparkListeners that compare the TaskInfo objects across these SparkListener events and
874+
* thus the TaskInfo object sent to the DAGScheduler must always reference the same TaskInfo
875+
* object.
876+
*/
877+
private def emptyTaskInfoAccumulablesAndNotifyDagScheduler(
878+
taskId: Long,
879+
task: Task[_],
880+
reason: TaskEndReason,
881+
result: Any,
882+
accumUpdates: Seq[AccumulatorV2[_, _]],
883+
metricPeaks: Array[Long],
884+
taskInfo: TaskInfo): Unit = {
885+
val index = taskInfo.index
886+
if (conf.get(DROP_TASK_INFO_ACCUMULABLES_ON_TASK_COMPLETION)) {
887+
val clonedTaskInfo = taskInfo.cloneWithEmptyAccumulables()
888+
// Update this task's taskInfo while preserving its position in the list
889+
taskAttempts(index) =
890+
taskAttempts(index).map { i => if (i eq taskInfo) clonedTaskInfo else i }
891+
taskInfos(taskId) = clonedTaskInfo
892+
}
893+
sched.dagScheduler.taskEnded(task, reason, result, accumUpdates, metricPeaks, taskInfo)
894+
}
895+
854896
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
855897
partitionToIndex.get(partitionId).foreach { index =>
856898
if (!successful(index)) {
@@ -874,6 +916,8 @@ private[spark] class TaskSetManager(
874916
// SPARK-37300: when the task was already finished state, just ignore it,
875917
// so that there won't cause copiesRunning wrong result.
876918
if (info.finished) {
919+
// SPARK-46383: Clear out the accumulables for a completed task to reduce accumulable lifetime.
920+
info.resetAccumulables()
877921
return
878922
}
879923
removeRunningTask(tid)
@@ -908,7 +952,8 @@ private[spark] class TaskSetManager(
908952
if (ef.className == classOf[NotSerializableException].getName) {
909953
// If the task result wasn't serializable, there's no point in trying to re-execute it.
910954
logError(s"$task had a not serializable result: ${ef.description}; not retrying")
911-
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
955+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
956+
accumUpdates, metricPeaks, info)
912957
abort(s"$task had a not serializable result: ${ef.description}")
913958
return
914959
}
@@ -917,7 +962,8 @@ private[spark] class TaskSetManager(
917962
// re-execute it.
918963
logError("Task %s in stage %s (TID %d) can not write to output file: %s; not retrying"
919964
.format(info.id, taskSet.id, tid, ef.description))
920-
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
965+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
966+
accumUpdates, metricPeaks, info)
921967
abort("Task %s in stage %s (TID %d) can not write to output file: %s".format(
922968
info.id, taskSet.id, tid, ef.description))
923969
return
@@ -970,7 +1016,8 @@ private[spark] class TaskSetManager(
9701016
isZombie = true
9711017
}
9721018

973-
sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, metricPeaks, info)
1019+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid, tasks(index), reason, null,
1020+
accumUpdates, metricPeaks, info)
9741021

9751022
if (!isZombie && reason.countTowardsTaskFailures) {
9761023
assert (null != failureReason)
@@ -1086,7 +1133,7 @@ private[spark] class TaskSetManager(
10861133
addPendingTask(index)
10871134
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
10881135
// stage finishes when a total of tasks.size tasks finish.
1089-
sched.dagScheduler.taskEnded(
1136+
emptyTaskInfoAccumulablesAndNotifyDagScheduler(tid,
10901137
tasks(index), Resubmitted, null, Seq.empty, Array.empty, info)
10911138
}
10921139
}

core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.scheduler
1919

2020
import java.io.{Externalizable, ObjectInput, ObjectOutput}
21+
import java.util.{Collections, IdentityHashMap}
2122
import java.util.concurrent.Semaphore
2223

2324
import scala.collection.mutable
@@ -289,6 +290,16 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
289290
stageInfo.rddInfos.forall(_.numPartitions == 4) should be {true}
290291
}
291292

293+
test("SPARK-46383: Track TaskInfo objects") {
294+
sc = new SparkContext("local", "SparkListenerSuite")
295+
val listener = new SaveActiveTaskInfos
296+
sc.addSparkListener(listener)
297+
val rdd1 = sc.parallelize(1 to 100, 4)
298+
sc.runJob(rdd1, (items: Iterator[Int]) => items.size, Seq(0, 1))
299+
sc.listenerBus.waitUntilEmpty()
300+
listener.taskInfos.size should be { 0 }
301+
}
302+
292303
test("local metrics") {
293304
sc = new SparkContext("local", "SparkListenerSuite")
294305
val listener = new SaveStageAndTaskInfo
@@ -643,6 +654,29 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match
643654
}
644655
}
645656

657+
/**
658+
* A simple listener that tracks task infos for all active tasks.
659+
*/
660+
private class SaveActiveTaskInfos extends SparkListener {
661+
// Use a set based on IdentityHashMap instead of a HashSet to track unique references of
662+
// TaskInfo objects.
663+
val taskInfos = Collections.newSetFromMap[TaskInfo](new IdentityHashMap)
664+
665+
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
666+
val info = taskStart.taskInfo
667+
if (info != null) {
668+
taskInfos.add(info)
669+
}
670+
}
671+
672+
override def onTaskEnd(task: SparkListenerTaskEnd): Unit = {
673+
val info = task.taskInfo
674+
if (info != null && taskInfos.contains(info)) {
675+
taskInfos.remove(info)
676+
}
677+
}
678+
}
679+
646680
/**
647681
* A simple listener that saves the task indices for all task events.
648682
*/

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
6161
accumUpdates: Seq[AccumulatorV2[_, _]],
6262
metricPeaks: Array[Long],
6363
taskInfo: TaskInfo): Unit = {
64+
accumUpdates.foreach(acc =>
65+
taskInfo.setAccumulables(
66+
acc.toInfo(Some(acc.value), Some(acc.value)) +: taskInfo.accumulables)
67+
)
68+
taskScheduler.endedTasks(taskInfo.index) = reason
6469
taskScheduler.endedTasks(taskInfo.index) = reason
6570
}
6671

@@ -229,6 +234,74 @@ class TaskSetManagerSuite
229234
super.afterEach()
230235
}
231236

237+
test("SPARK-46383: Accumulables of TaskInfo objects held by TaskSetManager must not be " +
238+
"accessed once the task has completed") { conf =>
239+
sc = new SparkContext("local", "test", conf)
240+
sched = FakeTaskScheduler(sc, ("exec1", "host1"))
241+
val taskSet = FakeTask.createTaskSet(1)
242+
val clock = new ManualClock
243+
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
244+
val accumUpdates = taskSet.tasks.head.metrics.internalAccums
245+
246+
// Offer a host. This will launch the first task.
247+
val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)._1
248+
assert(taskOption.isDefined)
249+
250+
clock.advance(1)
251+
// Tell it the first task has finished successfully
252+
manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates))
253+
assert(sched.endedTasks(0) === Success)
254+
255+
val e = intercept[IllegalStateException]{
256+
manager.taskInfos.head._2.accumulables
257+
}
258+
assert(e.getMessage.contains("Accumulables for the TaskInfo have been cleared"))
259+
}
260+
261+
test("SPARK-46383: TaskInfo accumulables are cleared upon task completion") { conf =>
262+
sc = new SparkContext("local", "test", conf)
263+
sched = FakeTaskScheduler(sc, ("exec1", "host1"))
264+
val taskSet = FakeTask.createTaskSet(2)
265+
val clock = new ManualClock
266+
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
267+
val accumUpdates = taskSet.tasks.head.metrics.internalAccums
268+
269+
// Offer a host. This will launch the first task.
270+
val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)._1
271+
assert(taskOption.isDefined)
272+
273+
clock.advance(1)
274+
// Tell it the first task has finished successfully
275+
manager.handleSuccessfulTask(0, createTaskResult(0, accumUpdates))
276+
assert(sched.endedTasks(0) === Success)
277+
278+
// Only one task was launched and it completed successfully, thus the TaskInfo accumulables
279+
// should be empty.
280+
assert(!manager.taskInfos.exists(l => !l._2.isAccumulablesEmpty))
281+
assert(manager.taskAttempts.flatMap(l => l.filter(!_.isAccumulablesEmpty)).isEmpty)
282+
283+
// Fail the second task (MAX_TASK_FAILURES - 1) times.
284+
(1 to manager.maxTaskFailures - 1).foreach { index =>
285+
val offerResult = manager.resourceOffer("exec1", "host1", ANY)._1
286+
assert(offerResult.isDefined,
287+
"Expect resource offer on iteration %s to return a task".format(index))
288+
assert(offerResult.get.index === 1)
289+
manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
290+
assert(!sched.taskSetsFailed.contains(FakeTaskFailure(taskSet.id)))
291+
}
292+
293+
clock.advance(1)
294+
// Successfully finish the second task.
295+
val taskOption1 = manager.resourceOffer("exec1", "host1", ANY)._1
296+
manager.handleSuccessfulTask(taskOption1.get.taskId, createTaskResult(1, accumUpdates))
297+
assert(sched.endedTasks(1) === Success)
298+
// The TaskInfo accumulables should be empty as the second task has now completed successfully.
299+
assert(!manager.taskInfos.exists(l => !l._2.isAccumulablesEmpty))
300+
assert(manager.taskAttempts.flatMap(l => l.filter(!_.isAccumulablesEmpty)).isEmpty)
301+
302+
assert(sched.finishedManagers.contains(manager))
303+
}
304+
232305
test("TaskSet with no preferences") {
233306
sc = new SparkContext("local", "test")
234307
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))

0 commit comments

Comments
 (0)