Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,7 @@ private[spark] class DAGScheduler(
// finished. Here we notify the task scheduler to skip running tasks for the same partition,
// to save resource.
if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
taskScheduler.notifyPartitionCompletion(
stageId, task.partitionId, event.taskInfo.duration)
taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
}

task match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,9 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
// This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want
// DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's
// synchronized and may hurt the throughput of the scheduler.
def enqueuePartitionCompletionNotification(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = {
getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
scheduler.handlePartitionCompleted(stageId, partitionId, taskDuration)
scheduler.handlePartitionCompleted(stageId, partitionId)
})
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ private[spark] trait TaskScheduler {

// Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed
// and they can skip running tasks for it.
def notifyPartitionCompletion(stageId: Int, partitionId: Int, taskDuration: Long)
def notifyPartitionCompletion(stageId: Int, partitionId: Int)

// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,8 @@ private[spark] class TaskSchedulerImpl(
}
}

override def notifyPartitionCompletion(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId, taskDuration)
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId)
}

/**
Expand Down Expand Up @@ -651,12 +650,9 @@ private[spark] class TaskSchedulerImpl(
* means that a task completion from an earlier zombie attempt can lead to the entire stage
* getting marked as successful.
*/
private[scheduler] def handlePartitionCompleted(
stageId: Int,
partitionId: Int,
taskDuration: Long) = synchronized {
private[scheduler] def handlePartitionCompleted(stageId: Int, partitionId: Int) = synchronized {
taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
tsm.markPartitionCompleted(partitionId, taskDuration)
tsm.markPartitionCompleted(partitionId)
})
}

Expand Down
24 changes: 12 additions & 12 deletions core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,8 @@ private[spark] class TaskSetManager(
private val addedJars = HashMap[String, Long](sched.sc.addedJars.toSeq: _*)
private val addedFiles = HashMap[String, Long](sched.sc.addedFiles.toSeq: _*)

// Quantile of tasks at which to start speculation
val speculationQuantile = conf.get(SPECULATION_QUANTILE)
val speculationMultiplier = conf.get(SPECULATION_MULTIPLIER)

val maxResultSize = conf.get(config.MAX_RESULT_SIZE)

val speculationEnabled = conf.get(SPECULATION_ENABLED)

// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
Expand All @@ -80,6 +74,12 @@ private[spark] class TaskSetManager(
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)

val speculationEnabled = conf.get(SPECULATION_ENABLED)
// Quantile of tasks at which to start speculation
val speculationQuantile = conf.get(SPECULATION_QUANTILE)
val speculationMultiplier = conf.get(SPECULATION_MULTIPLIER)
val minFinishedForSpeculation = math.max((speculationQuantile * numTasks).floor.toInt, 1)

// For each task, tracks whether a copy of the task has succeeded. A task will also be
// marked as "succeeded" if it failed with a fetch failure, in which case it should not
// be re-run because the missing map data needs to be regenerated first.
Expand Down Expand Up @@ -816,12 +816,9 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}

private[scheduler] def markPartitionCompleted(partitionId: Int, taskDuration: Long): Unit = {
private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
partitionToIndex.get(partitionId).foreach { index =>
if (!successful(index)) {
if (speculationEnabled && !isZombie) {
successfulTaskDurations.insert(taskDuration)
}
tasksSuccessful += 1
successful(index) = true
if (tasksSuccessful == numTasks) {
Expand Down Expand Up @@ -1035,10 +1032,13 @@ private[spark] class TaskSetManager(
return false
}
var foundTasks = false
val minFinishedForSpeculation = (speculationQuantile * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)

if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
// It's possible that a task is marked as completed by the scheduler, then the size of
// `successfulTaskDurations` may not equal to `tasksSuccessful`. Here we should only count the
// tasks that are submitted by this `TaskSetManager` and are completed successfully.
val numSuccessfulTasks = successfulTaskDurations.size()
if (numSuccessfulTasks >= minFinishedForSpeculation) {
val time = clock.getTimeMillis()
val medianDuration = successfulTaskDurations.median
val threshold = max(speculationMultiplier * medianDuration, minTimeToSpeculation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def notifyPartitionCompletion(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
taskSets.filter(_.stageId == stageId).lastOption.foreach { ts =>
val tasks = ts.tasks.filter(_.partitionId == partitionId)
assert(tasks.length == 1)
Expand Down Expand Up @@ -669,8 +668,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
stageId: Int, interruptThread: Boolean, reason: String): Unit = {
throw new UnsupportedOperationException
}
override def notifyPartitionCompletion(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
throw new UnsupportedOperationException
}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ private class DummyTaskScheduler extends TaskScheduler {
taskId: Long, interruptThread: Boolean, reason: String): Boolean = false
override def killAllTaskAttempts(
stageId: Int, interruptThread: Boolean, reason: String): Unit = {}
override def notifyPartitionCompletion(
stageId: Int, partitionId: Int, taskDuration: Long): Unit = {}
override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {}
override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {}
override def defaultParallelism(): Int = 2
override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg

val taskSetManager = sched.taskSetManagerForAttempt(0, 0).get
assert(taskSetManager.runningTasks === 8)
taskSetManager.markPartitionCompleted(8, 0)
assert(!taskSetManager.successfulTaskDurations.isEmpty())
taskSetManager.markPartitionCompleted(8)
assert(taskSetManager.successfulTaskDurations.isEmpty())
taskSetManager.checkSpeculatableTasks(0)
}

Expand Down