Skip to content

Commit d124ce9

Browse files
cloud-fansquito
authored andcommitted
[SPARK-27590][CORE] do not consider skipped tasks when scheduling speculative tasks
## What changes were proposed in this pull request? This is a followup of #24375 When `TaskSetManager` skips a task because its corresponding partition is already completed by other `TaskSetManager`s, we should not consider the duration of the task that is finished by other `TaskSetManager`s to schedule the speculative tasks of this `TaskSetManager`. ## How was this patch tested? updated test case Closes #24485 from cloud-fan/minor. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Imran Rashid <[email protected]>
1 parent d5308cd commit d124ce9

File tree

8 files changed

+25
-34
lines changed

8 files changed

+25
-34
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,8 +1394,7 @@ private[spark] class DAGScheduler(
13941394
// finished. Here we notify the task scheduler to skip running tasks for the same partition,
13951395
// to save resource.
13961396
if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
1397-
taskScheduler.notifyPartitionCompletion(
1398-
stageId, task.partitionId, event.taskInfo.duration)
1397+
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
13991398
}
14001399

14011400
task match {

core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
158158
// This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want
159159
// DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's
160160
// synchronized and may hurt the throughput of the scheduler.
161-
def enqueuePartitionCompletionNotification(
162-
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
161+
def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = {
163162
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
164-
scheduler.handlePartitionCompleted(stageId, partitionId, taskDuration)
163+
scheduler.handlePartitionCompleted(stageId, partitionId)
165164
})
166165
}
167166

core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ private[spark] trait TaskScheduler {
7070

7171
// Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed
7272
// and they can skip running tasks for it.
73-
def notifyPartitionCompletion(stageId: Int, partitionId: Int, taskDuration: Long)
73+
def notifyPartitionCompletion(stageId: Int, partitionId: Int)
7474

7575
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
7676
def setDAGScheduler(dagScheduler: DAGScheduler): Unit

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,8 @@ private[spark] class TaskSchedulerImpl(
301301
}
302302
}
303303

304-
override def notifyPartitionCompletion(
305-
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
306-
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId, taskDuration)
304+
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
305+
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId)
307306
}
308307

309308
/**
@@ -651,12 +650,9 @@ private[spark] class TaskSchedulerImpl(
651650
* means that a task completion from an earlier zombie attempt can lead to the entire stage
652651
* getting marked as successful.
653652
*/
654-
private[scheduler] def handlePartitionCompleted(
655-
stageId: Int,
656-
partitionId: Int,
657-
taskDuration: Long) = synchronized {
653+
private[scheduler] def handlePartitionCompleted(stageId: Int, partitionId: Int) = synchronized {
658654
taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
659-
tsm.markPartitionCompleted(partitionId, taskDuration)
655+
tsm.markPartitionCompleted(partitionId)
660656
})
661657
}
662658

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,8 @@ private[spark] class TaskSetManager(
6262
private val addedJars = HashMap[String, Long](sched.sc.addedJars.toSeq: _*)
6363
private val addedFiles = HashMap[String, Long](sched.sc.addedFiles.toSeq: _*)
6464

65-
// Quantile of tasks at which to start speculation
66-
val speculationQuantile = conf.get(SPECULATION_QUANTILE)
67-
val speculationMultiplier = conf.get(SPECULATION_MULTIPLIER)
68-
6965
val maxResultSize = conf.get(config.MAX_RESULT_SIZE)
7066

71-
val speculationEnabled = conf.get(SPECULATION_ENABLED)
72-
7367
// Serializer for closures and tasks.
7468
val env = SparkEnv.get
7569
val ser = env.closureSerializer.newInstance()
@@ -80,6 +74,12 @@ private[spark] class TaskSetManager(
8074
val numTasks = tasks.length
8175
val copiesRunning = new Array[Int](numTasks)
8276

77+
val speculationEnabled = conf.get(SPECULATION_ENABLED)
78+
// Quantile of tasks at which to start speculation
79+
val speculationQuantile = conf.get(SPECULATION_QUANTILE)
80+
val speculationMultiplier = conf.get(SPECULATION_MULTIPLIER)
81+
val minFinishedForSpeculation = math.max((speculationQuantile * numTasks).floor.toInt, 1)
82+
8383
// For each task, tracks whether a copy of the task has succeeded. A task will also be
8484
// marked as "succeeded" if it failed with a fetch failure, in which case it should not
8585
// be re-run because the missing map data needs to be regenerated first.
@@ -816,12 +816,9 @@ private[spark] class TaskSetManager(
816816
maybeFinishTaskSet()
817817
}
818818

819-
private[scheduler] def markPartitionCompleted(partitionId: Int, taskDuration: Long): Unit = {
819+
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
820820
partitionToIndex.get(partitionId).foreach { index =>
821821
if (!successful(index)) {
822-
if (speculationEnabled && !isZombie) {
823-
successfulTaskDurations.insert(taskDuration)
824-
}
825822
tasksSuccessful += 1
826823
successful(index) = true
827824
if (tasksSuccessful == numTasks) {
@@ -1035,10 +1032,13 @@ private[spark] class TaskSetManager(
10351032
return false
10361033
}
10371034
var foundTasks = false
1038-
val minFinishedForSpeculation = (speculationQuantile * numTasks).floor.toInt
10391035
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
10401036

1041-
if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
1037+
// It's possible that a task is marked as completed by the scheduler, then the size of
1038+
// `successfulTaskDurations` may not equal to `tasksSuccessful`. Here we should only count the
1039+
// tasks that are submitted by this `TaskSetManager` and are completed successfully.
1040+
val numSuccessfulTasks = successfulTaskDurations.size()
1041+
if (numSuccessfulTasks >= minFinishedForSpeculation) {
10421042
val time = clock.getTimeMillis()
10431043
val medianDuration = successfulTaskDurations.median
10441044
val threshold = max(speculationMultiplier * medianDuration, minTimeToSpeculation)

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
157157
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
158158
override def killAllTaskAttempts(
159159
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
160-
override def notifyPartitionCompletion(
161-
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
160+
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
162161
taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
163162
val tasks = ts.tasks.filter(_.partitionId == partitionId)
164163
assert(tasks.length == 1)
@@ -668,8 +667,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
668667
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
669668
throw new UnsupportedOperationException
670669
}
671-
override def notifyPartitionCompletion(
672-
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
670+
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
673671
throw new UnsupportedOperationException
674672
}
675673
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}

core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,7 @@ private class DummyTaskScheduler extends TaskScheduler {
8484
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
8585
override def killAllTaskAttempts(
8686
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
87-
override def notifyPartitionCompletion(
88-
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {}
87+
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {}
8988
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
9089
override def defaultParallelism(): Int = 2
9190
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,8 +1394,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
13941394

13951395
val taskSetManager = sched.taskSetManagerForAttempt(0, 0).get
13961396
assert(taskSetManager.runningTasks === 8)
1397-
taskSetManager.markPartitionCompleted(8, 0)
1398-
assert(!taskSetManager.successfulTaskDurations.isEmpty())
1397+
taskSetManager.markPartitionCompleted(8)
1398+
assert(taskSetManager.successfulTaskDurations.isEmpty())
13991399
taskSetManager.checkSpeculatableTasks(0)
14001400
}
14011401

0 commit comments

Comments
 (0)