Skip to content

Commit 9af3665

Browse files
squitomarkhamstra
authored andcommitted
[SPARK-8103][core] DAGScheduler should not submit multiple concurrent attempts for a stage
https://issues.apache.org/jira/browse/SPARK-8103 cc kayousterhout (thanks for the extra test case) Author: Imran Rashid <[email protected]> Author: Kay Ousterhout <[email protected]> Author: Imran Rashid <[email protected]> Closes apache#6750 from squito/SPARK-8103 and squashes the following commits: fb3acfc [Imran Rashid] fix log msg e01b7aa [Imran Rashid] fix some comments, style 584acd4 [Imran Rashid] simplify going from taskId to taskSetMgr e43ac25 [Imran Rashid] Merge branch 'master' into SPARK-8103 6bc23af [Imran Rashid] update log msg 4470fa1 [Imran Rashid] rename c04707e [Imran Rashid] style 88b61cc [Imran Rashid] add tests to make sure that TaskSchedulerImpl schedules correctly with zombie attempts d7f1ef2 [Imran Rashid] get rid of activeTaskSets a21c8b5 [Imran Rashid] Merge branch 'master' into SPARK-8103 906d626 [Imran Rashid] fix merge 109900e [Imran Rashid] Merge branch 'master' into SPARK-8103 c0d4d90 [Imran Rashid] Revert "Index active task sets by stage Id rather than by task set id" f025154 [Imran Rashid] Merge pull request #2 from kayousterhout/imran_SPARK-8103 baf46e1 [Kay Ousterhout] Index active task sets by stage Id rather than by task set id 19685bb [Imran Rashid] switch to using latestInfo.attemptId, and add comments a5f7c8c [Imran Rashid] remove comment for reviewers 227b40d [Imran Rashid] style 517b6e5 [Imran Rashid] get rid of SparkIllegalStateException b2faef5 [Imran Rashid] faster check for conflicting task sets 6542b42 [Imran Rashid] remove extra stageAttemptId ada7726 [Imran Rashid] reviewer feedback d8eb202 [Imran Rashid] Merge branch 'master' into SPARK-8103 46bc26a [Imran Rashid] more cleanup of debug garbage cb245da [Imran Rashid] finally found the issue ... clean up debug stuff 8c29707 [Imran Rashid] Merge branch 'master' into SPARK-8103 89a59b6 [Imran Rashid] more printlns ... 9601b47 [Imran Rashid] more debug printlns ecb4e7d [Imran Rashid] debugging printlns b6bc248 [Imran Rashid] style 55f4a94 [Imran Rashid] get rid of more random test case since kays tests are clearer 7021d28 [Imran Rashid] update test since listenerBus.waitUntilEmpty now throws an exception instead of returning a boolean 883fe49 [Kay Ousterhout] Unit tests for concurrent stages issue 6e14683 [Imran Rashid] unit test just to make sure we fail fast on concurrent attempts 06a0af6 [Imran Rashid] ignore for jenkins c443def [Imran Rashid] better fix and simpler test case 28d70aa [Imran Rashid] wip on getting a better test case ... a9bf31f [Imran Rashid] wip
1 parent fce746d commit 9af3665

File tree

13 files changed

+383
-86
lines changed

13 files changed

