diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 17292b4c15b8b..5ed2803d76afc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -210,11 +210,14 @@ private[spark] class TaskSchedulerImpl( SparkEnv.set(sc.env) // Mark each slave as alive and remember its hostname + // Also track if new executor is added + var newExecAvail = false for (o <- offers) { executorIdToHost(o.executorId) = o.host if (!executorsByHost.contains(o.host)) { executorsByHost(o.host) = new HashSet[String]() executorAdded(o.executorId, o.host) + newExecAvail = true } } @@ -227,12 +230,15 @@ private[spark] class TaskSchedulerImpl( for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( taskSet.parent.name, taskSet.name, taskSet.runningTasks)) + if (newExecAvail) { + taskSet.executorAdded() + } } // Take each TaskSet in our scheduling order, and then offer it each node in increasing order // of locality levels so that it gets a chance to launch local tasks on all of them. var launchedTask = false - for (taskSet <- sortedTaskSets; maxLocality <- TaskLocality.values) { + for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) { do { launchedTask = false for (i <- 0 until shuffledOffers.size) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index f3bd0797aa035..b5bcdd7e99c58 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -118,7 +118,7 @@ private[spark] class TaskSetManager( private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. - val pendingTasksWithNoPrefs = new ArrayBuffer[Int] + var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). val allPendingTasks = new ArrayBuffer[Int] @@ -153,8 +153,8 @@ private[spark] class TaskSetManager( } // Figure out which locality levels we have in our TaskSet, so we can do delay scheduling - val myLocalityLevels = computeValidLocalityLevels() - val localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level + var myLocalityLevels = computeValidLocalityLevels() + var localityWaits = myLocalityLevels.map(getLocalityWait) // Time to wait at each level // Delay scheduling variables: we keep track of our current locality level and the time we // last launched a task at that level, and move up a level when localityWaits[curLevel] expires. @@ -181,16 +181,14 @@ private[spark] class TaskSetManager( var hadAliveLocations = false for (loc <- tasks(index).preferredLocations) { for (execId <- loc.executorId) { - if (sched.isExecutorAlive(execId)) { - addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) - hadAliveLocations = true - } + addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer)) } if (sched.hasExecutorsAliveOnHost(loc.host)) { - addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) - for (rack <- sched.getRackForHost(loc.host)) { - addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) - } + hadAliveLocations = true + } + addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer)) + for (rack <- sched.getRackForHost(loc.host)) { + addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer)) hadAliveLocations = true } } @@ -725,10 +723,12 @@ private[spark] class TaskSetManager( private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = { import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY} val levels = new ArrayBuffer[TaskLocality.TaskLocality] - if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0) { + if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 && + pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) { levels += PROCESS_LOCAL } - if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0) { + if (!pendingTasksForHost.isEmpty && getLocalityWait(NODE_LOCAL) != 0 && + pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) { levels += NODE_LOCAL } if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0) { @@ -738,4 +738,21 @@ private[spark] class TaskSetManager( logDebug("Valid locality levels for " + taskSet + ": " + levels.mkString(", ")) levels.toArray } + + // Re-compute pendingTasksWithNoPrefs since new preferred locations may become available + def executorAdded() { + def newLocAvail(index: Int): Boolean = { + for (loc <- tasks(index).preferredLocations) { + if (sched.hasExecutorsAliveOnHost(loc.host) || + sched.getRackForHost(loc.host).isDefined) { + return true + } + } + false + } + logInfo("Re-computing pending task lists.") + pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.filter(!newLocAvail(_)) + myLocalityLevels = computeValidLocalityLevels() + localityWaits = myLocalityLevels.map(getLocalityWait) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index c92b6dc96c8eb..1cabcbe89f592 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -77,6 +77,10 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex override def isExecutorAlive(execId: String): Boolean = executors.contains(execId) override def hasExecutorsAliveOnHost(host: String): Boolean = executors.values.exists(_ == host) + + def addExecutor(execId: String, host: String) { + executors.put(execId, host) + } } class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { @@ -384,6 +388,36 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(sched.taskSetsFailed.contains(taskSet.id)) } + test("new executors get added") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4, + Seq(TaskLocation("host1", "execA")), + Seq(TaskLocation("host1", "execB")), + Seq(TaskLocation("host2", "execC")), + Seq()) + val clock = new FakeClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + // All tasks added to no-pref list since no preferred location is available + assert(manager.pendingTasksWithNoPrefs.size === 4) + // Only ANY is valid + assert(manager.myLocalityLevels.sameElements(Array(ANY))) + // Add a new executor + sched.addExecutor("execD", "host1") + manager.executorAdded() + // Task 0 and 1 should be removed from no-pref list + assert(manager.pendingTasksWithNoPrefs.size === 2) + // Valid locality should contain NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, ANY))) + // Add another executor + sched.addExecutor("execC", "host2") + manager.executorAdded() + // No-pref list now only contains task 3 + assert(manager.pendingTasksWithNoPrefs.size === 1) + // Valid locality should contain PROCESS_LOCAL, NODE_LOCAL and ANY + assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY))) + } + def createTaskResult(id: Int): DirectTaskResult[Int] = { val valueSer = SparkEnv.get.serializer.newInstance() new DirectTaskResult[Int](valueSer.serialize(id), mutable.Map.empty, new TaskMetrics)