Skip to content

Commit f132194

Browse files
committed
add stageIdToFinishedPartitions
1 parent 927081d commit f132194

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,9 @@ private[spark] class TaskSchedulerImpl(
9696
private[scheduler] val taskIdToTaskSetManager = new ConcurrentHashMap[Long, TaskSetManager]
9797
val taskIdToExecutorId = new HashMap[Long, String]
9898

99+
// Protected by `this`
100+
private[scheduler] val stageIdToFinishedPartitions = new HashMap[Int, HashSet[Int]]
101+
99102
@volatile private var hasReceivedTask = false
100103
@volatile private var hasLaunchedTask = false
101104
private val starvationTimer = new Timer(true)
@@ -208,7 +211,7 @@ private[spark] class TaskSchedulerImpl(
208211
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
209212
stageTaskSets(taskSet.stageAttemptId) = manager
210213
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
211-
ts.taskSet != taskSet && !ts.isZombie
214+
ts.taskSet != manager.taskSet && !ts.isZombie
212215
}
213216
if (conflictingTaskSet) {
214217
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
@@ -238,7 +241,13 @@ private[spark] class TaskSchedulerImpl(
238241
private[scheduler] def createTaskSetManager(
239242
taskSet: TaskSet,
240243
maxTaskFailures: Int): TaskSetManager = {
241-
new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt)
244+
val finishedPartitions =
245+
stageIdToFinishedPartitions.getOrElseUpdate(taskSet.stageId, new HashSet[Int])
246+
// filter the task which has been finished by previous attempts
247+
val tasks = taskSet.tasks.filterNot{ t => finishedPartitions(t.partitionId) }
248+
val ts = new TaskSet(
249+
tasks, taskSet.stageId, taskSet.stageAttemptId, taskSet.priority, taskSet.properties)
250+
new TaskSetManager(this, ts, maxTaskFailures, blacklistTrackerOpt)
242251
}
243252

244253
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
@@ -297,6 +306,7 @@ private[spark] class TaskSchedulerImpl(
297306
taskSetsForStage -= manager.taskSet.stageAttemptId
298307
if (taskSetsForStage.isEmpty) {
299308
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
309+
stageIdToFinishedPartitions -= manager.taskSet.stageId
300310
}
301311
}
302312
manager.parent.removeSchedulable(manager)
@@ -837,17 +847,31 @@ private[spark] class TaskSchedulerImpl(
837847
}
838848

839849
/**
840-
* Marks the task has completed in all TaskSetManagers for the given stage.
850+
* Marks the task has completed in all TaskSetManagers(active / zombie) for the given stage.
841851
*
842852
* After stage failure and retry, there may be multiple TaskSetManagers for the stage.
843853
* If an earlier attempt of a stage completes a task, we should ensure that the later attempts
844854
* do not also submit those same tasks. That also means that a task completion from an earlier
845855
* attempt can lead to the entire stage getting marked as successful.
856+
* And there's a situation that the active TaskSetManager corresponding to the stage may
857+
* haven't been created at the time we call this method. And it is possible since the behaviour
858+
* of calling on this method and creating active TaskSetManager is from two different threads,
859+
* which are "task-result-getter" and "dag-scheduler-event-loop" separately. Consequently, under
860+
* this situation, the active TaskSetManager which is created later could not learn about the
861+
* finished partitions and keep on launching duplicate tasks, which may lead to job fail for some
862+
* severe cases, see SPARK-25250 for details. So, to avoid the problem, we record the finished
863+
* partitions for that stage here and exclude the already finished tasks when we creating active
864+
* TaskSetManagers later by looking into stageIdToFinishedPartitions. Thus, active TaskSetManager
865+
* could be always notified about the finished partitions whether it has been created or not at
866+
* the time we call this method.
846867
*/
847868
private[scheduler] def markPartitionCompletedInAllTaskSets(
848869
stageId: Int,
849870
partitionId: Int,
850871
taskInfo: TaskInfo) = {
872+
val finishedPartitions =
873+
stageIdToFinishedPartitions.getOrElseUpdate(stageId, new HashSet[Int])
874+
finishedPartitions += partitionId
851875
taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
852876
tsm.markPartitionCompleted(partitionId, taskInfo)
853877
}

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1206,6 +1206,67 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
12061206
}
12071207
}
12081208

