Skip to content

Commit 08526a8

Browse files
Ensure uniqueness of TaskSetManager name.
1 parent 9540357 commit 08526a8

File tree

3 files changed

+25
-4
lines changed

3 files changed

+25
-4
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ private[spark] class TaskSetManager(
7979
var minShare = 0
8080
var priority = taskSet.priority
8181
var stageId = taskSet.stageId
82-
var name = "TaskSet_" + taskSet.stageId.toString
82+
val name = "TaskSet_" + taskSet.id
8383
var parent: Pool = null
8484
var totalResultSize = 0L
8585
var calculatedTasks = 0

core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,17 @@ object FakeTask {
3737
}
3838

3939
def createTaskSet(numTasks: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*): TaskSet = {
40+
createTaskSet(numTasks, 0, stageAttemptId, prefLocs: _*)
41+
}
42+
43+
def createTaskSet(numTasks: Int, stageId: Int, stageAttemptId: Int, prefLocs: Seq[TaskLocation]*):
44+
TaskSet = {
4045
if (prefLocs.size != 0 && prefLocs.size != numTasks) {
4146
throw new IllegalArgumentException("Wrong number of task locations")
4247
}
4348
val tasks = Array.tabulate[Task[_]](numTasks) { i =>
44-
new FakeTask(0, i, if (prefLocs.size != 0) prefLocs(i) else Nil)
49+
new FakeTask(stageId, i, if (prefLocs.size != 0) prefLocs(i) else Nil)
4550
}
46-
new TaskSet(tasks, 0, stageAttemptId, 0, null)
51+
new TaskSet(tasks, stageId, stageAttemptId, 0, null)
4752
}
4853
}

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
904904
task.index == index && !sched.endedTasks.contains(task.taskId)
905905
}.getOrElse {
906906
throw new RuntimeException(s"couldn't find index $index in " +
907-
s"tasks: ${tasks.map{t => t.index -> t.taskId}} with endedTasks:" +
907+
s"tasks: ${tasks.map { t => t.index -> t.taskId }} with endedTasks:" +
908908
s" ${sched.endedTasks.keys}")
909909
}
910910
}
@@ -974,6 +974,22 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
974974
assert(manager.isZombie)
975975
}
976976

977+
test("check uniqueness of TaskSetManager name") {
978+
sc = new SparkContext("local", "test")
979+
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
980+
val taskSet = FakeTask.createTaskSet(1, 0, 0)
981+
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, new ManualClock)
982+
assert(manager.name === "TaskSet_0.0")
983+
984+
val taskSet2 = FakeTask.createTaskSet(1, 0, 1)
985+
val manager2 = new TaskSetManager(sched, taskSet2, MAX_TASK_FAILURES, new ManualClock)
986+
assert(manager2.name === "TaskSet_0.1")
987+
988+
val taskSet3 = FakeTask.createTaskSet(1, 1, 1)
989+
val manager3 = new TaskSetManager(sched, taskSet3, MAX_TASK_FAILURES, new ManualClock)
990+
assert(manager3.name === "TaskSet_1.1")
991+
}
992+
977993
private def createTaskResult(
978994
id: Int,
979995
accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {

0 commit comments

Comments
 (0)