Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a9bf31f
wip
squito Jun 9, 2015
28d70aa
wip on getting a better test case ...
squito Jun 9, 2015
c443def
better fix and simpler test case
squito Jun 10, 2015
06a0af6
ignore for jenkins
squito Jun 10, 2015
6e14683
unit test just to make sure we fail fast on concurrent attempts
squito Jun 10, 2015
883fe49
Unit tests for concurrent stages issue
kayousterhout May 26, 2015
7021d28
update test since listenerBus.waitUntilEmpty now throws an exception …
squito Jun 10, 2015
55f4a94
get rid of more random test case since kays tests are clearer
squito Jun 10, 2015
b6bc248
style
squito Jun 10, 2015
ecb4e7d
debugging printlns
squito Jun 11, 2015
9601b47
more debug printlns
squito Jun 11, 2015
89a59b6
more printlns ...
squito Jun 11, 2015
8c29707
Merge branch 'master' into SPARK-8103
squito Jun 11, 2015
cb245da
finally found the issue ... clean up debug stuff
squito Jun 11, 2015
46bc26a
more cleanup of debug garbage
squito Jun 11, 2015
d8eb202
Merge branch 'master' into SPARK-8103
squito Jun 29, 2015
ada7726
reviewer feedback
squito Jul 1, 2015
6542b42
remove extra stageAttemptId
squito Jul 2, 2015
b2faef5
faster check for conflicting task sets
squito Jul 7, 2015
517b6e5
get rid of SparkIllegalStateException
squito Jul 7, 2015
227b40d
style
squito Jul 7, 2015
a5f7c8c
remove comment for reviewers
squito Jul 7, 2015
19685bb
switch to using latestInfo.attemptId, and add comments
squito Jul 7, 2015
baf46e1
Index active task sets by stage Id rather than by task set id
kayousterhout Jul 8, 2015
f025154
Merge pull request #2 from kayousterhout/imran_SPARK-8103
squito Jul 14, 2015
c0d4d90
Revert "Index active task sets by stage Id rather than by task set id"
squito Jul 14, 2015
109900e
Merge branch 'master' into SPARK-8103
squito Jul 14, 2015
906d626
fix merge
squito Jul 14, 2015
a21c8b5
Merge branch 'master' into SPARK-8103
squito Jul 14, 2015
d7f1ef2
get rid of activeTaskSets
squito Jul 15, 2015
88b61cc
add tests to make sure that TaskSchedulerImpl schedules correctly wit…
squito Jul 15, 2015
c04707e
style
squito Jul 15, 2015
4470fa1
rename
squito Jul 15, 2015
6bc23af
update log msg
squito Jul 17, 2015
e43ac25
Merge branch 'master' into SPARK-8103
squito Jul 17, 2015
584acd4
simplify going from taskId to taskSetMgr
squito Jul 17, 2015
e01b7aa
fix some comments, style
squito Jul 17, 2015
fb3acfc
fix log msg
squito Jul 17, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 44 additions & 34 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,6 @@ class DAGScheduler(
// Get our pending tasks and remember them in our pendingTasks entry
stage.pendingTasks.clear()


// First figure out the indexes of partition ids to compute.
val partitionsToCompute: Seq[Int] = {
stage match {
Expand Down Expand Up @@ -914,7 +913,7 @@ class DAGScheduler(
partitionsToCompute.map { id =>
val locs = getPreferredLocs(stage.rdd, id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, taskBinary, part, locs)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs)
}

case stage: ResultStage =>
Expand All @@ -923,7 +922,7 @@ class DAGScheduler(
val p: Int = job.partitions(id)
val part = stage.rdd.partitions(p)
val locs = getPreferredLocs(stage.rdd, p)
new ResultTask(stage.id, taskBinary, part, locs, id)
new ResultTask(stage.id, stage.latestInfo.attemptId, taskBinary, part, locs, id)
}
}
} catch {
Expand Down Expand Up @@ -1065,10 +1064,11 @@ class DAGScheduler(
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
logInfo(s"Ignoring possibly bogus $smt completion from executor $execId")
} else {
shuffleStage.addOutputLoc(smt.partitionId, status)
}

if (runningStages.contains(shuffleStage) && shuffleStage.pendingTasks.isEmpty) {
markStageAsFinished(shuffleStage)
logInfo("looking for newly runnable stages")
Expand Down Expand Up @@ -1128,38 +1128,48 @@ class DAGScheduler(
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)

// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is possible
// the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
}
if (failedStage.latestInfo.attemptId != task.stageAttemptId) {
logInfo(s"Ignoring fetch failure from $task as it's from $failedStage attempt" +
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
} else {

if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}
// It is likely that we receive multiple FetchFailed for a single stage (because we have
// multiple tasks running concurrently on different executors). In that case, it is
// possible the fetch failure has already been handled by the scheduler.
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
markStageAsFinished(failedStage, Some(failureMessage))
} else {
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
s"longer running")
}

if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
} else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
}

// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
}

case commitDenied: TaskCommitDenied =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,12 @@ import org.apache.spark.rdd.RDD
*/
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient locs: Seq[TaskLocation],
val outputId: Int)
extends Task[U](stageId, partition.index) with Serializable {
extends Task[U](stageId, stageAttemptId, partition.index) with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ import org.apache.spark.shuffle.ShuffleWriter
*/
private[spark] class ShuffleMapTask(
stageId: Int,
stageAttemptId: Int,
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, partition.index) with Logging {
extends Task[MapStatus](stageId, stageAttemptId, partition.index) with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, null, new Partition { override def index: Int = 0 }, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, null)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ import org.apache.spark.util.Utils
* @param stageId id of the stage this task belongs to
* @param partitionId index of the number in the RDD
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
var partitionId: Int) extends Serializable {

/**
* Called by [[Executor]] to run this task.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ private[spark] class TaskSchedulerImpl(

// TaskSetManagers are not thread safe, so any access to one should be synchronized
// on this class.
val activeTaskSets = new HashMap[String, TaskSetManager]
private val taskSetsByStageIdAndAttempt = new HashMap[Int, HashMap[Int, TaskSetManager]]

val taskIdToTaskSetId = new HashMap[Long, String]
private[scheduler] val taskIdToTaskSetManager = new HashMap[Long, TaskSetManager]
val taskIdToExecutorId = new HashMap[Long, String]

@volatile private var hasReceivedTask = false
Expand Down Expand Up @@ -162,7 +162,17 @@ private[spark] class TaskSchedulerImpl(
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
val stage = taskSet.stageId
val stageTaskSets =
taskSetsByStageIdAndAttempt.getOrElseUpdate(stage, new HashMap[Int, TaskSetManager])
stageTaskSets(taskSet.stageAttemptId) = manager
val conflictingTaskSet = stageTaskSets.exists { case (_, ts) =>
ts.taskSet != taskSet && !ts.isZombie
}
if (conflictingTaskSet) {
throw new IllegalStateException(s"more than one active taskSet for stage $stage:" +
s" ${stageTaskSets.toSeq.map{_._2.taskSet.id}.mkString(",")}")
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kayousterhout How much of a concern should the extra overhead be here? Just wondering whether this (let's hope rare) condition might better be handled only in a non-production environment and behind an if(debug) kind of flag.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps the code could just look for an existing task set that matches the stage ID of the task set being added? That should be a little better than the filter / groupBy. Something like:

activeTaskSets.exists { case (_, ts) => ts.stageId == taskSet.stageId }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restoring the comments from the old diff b/c they are still relevant:

from mark:

@kayousterhout How much of a concern should the extra overhead be here? Just wondering whether this (let's hope rare) condition might better be handled only in a non-production environment and behind an if(debug) kind of flag.

from marcelo:

Perhaps the code could just look for an existing task set that matches the stage ID of the task set being added? That should be a little better than the filter / groupBy.

good point, there isn't any need to do the groupBy, so I've made it simpler.

I'd really rather leave the check in place. In fact I think this fail-fast behavior is especially important in a production environment -- that's much better than an infinite loop of failures.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I totally agree that it is valuable to catch the failure in production. I'm only suggesting that at some point the check becomes a big enough performance hit that it makes sense to compromise on the fail-fast desiderata in order to maintain production performance while trying to ensure in development that the failure can never occur. I doubt that this check is that costly, but my expectation is that Kay has a better sense of how much more we can afford to do within this synchronized block.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you change activeTaskSets to be keyed on the stageId? That seems pretty easy to do and would make this check O(1) rather than O(N)

schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)

if (!isLocal && !hasReceivedTask) {
Expand Down Expand Up @@ -192,19 +202,21 @@ private[spark] class TaskSchedulerImpl(

override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts =>
attempts.foreach { case (_, tsm) =>
// There are two possible cases here:
// 1. The task set manager has been created and some tasks have been scheduled.
// In this case, send a kill signal to the executors to kill the task and then abort
// the stage.
// 2. The task set manager has been created but no tasks has been scheduled. In this case,
// simply abort the stage.
tsm.runningTasksSet.foreach { tid =>
val execId = taskIdToExecutorId(tid)
backend.killTask(tid, execId, interruptThread)
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
tsm.abort("Stage %s cancelled".format(stageId))
logInfo("Stage %d was cancelled".format(stageId))
}
}

Expand All @@ -214,7 +226,12 @@ private[spark] class TaskSchedulerImpl(
* cleaned up.
*/
def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
activeTaskSets -= manager.taskSet.id
taskSetsByStageIdAndAttempt.get(manager.taskSet.stageId).foreach { taskSetsForStage =>
taskSetsForStage -= manager.taskSet.stageAttemptId
if (taskSetsForStage.isEmpty) {
taskSetsByStageIdAndAttempt -= manager.taskSet.stageId
}
}
manager.parent.removeSchedulable(manager)
logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
.format(manager.taskSet.id, manager.parent.name))
Expand All @@ -235,7 +252,7 @@ private[spark] class TaskSchedulerImpl(
for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
tasks(i) += task
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskIdToTaskSetManager(tid) = taskSet
taskIdToExecutorId(tid) = execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
Expand Down Expand Up @@ -319,26 +336,24 @@ private[spark] class TaskSchedulerImpl(
failedExecutor = Some(execId)
}
}
taskIdToTaskSetId.get(tid) match {
case Some(taskSetId) =>
taskIdToTaskSetManager.get(tid) match {
case Some(taskSet) =>
if (TaskState.isFinished(state)) {
taskIdToTaskSetId.remove(tid)
taskIdToTaskSetManager.remove(tid)
taskIdToExecutorId.remove(tid)
}
activeTaskSets.get(taskSetId).foreach { taskSet =>
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
if (state == TaskState.FINISHED) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueSuccessfulTask(taskSet, tid, serializedData)
} else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
taskSet.removeRunningTask(tid)
taskResultGetter.enqueueFailedTask(taskSet, tid, state, serializedData)
}
case None =>
logError(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a small change in logic here. Before, we'd hit this case only if taskIdToTaskSetId didn't contain tid. If activeTaskSets doesn't contain taskSetId, we wouldn't log any error (just a foreach on None). But now, we hit the error case if either taskIdToTaskSetId doesn't contain tid, or if taskSetsByStageIdAndAttempt doesn't contain the stage & attempt.

This wasn't exactly intentional, it just fell out from an attempt at refactoring a helper method -- but it does seems to more accurately correspond to what is in the error msg. Still I'm not certain about it. Perhaps if nobody else has high confidence in this change I should just go back to the old behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine!

("Ignoring update with state %s for TID %s because its task set is gone (this is " +
"likely the result of receiving duplicate task finished status updates)")
.format(state, tid))
"likely the result of receiving duplicate task finished status updates)")
.format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
Expand All @@ -363,9 +378,9 @@ private[spark] class TaskSchedulerImpl(

val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
taskMetrics.flatMap { case (id, metrics) =>
taskIdToTaskSetId.get(id)
.flatMap(activeTaskSets.get)
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
taskIdToTaskSetManager.get(id).map { taskSetMgr =>
(id, taskSetMgr.stageId, taskSetMgr.taskSet.stageAttemptId, metrics)
}
}
}
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
Expand Down Expand Up @@ -397,9 +412,12 @@ private[spark] class TaskSchedulerImpl(

def error(message: String) {
synchronized {
if (activeTaskSets.nonEmpty) {
if (taskSetsByStageIdAndAttempt.nonEmpty) {
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
for {
attempts <- taskSetsByStageIdAndAttempt.values
manager <- attempts.values
} {
try {
manager.abort(message)
} catch {
Expand Down Expand Up @@ -520,6 +538,17 @@ private[spark] class TaskSchedulerImpl(

override def applicationAttemptId(): Option[String] = backend.applicationAttemptId()

private[scheduler] def taskSetManagerForAttempt(
stageId: Int,
stageAttemptId: Int): Option[TaskSetManager] = {
for {
attempts <- taskSetsByStageIdAndAttempt.get(stageId)
manager <- attempts.get(stageAttemptId)
} yield {
manager
}
}

}


Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ import java.util.Properties
private[spark] class TaskSet(
val tasks: Array[Task[_]],
val stageId: Int,
val attempt: Int,
val stageAttemptId: Int,
val priority: Int,
val properties: Properties) {
val id: String = stageId + "." + attempt
val id: String = stageId + "." + stageAttemptId

override def toString: String = "TaskSet " + id
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,15 +191,14 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
for (task <- tasks.flatten) {
val serializedTask = ser.serialize(task)
if (serializedTask.limit >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val taskSetId = scheduler.taskIdToTaskSetId(task.taskId)
scheduler.activeTaskSets.get(taskSetId).foreach { taskSet =>
scheduler.taskIdToTaskSetManager.get(task.taskId).foreach { taskSetMgr =>
try {
var msg = "Serialized task %s:%d was %d bytes, which exceeds max allowed: " +
"spark.akka.frameSize (%d bytes) - reserved (%d bytes). Consider increasing " +
"spark.akka.frameSize or using broadcast variables for large values."
msg = msg.format(task.taskId, task.index, serializedTask.limit, akkaFrameSize,
AkkaUtils.reservedSizeBytes)
taskSet.abort(msg)
taskSetMgr.abort(msg)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
Expand Down
Loading