+383
-86
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,6 @@ class DAGScheduler(
840840
// Get our pending tasks and remember them in our pendingTasks entry
841841
stage.pendingTasks.clear()
842842

843-
844843
// First figure out the indexes of partition ids to compute.
845844
val partitionsToCompute: Seq[Int] = {
846845
stage match {
@@ -901,7 +900,7 @@ class DAGScheduler(
901900
partitionsToCompute.map { id =>
902901
val locs = getPreferredLocs(stage.rdd, id)
903902
val part = stage.rdd.partitions(id)
904-
new ShuffleMapTask(stage.id, taskBinary, part, locs)
903+
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
905904
}
906905

907906
case stage: ResultStage =>
@@ -910,7 +909,7 @@ class DAGScheduler(
910909
val p: Int = job.partitions(id)
911910
val part = stage.rdd.partitions(p)
912911
val locs = getPreferredLocs(stage.rdd, p)
913-
new ResultTask(stage.id, taskBinary, part, locs, id)
912+
new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
914913
}
915914
}
916915
} catch {
@@ -1052,10 +1051,11 @@ class DAGScheduler(
10521051
val execId = status.location.executorId
10531052
logDebug("ShuffleMapTask finished on " + execId)
10541053
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
1055-
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
1054+
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
10561055
} else {
10571056
shuffleStage.addOutputLoc(smt.partitionId, status)
10581057
}
1058+
10591059
if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
10601060
markStageAsFinished(shuffleStage)
10611061
logInfo("looking for newly runnable stages")
@@ -1115,38 +1115,48 @@ class DAGScheduler(
11151115
val failedStage = stageIdToStage(task.stageId)
11161116
val mapStage = shuffleToMapStage(shuffleId)
11171117

1118-
// It is likely that we receive multiple FetchFailed for a single stage (because we have
1119-
// multiple tasks running concurrently on different executors). In that case, it is possible
1120-
// the fetch failure has already been handled by the scheduler.
1121-
if (runningStages.contains(failedStage)) {
1122-
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
1123-
s"due to a fetch failure from $mapStage (${mapStage.name})")
1124-
markStageAsFinished(failedStage, Some(failureMessage))
1125-
}
1118+
if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
1119+
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
1120+
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
1121+
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
1122+
} else {
11261123

1127-
if (disallowStageRetryForTest) {
1128-
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
1129-
} else if (failedStages.isEmpty) {
1130-
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
1131-
// in that case the event will already have been scheduled.
1132-
// TODO: Cancel running tasks in the stage
1133-
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
1134-
s"$failedStage (${failedStage.name}) due to fetch failure")
1135-
messageScheduler.schedule(new Runnable {
1136-
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
1137-
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
1138-
}
1139-
failedStages += failedStage
1140-
failedStages += mapStage
1141-
// Mark the map whose fetch failed as broken in the map stage
1142-
if (mapId != -1) {
1143-
mapStage.removeOutputLoc(mapId, bmAddress)
1144-
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
1145-
}
1124+
// It is likely that we receive multiple FetchFailed for a single stage (because we have
1125+
// multiple tasks running concurrently on different executors). In that case, it is
1126+
// possible the fetch failure has already been handled by the scheduler.
1127+
if (runningStages.contains(failedStage)) {
1128+
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
1129+
s"due to a fetch failure from $mapStage (${mapStage.name})")
1130+
markStageAsFinished(failedStage, Some(failureMessage))
1131+
} else {
1132+
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
1133+
s"longer running")
1134+
}
1135+
1136+
if (disallowStageRetryForTest) {
1137+
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
1138+
} else if (failedStages.isEmpty) {
1139+
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
1140+
// in that case the event will already have been scheduled.
1141+
// TODO: Cancel running tasks in the stage
1142+
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
1143+
s"$failedStage (${failedStage.name}) due to fetch failure")
1144+
messageScheduler.schedule(new Runnable {
1145+
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
1146+
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
1147+
}
1148+
failedStages += failedStage
1149+
failedStages += mapStage
1150+
// Mark the map whose fetch failed as broken in the map stage
1151+
if (mapId != -1) {
1152+
mapStage.removeOutputLoc(mapId, bmAddress)
1153+
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
1154+
}
11461155

1147-
// TODO: mark the executor as failed only if there were lots of fetch failures on it
1148-
if (bmAddress != null) {
1149-
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
1156+
// TODO: mark the executor as failed only if there were lots of fetch failures on it
1157+
if (bmAddress != null) {
1158+
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
1159+
}
11501160
}
11511161

11521162
case commitDenied: TaskCommitDenied =>

core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
4141
*/
4242
private[spark] class ResultTask[T, U](
4343
stageId: Int,
44+
stageAttemptId: Int,
4445
taskBinary: Broadcast[Array[Byte]],
4546
partition: Partition,
4647
@transient locs: Seq[TaskLocation],
4748
val outputId: Int)
48-
extends Task[U](stageId, partition.index) with Serializable {
49+
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {
4950

5051
@transient private[this] val preferredLocs: Seq[TaskLocation] = {
5152
if (locs == null) Nil else locs.toSet.toSeq

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
4040
*/
4141
private[spark] class ShuffleMapTask(
4242
stageId: Int,
43+
stageAttemptId: Int,
4344
taskBinary: Broadcast[Array[Byte]],
4445
partition: Partition,
4546
@transient private var locs: Seq[TaskLocation])
46-
extends Task[MapStatus](stageId, partition.index) with Logging {
47+
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {
4748

4849
/** A constructor used only in test suites. This does not require passing in an RDD. */
4950
def this(partitionId: Int) {
50-
this(0, null, new Partition { override def index: Int = 0 }, null)
51+
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
5152
}
5253

5354
@transient private val preferredLocs: Seq[TaskLocation] = {

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
4343
* @param stageId id of the stage this task belongs to
4444
* @param partitionId index of the number in the RDD
4545
*/
46-
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
46+
private[spark] abstract class Task[T](
47+
val stageId: Int,
48+
val stageAttemptId: Int,
49+
var partitionId: Int) extends Serializable {
4750

4851
/**
4952
* Called by [[Executor]] to run this task.

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 64 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ private[spark] class TaskSchedulerImpl(
7272

7373
// TaskSetManagers are not thread safe, so any access to one should be synchronized
7474
// on this class.
75-
val activeTaskSets = new HashMap[String, TaskSetManager]
75+
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]
7676

77-
val taskIdToTaskSetId = new HashMap[Long, String]
77+
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
7878
val taskIdToExecutorId = new HashMap[Long, String]
7979

8080
@volatile private var hasReceivedTask = false
@@ -158,7 +158,17 @@ private[spark] class TaskSchedulerImpl(
158158
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
159159
this.synchronized {
160160
val manager = createTaskSetManager(taskSet, maxTaskFailures)
161-
activeTaskSets(taskSet.id) = manager
161+
val stage = taskSet.stageId
162+
val stageTaskSets =
163+
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
164+
stageTaskSets(taskSet.stageAttemptId) = manager
165+
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
166+
ts.taskSet != taskSet && !ts.isZombie
167+
}
168+
if (conflictingTaskSet) {
169+
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
170+
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
171+
}
162172
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
163173

164174
if (!isLocal && !hasReceivedTask) {
@@ -188,19 +198,21 @@ private[spark] class TaskSchedulerImpl(
188198

189199
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
190200
logInfo("Cancelling stage " + stageId)
191-
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
192-
// There are two possible cases here:
193-
// 1. The task set manager has been created and some tasks have been scheduled.
194-
// In this case, send a kill signal to the executors to kill the task and then abort
195-
// the stage.
196-
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
197-
// simply abort the stage.
198-
tsm.runningTasksSet.foreach { tid =>
199-
val execId = taskIdToExecutorId(tid)
200-
backend.killTask(tid, execId, interruptThread)
201+
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
202+
attempts.foreach { case (_, tsm) =>
203+
// There are two possible cases here:
204+
// 1. The task set manager has been created and some tasks have been scheduled.
205+
// In this case, send a kill signal to the executors to kill the task and then abort
206+
// the stage.
207+
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
208+
// simply abort the stage.
209+
tsm.runningTasksSet.foreach { tid =>
210+
val execId = taskIdToExecutorId(tid)
211+
backend.killTask(tid, execId, interruptThread)
212+
}
213+
tsm.abort("Stage %s cancelled".format(stageId))
214+
logInfo("Stage %d was cancelled".format(stageId))
201215
}
202-
tsm.abort("Stage %s cancelled".format(stageId))
203-
logInfo("Stage %d was cancelled".format(stageId))
204216
}
205217
}
206218

@@ -210,7 +222,12 @@ private[spark] class TaskSchedulerImpl(
210222
* cleaned up.
211223
*/
212224
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
213-
activeTaskSets -= manager.taskSet.id
225+
taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
226+
taskSetsForStage -= manager.taskSet.stageAttemptId
227+
if (taskSetsForStage.isEmpty) {
228+
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
229+
}
230+
}
214231
manager.parent.removeSchedulable(manager)
215232
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
216233
.format(manager.taskSet.id, manager.parent.name))
@@ -231,7 +248,7 @@ private[spark] class TaskSchedulerImpl(
231248
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
232249
tasks(i) += task
233250
val tid = task.taskId
234-
taskIdToTaskSetId(tid) = taskSet.taskSet.id
251+
taskIdToTaskSetManager(tid) = taskSet
235252
taskIdToExecutorId(tid) = execId
236253
executorsByHost(host) += execId
237254
availableCpus(i) -= CPUS_PER_TASK
@@ -315,26 +332,24 @@ private[spark] class TaskSchedulerImpl(
315332
failedExecutor = Some(execId)
316333
}
317334
}
318-
taskIdToTaskSetId.get(tid) match {
319-
case Some(taskSetId) =>
335+
taskIdToTaskSetManager.get(tid) match {
336+
case Some(taskSet) =>
320337
if (TaskState.isFinished(state)) {
321-
taskIdToTaskSetId.remove(tid)
338+
taskIdToTaskSetManager.remove(tid)
322339
taskIdToExecutorId.remove(tid)
323340
}
324-
activeTaskSets.get(taskSetId).foreach { taskSet =>
325-
if (state == TaskState.FINISHED) {
326-
taskSet.removeRunningTask(tid)
327-
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
328-
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
329-
taskSet.removeRunningTask(tid)
330-
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
331-
}
341+
if (state == TaskState.FINISHED) {
342+
taskSet.removeRunningTask(tid)
343+
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
344+
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
345+
taskSet.removeRunningTask(tid)
346+
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
332347
}
333348
case None =>
334349
logError(
335350
("Ignoring update with state %s for TID %s because its task set is gone (this is " +
336-
"likely the result of receiving duplicate task finished status updates)")
337-
.format(state, tid))
351+
"likely the result of receiving duplicate task finished status updates)")
352+
.format(state, tid))
338353
}
339354
} catch {
340355
case e: Exception => logError("Exception in statusUpdate", e)
@@ -359,9 +374,9 @@ private[spark] class TaskSchedulerImpl(
359374

360375
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
361376
taskMetrics.flatMap { case (id, metrics) =>
362-
taskIdToTaskSetId.get(id)
363-
.flatMap(activeTaskSets.get)
364-
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
377+
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
378+
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
379+
}
365380
}
366381
}
367382
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
@@ -393,9 +408,12 @@ private[spark] class TaskSchedulerImpl(
393408

394409
def error(message: String) {
395410
synchronized {
396-
if (activeTaskSets.nonEmpty) {
411+
if (taskSetsByStageIdAndAttempt.nonEmpty) {
397412
// Have each task set throw a SparkException with the error
398-
for ((taskSetId, manager) <- activeTaskSets) {
413+
for {
414+
attempts <- taskSetsByStageIdAndAttempt.values
415+
manager <- attempts.values
416+
} {
399417
try {
400418
manager.abort(message)
401419
} catch {
@@ -515,6 +533,17 @@ private[spark] class TaskSchedulerImpl(
515533

516534
override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()
517535

536+
private[scheduler] def taskSetManagerForAttempt(
537+
stageId: Int,
538+
stageAttemptId: Int): Option[TaskSetManager] = {
539+
for {
540+
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
541+
manager <- attempts.get(stageAttemptId)
542+
} yield {
543+
manager
544+
}
545+
}
546+
518547
}
519548

520549

core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ import java.util.Properties
2626
private[spark] class TaskSet(
2727
val tasks: Array[Task[_]],
2828
val stageId: Int,
29-
val attempt: Int,
29+
val stageAttemptId: Int,
3030
val priority: Int,
3131
val properties: Properties) {
32-
val id: String = stageId + "." + attempt
32+
val id: String = stageId + "." + stageAttemptId
3333

3434
override def toString: String = "TaskSet " + id
3535
}

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -194,15 +194,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
194194
val ser = SparkEnv.get.closureSerializer.newInstance()
195195
val serializedTask = ser.serialize(task)
196196
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
197-
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
198-
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
197+
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
199198
try {
200199
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
201200
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
202201
"spark.akka.frameSize or using broadcast variables for large values."
203202
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
204203
AkkaUtils.reservedSizeBytes)
205-
taskSet.abort(msg)
204+
taskSetMgr.abort(msg)
206205
} catch {
207206
case e: Exception => logError("Exception in error callback", e)
208207
}

0 commit comments

Comments
 (0)