1209+
test("successful tasks from previous attempts could be learnt by later active taskset") {
1210+
val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
1211+
val valueSer = SparkEnv.get.serializer.newInstance()
1212+
val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
1213+
1214+
// submit a taskset with 10 tasks to taskScheduler
1215+
val attempt0 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 0)
1216+
taskScheduler.submitTasks(attempt0)
1217+
// get the current active tsm
1218+
val tsm0 = taskScheduler.taskSetManagerForAttempt(0, 0).get
1219+
// offer sufficient resources
1220+
val offers0 = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
1221+
taskScheduler.resourceOffers(offers0)
1222+
assert(tsm0.runningTasks === 10)
1223+
// fail task 0.0 and mark tsm0 as zombie
1224+
tsm0.handleFailedTask(tsm0.taskAttempts(0)(0).taskId, TaskState.FAILED,
1225+
FetchFailed(null, 0, 0, 0, "fetch failed"))
1226+
// the attempt0 is a zombie, but the tasks are still running (this could be true even if
1227+
// we actively killed those tasks, as killing is best-effort)
1228+
assert(tsm0.isZombie)
1229+
assert(tsm0.runningTasks === 9)
1230+
1231+
1232+
// success task 1.0 , finish partition 1. But now,
1233+
// no active tsm exists in TaskScheduler for stage0.
1234+
tsm0.handleSuccessfulTask(tsm0.taskAttempts(1)(0).taskId, result)
1235+
assert(tsm0.runningTasks === 8)
1236+
assert(taskScheduler.stageIdToFinishedPartitions(0).contains(1))
1237+
1238+
// submit a new taskset with 10 tasks after someone previous task attempt succeed
1239+
val attempt1 = FakeTask.createTaskSet(10, stageId = 0, stageAttemptId = 1)
1240+
taskScheduler.submitTasks(attempt1)
1241+
// get the current active tsm
1242+
val tsm1 = taskScheduler.taskSetManagerForAttempt(0, 1).get
1243+
// tsm1 learns about the finished partition 1 during constructing, so it only need
1244+
// to execute other 9 tasks
1245+
assert(tsm1.taskSet.tasks.length == 9)
1246+
// offer one resource
1247+
val offers1 = (10 until 11).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
1248+
taskScheduler.resourceOffers(offers1)
1249+
assert(tsm1.runningTasks === 1)
1250+
// success task 0.0 in tsm1 and finish partition 0
1251+
tsm1.handleSuccessfulTask(tsm1.taskAttempts(0)(0).taskId, result)
1252+
assert(taskScheduler.stageIdToFinishedPartitions(0).contains(0))
1253+
1254+
1255+
val runningTasks = tsm0.taskSet.tasks.filterNot{ t =>
1256+
taskScheduler.stageIdToFinishedPartitions(0).contains(t.partitionId)
1257+
}
1258+
// finish tsm1 by previous task attempts from tsm0, this remains same behavior with SPARK-23433
1259+
runningTasks.foreach{ t =>
1260+
val attempt = tsm0.taskAttempts(tsm0.partitionToIndex(t.partitionId)).head
1261+
tsm0.handleSuccessfulTask(attempt.taskId, result)
1262+
}
1263+
1264+
assert(taskScheduler.taskSetManagerForAttempt(0, 0).isEmpty)
1265+
assert(taskScheduler.taskSetManagerForAttempt(0, 1).isEmpty)
1266+
assert(taskScheduler.stageIdToFinishedPartitions.isEmpty)
1267+
1268+
}
1269+
12091270
test("don't schedule for a barrier taskSet if available slots are less than pending tasks") {
12101271
val taskCpus = 2
12111272
val taskScheduler = setupScheduler(config.CPUS_PER_TASK.key -> taskCpus.toString)

0 commit comments

Comments
 (0)