diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 598b62f85a1f..56c0bf6c0935 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -697,9 +697,12 @@ private[spark] class TaskSchedulerImpl( * do not also submit those same tasks. That also means that a task completion from an earlier * attempt can lead to the entire stage getting marked as successful. */ - private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + private[scheduler] def markPartitionCompletedInAllTaskSets( + stageId: Int, + partitionId: Int, + taskInfo: TaskInfo) = { taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => - tsm.markPartitionCompleted(partitionId) + tsm.markPartitionCompleted(partitionId, taskInfo) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a18c66596852..6071605ad7f9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -758,7 +758,7 @@ private[spark] class TaskSetManager( } // There may be multiple tasksets for this stage -- we let all of them know that the partition // was completed. This may result in some of the tasksets getting completed. - sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -769,9 +769,12 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + private[scheduler] def markPartitionCompleted(partitionId: Int, taskInfo: TaskInfo): Unit = { partitionToIndex.get(partitionId).foreach { index => if (!successful(index)) { + if (speculationEnabled && !isZombie) { + successfulTaskDurations.insert(taskInfo.duration) + } tasksSuccessful += 1 successful(index) = true if (tasksSuccessful == numTasks) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ca6a7e5db3b1..ae571e5a3583 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1365,6 +1365,55 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(taskOption4.get.addedJars === addedJarsMidTaskSet) } + test("[SPARK-24677] Avoid NoSuchElementException from MedianHeap") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.1") + sc.conf.set("spark.speculation", "true") + + sched = new FakeTaskScheduler(sc) + sched.initialize(new FakeSchedulerBackend()) + + val dagScheduler = new FakeDAGScheduler(sc, sched) + sched.setDAGScheduler(dagScheduler) + + val taskSet1 = FakeTask.createTaskSet(10) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet1.tasks.map { task => + task.metrics.internalAccums + } + + sched.submitTasks(taskSet1) + sched.resourceOffers( + (0 until 10).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + val taskSetManager1 = sched.taskSetManagerForAttempt(0, 0).get + + // fail fetch + taskSetManager1.handleFailedTask( + taskSetManager1.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + + assert(taskSetManager1.isZombie) + assert(taskSetManager1.runningTasks === 9) + + val taskSet2 = FakeTask.createTaskSet(10, stageAttemptId = 1) + sched.submitTasks(taskSet2) + sched.resourceOffers( + (11 until 20).map { idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }) + + // Complete the 2 tasks and leave 8 task in running + for (id <- Set(0, 1)) { + taskSetManager1.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + val taskSetManager2 = sched.taskSetManagerForAttempt(0, 1).get + assert(!taskSetManager2.successfulTaskDurations.isEmpty()) + taskSetManager2.checkSpeculatableTasks(0) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {