diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 9491bc7a0497e..b766e4148e496 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -79,7 +79,7 @@ private[spark] class TaskSetManager( var minShare = 0 var priority = taskSet.priority var stageId = taskSet.stageId - var name = "TaskSet_" + taskSet.stageId.toString + val name = "TaskSet_" + taskSet.id var parent: Pool = null var totalResultSize = 0L var calculatedTasks = 0 diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index 87600fe504b98..f395fe9804c91 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -22,7 +22,7 @@ import org.apache.spark.TaskContext class FakeTask( stageId: Int, partitionId: Int, - prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, partitionId) { + prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, stageAttemptId = 0, partitionId) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } @@ -33,16 +33,21 @@ object FakeTask { * locations for each task (given as varargs) if this sequence is not empty. */ def createTaskSet(numTasks: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { - createTaskSet(numTasks, 0, prefLocs: _*) + createTaskSet(numTasks, stageAttemptId = 0, prefLocs: _*) } def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = { + createTaskSet(numTasks, stageId = 0, stageAttemptId, prefLocs: _*) + } + + def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): + TaskSet = { if (prefLocs.size != 0 && prefLocs.size != numTasks) { throw new IllegalArgumentException("Wrong number of task locations") } val tasks = Array.tabulate[Task[_]](numTasks) { i => - new FakeTask(0, i, if (prefLocs.size != 0) prefLocs(i) else Nil) + new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil) } - new TaskSet(tasks, 0, stageAttemptId, 0, null) + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 69edcf3347243..43e1f1a0e6ea8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -904,7 +904,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg task.index == index && !sched.endedTasks.contains(task.taskId) }.getOrElse { throw new RuntimeException(s"couldn't find index $index in " + - s"tasks: ${tasks.map{t => t.index -> t.taskId}} with endedTasks:" + + s"tasks: ${tasks.map { t => t.index -> t.taskId }} with endedTasks:" + s" ${sched.endedTasks.keys}") } } @@ -974,6 +974,24 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(manager.isZombie) } + test("SPARK-17894: Verify TaskSetManagers for different stage attempts have unique names") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1")) + val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, new ManualClock) + assert(manager.name === "TaskSet_0.0") + + // Make sure a task set with the same stage ID but different attempt ID also has a unique name + val taskSet2 = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 1) + val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, new ManualClock) + assert(manager2.name === "TaskSet_0.1") + + // Make sure a task set with the same attempt ID but different stage ID also has a unique name + val taskSet3 = FakeTask.createTaskSet(numTasks = 1, stageId = 1, stageAttemptId = 1) + val manager3 = new TaskSetManager(sched, taskSet3, MAX_TASK_FAILURES, new ManualClock) + assert(manager3.name === "TaskSet_1.1") + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {