From dbc478b0d96de0f1a48f93d90f1e648229b5790c Mon Sep 17 00:00:00 2001 From: Devesh Agrawal Date: Mon, 6 Jul 2020 10:45:33 -0700 Subject: [PATCH] Reduce job failures during decommissioning This PR reduces the prospect of a job loss during decommissioning. It fixes two holes in the current decommissioning framework: - (a) Loss of decommissioned executors is not treated as a job failure: We know that the decom'd executor would be dying soon, so its death is clearly not caused by the application. - (b) Shuffle files on the decommissioned host are cleared when the first fetch failure is detected from a decommissioned host: This is a bit tricky in terms of when to clear the shuffle state ? Ideally you want to clear it the millisecond before the shuffle service on the node dies (or the executor dies when there is no external shuffle service) -- too soon and it could lead to some wasteage and too late would lead to fetch failures. (These two things are a bit intertwined so it is easier to do them in a single commit) The approach here is to do this clearing when the very first fetch failure is observed on the decom'd block manager, without waiting for other blocks to also signal a failure. Added a new unit test `DecommissionWorkerSuite` to test the new behavior. --- .../apache/spark/scheduler/DAGScheduler.scala | 19 +- .../spark/scheduler/ExecutorLossReason.scala | 7 +- .../spark/scheduler/TaskScheduler.scala | 5 + .../spark/scheduler/TaskSchedulerImpl.scala | 37 +- .../spark/scheduler/TaskSetManager.scala | 1 + .../deploy/DecommissionWorkerSuite.scala | 424 ++++++++++++++++++ .../spark/scheduler/DAGSchedulerSuite.scala | 4 + .../ExternalClusterManagerSuite.scala | 2 + .../scheduler/TaskSchedulerImplSuite.scala | 47 ++ 9 files changed, 539 insertions(+), 7 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 2503ae0856dc..6b376cdadc66 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1821,10 +1821,19 @@ private[spark] class DAGScheduler( // TODO: mark the executor as failed only if there were lots of fetch failures on it if (bmAddress != null) { - val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled && - unRegisterOutputOnHostOnFetchFailure) { - // We had a fetch failure with the external shuffle service, so we - // assume all shuffle data on the node is bad. + val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled + val isHostDecommissioned = taskScheduler + .getExecutorDecommissionInfo(bmAddress.executorId) + .exists(_.isHostDecommissioned) + + // Shuffle output of all executors on host `bmAddress.host` may be lost if: + // - External shuffle service is enabled, so we assume that all shuffle data on node is + // bad. + // - Host is decommissioned, thus all executors on that host will die. + val shuffleOutputOfEntireHostLost = externalShuffleServiceEnabled || + isHostDecommissioned + val hostToUnregisterOutputs = if (shuffleOutputOfEntireHostLost + && unRegisterOutputOnHostOnFetchFailure) { Some(bmAddress.host) } else { // Unregister shuffle data just for one executor (we don't have any @@ -2339,7 +2348,7 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case ExecutorLost(execId, reason) => val workerLost = reason match { - case ExecutorProcessLost(_, true) => true + case ExecutorProcessLost(_, true, _) => true case _ => false } dagScheduler.handleExecutorLost(execId, workerLost) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala index 4141ed799a4e..671dedaa5a6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ExecutorLossReason.scala @@ -54,9 +54,14 @@ private [spark] object LossReasonPending extends ExecutorLossReason("Pending los /** * @param _message human readable loss reason * @param workerLost whether the worker is confirmed lost too (i.e. including shuffle service) + * @param causedByApp whether the loss of the executor is the fault of the running app. + * (assumed true by default unless known explicitly otherwise) */ private[spark] -case class ExecutorProcessLost(_message: String = "Worker lost", workerLost: Boolean = false) +case class ExecutorProcessLost( + _message: String = "Executor Process Lost", + workerLost: Boolean = false, + causedByApp: Boolean = true) extends ExecutorLossReason(_message) /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index b29458c48141..1101d0616d2b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -103,6 +103,11 @@ private[spark] trait TaskScheduler { */ def executorDecommission(executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit + /** + * If an executor is decommissioned, return its corresponding decommission info + */ + def getExecutorDecommissionInfo(executorId: String): Option[ExecutorDecommissionInfo] + /** * Process a lost executor */ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 510318afcb8d..b734d9f72944 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -136,6 +136,8 @@ private[spark] class TaskSchedulerImpl( // IDs of the tasks running on each executor private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]] + private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo] + def runningTasksByExecutors: Map[String, Int] = synchronized { executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap } @@ -939,12 +941,43 @@ private[spark] class TaskSchedulerImpl( override def executorDecommission( executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = { + synchronized { + // Don't bother noting decommissioning for executors that we don't know about + if (executorIdToHost.contains(executorId)) { + // The scheduler can get multiple decommission updates from multiple sources, + // and some of those can have isHostDecommissioned false. We merge them such that + // if we heard isHostDecommissioned ever true, then we keep that one since it is + // most likely coming from the cluster manager and thus authoritative + val oldDecomInfo = executorsPendingDecommission.get(executorId) + if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) { + executorsPendingDecommission(executorId) = decommissionInfo + } + } + } rootPool.executorDecommission(executorId) backend.reviveOffers() } - override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = { + override def getExecutorDecommissionInfo(executorId: String) + : Option[ExecutorDecommissionInfo] = synchronized { + executorsPendingDecommission.get(executorId) + } + + override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = { var failedExecutor: Option[String] = None + val reason = givenReason match { + // Handle executor process loss due to decommissioning + case ExecutorProcessLost(message, origWorkerLost, origCausedByApp) => + val executorDecommissionInfo = getExecutorDecommissionInfo(executorId) + ExecutorProcessLost( + message, + // Also mark the worker lost if we know that the host was decommissioned + origWorkerLost || executorDecommissionInfo.exists(_.isHostDecommissioned), + // Executor loss is certainly not caused by app if we knew that this executor is being + // decommissioned + causedByApp = executorDecommissionInfo.isEmpty && origCausedByApp) + case e => e + } synchronized { if (executorIdToRunningTaskIds.contains(executorId)) { @@ -1033,6 +1066,8 @@ private[spark] class TaskSchedulerImpl( } } + executorsPendingDecommission -= executorId + if (reason != LossReasonPending) { executorIdToHost -= executorId rootPool.executorLost(executorId, host, reason) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 4b31ff0c790d..d69f358cd19d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -985,6 +985,7 @@ private[spark] class TaskSetManager( val exitCausedByApp: Boolean = reason match { case exited: ExecutorExited => exited.exitCausedByApp case ExecutorKilled => false + case ExecutorProcessLost(_, _, false) => false case _ => true } handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(info.executorId, exitCausedByApp, diff --git a/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala new file mode 100644 index 000000000000..ee9a6be03868 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.duration._ + +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark._ +import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState, WorkerDecommission} +import org.apache.spark.deploy.master.{ApplicationInfo, Master, WorkerInfo} +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.internal.{config, Logging} +import org.apache.spark.network.TransportContext +import org.apache.spark.network.netty.SparkTransportConf +import org.apache.spark.network.shuffle.ExternalBlockHandler +import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.scheduler._ +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils + +class DecommissionWorkerSuite + extends SparkFunSuite + with Logging + with LocalSparkContext + with BeforeAndAfterEach { + + private var masterAndWorkerConf: SparkConf = null + private var masterAndWorkerSecurityManager: SecurityManager = null + private var masterRpcEnv: RpcEnv = null + private var master: Master = null + private var workerIdToRpcEnvs: mutable.HashMap[String, RpcEnv] = null + private var workers: mutable.ArrayBuffer[Worker] = null + + override def beforeEach(): Unit = { + super.beforeEach() + masterAndWorkerConf = new SparkConf() + .set(config.Worker.WORKER_DECOMMISSION_ENABLED, true) + masterAndWorkerSecurityManager = new SecurityManager(masterAndWorkerConf) + masterRpcEnv = RpcEnv.create( + Master.SYSTEM_NAME, + "localhost", + 0, + masterAndWorkerConf, + masterAndWorkerSecurityManager) + master = makeMaster() + workerIdToRpcEnvs = mutable.HashMap.empty + workers = mutable.ArrayBuffer.empty + } + + override def afterEach(): Unit = { + try { + masterRpcEnv.shutdown() + workerIdToRpcEnvs.values.foreach(_.shutdown()) + workerIdToRpcEnvs.clear() + master.stop() + workers.foreach(_.stop()) + workers.clear() + masterRpcEnv = null + } finally { + super.afterEach() + } + } + + test("decommission workers should not result in job failure") { + val maxTaskFailures = 2 + val numTimesToKillWorkers = maxTaskFailures + 1 + val numWorkers = numTimesToKillWorkers + 1 + createWorkers(numWorkers) + + // Here we will have a single task job and we will keep decommissioning (and killing) the + // worker running that task K times. Where K is more than the maxTaskFailures. Since the worker + // is notified of the decommissioning, the task failures can be ignored and not fail + // the job. + + sc = createSparkContext(config.TASK_MAX_FAILURES.key -> maxTaskFailures.toString) + val executorIdToWorkerInfo = getExecutorToWorkerAssignments + val taskIdsKilled = new ConcurrentHashMap[Long, Boolean] + val listener = new RootStageAwareListener { + override def handleRootTaskStart(taskStart: SparkListenerTaskStart): Unit = { + val taskInfo = taskStart.taskInfo + if (taskIdsKilled.size() < numTimesToKillWorkers) { + val workerInfo = executorIdToWorkerInfo(taskInfo.executorId) + decommissionWorkerOnMaster(workerInfo, "partition 0 must die") + killWorkerAfterTimeout(workerInfo, 1) + taskIdsKilled.put(taskInfo.taskId, true) + } + } + } + TestUtils.withListener(sc, listener) { _ => + val jobResult = sc.parallelize(1 to 1, 1).map { _ => + Thread.sleep(5 * 1000L); 1 + }.count() + assert(jobResult === 1) + } + // single task job that gets to run numTimesToKillWorkers + 1 times. + assert(listener.getTasksFinished().size === numTimesToKillWorkers + 1) + listener.rootTasksStarted.asScala.foreach { taskInfo => + assert(taskInfo.index == 0, s"Unknown task index ${taskInfo.index}") + } + listener.rootTasksEnded.asScala.foreach { taskInfo => + assert(taskInfo.index === 0, s"Expected task index ${taskInfo.index} to be 0") + // If a task has been killed then it shouldn't be successful + val taskSuccessExpected = !taskIdsKilled.getOrDefault(taskInfo.taskId, false) + val taskSuccessActual = taskInfo.successful + assert(taskSuccessActual === taskSuccessExpected, + s"Expected task success $taskSuccessActual == $taskSuccessExpected") + } + } + + test("decommission workers ensure that shuffle output is regenerated even with shuffle service") { + createWorkers(2) + val ss = new ExternalShuffleServiceHolder() + + sc = createSparkContext( + config.Tests.TEST_NO_STAGE_RETRY.key -> "true", + config.SHUFFLE_MANAGER.key -> "sort", + config.SHUFFLE_SERVICE_ENABLED.key -> "true", + config.SHUFFLE_SERVICE_PORT.key -> ss.getPort.toString + ) + + // Here we will create a 2 stage job: The first stage will have two tasks and the second stage + // will have one task. The two tasks in the first stage will be long and short. We decommission + // and kill the worker after the short task is done. Eventually the driver should get the + // executor lost signal for the short task executor. This should trigger regenerating + // the shuffle output since we cleanly decommissioned the executor, despite running with an + // external shuffle service. + try { + val executorIdToWorkerInfo = getExecutorToWorkerAssignments + val workerForTask0Decommissioned = new AtomicBoolean(false) + // single task job + val listener = new RootStageAwareListener { + override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val taskInfo = taskEnd.taskInfo + if (taskInfo.index == 0) { + if (workerForTask0Decommissioned.compareAndSet(false, true)) { + val workerInfo = executorIdToWorkerInfo(taskInfo.executorId) + decommissionWorkerOnMaster(workerInfo, "Kill early done map worker") + killWorkerAfterTimeout(workerInfo, 0) + logInfo(s"Killed the node ${workerInfo.hostPort} that was running the early task") + } + } + } + } + TestUtils.withListener(sc, listener) { _ => + val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => { + val sleepTimeSeconds = if (pid == 0) 1 else 10 + Thread.sleep(sleepTimeSeconds * 1000L) + List(1).iterator + }, preservesPartitioning = true).repartition(1).sum() + assert(jobResult === 2) + } + val tasksSeen = listener.getTasksFinished() + // 4 tasks: 2 from first stage, one retry due to decom, one more from the second stage. + assert(tasksSeen.size === 4, s"Expected 4 tasks but got $tasksSeen") + listener.rootTasksStarted.asScala.foreach { taskInfo => + assert(taskInfo.index <= 1, s"Expected ${taskInfo.index} <= 1") + assert(taskInfo.successful, s"Task ${taskInfo.index} should be successful") + } + val tasksEnded = listener.rootTasksEnded.asScala + tasksEnded.filter(_.index != 0).foreach { taskInfo => + assert(taskInfo.attemptNumber === 0, "2nd task should succeed on 1st attempt") + } + val firstTaskAttempts = tasksEnded.filter(_.index == 0) + assert(firstTaskAttempts.size > 1, s"Task 0 should have multiple attempts") + } finally { + ss.close() + } + } + + test("decommission workers ensure that fetch failures lead to rerun") { + createWorkers(2) + sc = createSparkContext( + config.Tests.TEST_NO_STAGE_RETRY.key -> "false", + config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE.key -> "true") + val executorIdToWorkerInfo = getExecutorToWorkerAssignments + val executorToDecom = executorIdToWorkerInfo.keysIterator.next + + // The task code below cannot call executorIdToWorkerInfo, so we need to pre-compute + // the worker to decom to force it to be serialized into the task. + val workerToDecom = executorIdToWorkerInfo(executorToDecom) + + // The setup of this job is similar to the one above: 2 stage job with first stage having + // long and short tasks. Except that we want the shuffle output to be regenerated on a + // fetch failure instead of an executor lost. Since it is hard to "trigger a fetch failure", + // we manually raise the FetchFailed exception when the 2nd stage's task runs and require that + // fetch failure to trigger a recomputation. + logInfo(s"Will try to decommission the task running on executor $executorToDecom") + val listener = new RootStageAwareListener { + override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val taskInfo = taskEnd.taskInfo + if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 && + taskEnd.stageAttemptId == 0) { + decommissionWorkerOnMaster(workerToDecom, + "decommission worker after task on it is done") + } + } + } + TestUtils.withListener(sc, listener) { _ => + val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => { + val executorId = SparkEnv.get.executorId + val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1 + Thread.sleep(sleepTimeSeconds * 1000L) + List(1).iterator + }, preservesPartitioning = true) + .repartition(1).mapPartitions(iter => { + val context = TaskContext.get() + if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) { + // MapIndex is explicitly -1 to force the entire host to be decommissioned + // However, this will cause both the tasks in the preceding stage since the host here is + // "localhost" (shortcoming of this single-machine unit test in that all the workers + // are actually on the same host) + throw new FetchFailedException(BlockManagerId(executorToDecom, + workerToDecom.host, workerToDecom.port), 0, 0, -1, 0, "Forcing fetch failure") + } + val sumVal: List[Int] = List(iter.sum) + sumVal.iterator + }, preservesPartitioning = true) + .sum() + assert(jobResult === 2) + } + // 6 tasks: 2 from first stage, 2 rerun again from first stage, 2nd stage attempt 1 and 2. + val tasksSeen = listener.getTasksFinished() + assert(tasksSeen.size === 6, s"Expected 6 tasks but got $tasksSeen") + } + + private abstract class RootStageAwareListener extends SparkListener { + private var rootStageId: Option[Int] = None + private val tasksFinished = new ConcurrentLinkedQueue[String]() + private val jobDone = new AtomicBoolean(false) + val rootTasksStarted = new ConcurrentLinkedQueue[TaskInfo]() + val rootTasksEnded = new ConcurrentLinkedQueue[TaskInfo]() + + protected def isRootStageId(stageId: Int): Boolean = + (rootStageId.isDefined && rootStageId.get == stageId) + + override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = { + if (stageSubmitted.stageInfo.parentIds.isEmpty && rootStageId.isEmpty) { + rootStageId = Some(stageSubmitted.stageInfo.stageId) + } + } + + override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = { + jobEnd.jobResult match { + case JobSucceeded => jobDone.set(true) + } + } + + protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {} + + protected def handleRootTaskStart(start: SparkListenerTaskStart) = {} + + override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = { + if (isRootStageId(taskStart.stageId)) { + rootTasksStarted.add(taskStart.taskInfo) + handleRootTaskStart(taskStart) + } + } + + override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = { + val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" + + s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}" + logInfo(s"Task End $taskSignature") + tasksFinished.add(taskSignature) + if (isRootStageId(taskEnd.stageId)) { + rootTasksEnded.add(taskEnd.taskInfo) + handleRootTaskEnd(taskEnd) + } + } + + def getTasksFinished(): Seq[String] = { + assert(jobDone.get(), "Job isn't successfully done yet") + tasksFinished.asScala.toSeq + } + } + + private def getExecutorToWorkerAssignments: Map[String, WorkerInfo] = { + val executorIdToWorkerInfo = mutable.HashMap[String, WorkerInfo]() + master.workers.foreach { wi => + assert(wi.executors.size <= 1, "There should be at most one executor per worker") + // Cast the executorId to string since the TaskInfo.executorId is a string + wi.executors.values.foreach { e => + val executorIdString = e.id.toString + val oldWorkerInfo = executorIdToWorkerInfo.put(executorIdString, wi) + assert(oldWorkerInfo.isEmpty, + s"Executor $executorIdString already present on another worker ${oldWorkerInfo}") + } + } + executorIdToWorkerInfo.toMap + } + + private def makeMaster(): Master = { + val master = new Master( + masterRpcEnv, + masterRpcEnv.address, + 0, + masterAndWorkerSecurityManager, + masterAndWorkerConf) + masterRpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + master + } + + private def createWorkers(numWorkers: Int, cores: Int = 1, memory: Int = 1024): Unit = { + val workerRpcEnvs = (0 until numWorkers).map { i => + RpcEnv.create( + Worker.SYSTEM_NAME + i, + "localhost", + 0, + masterAndWorkerConf, + masterAndWorkerSecurityManager) + } + workers.clear() + val rpcAddressToRpcEnv: mutable.HashMap[RpcAddress, RpcEnv] = mutable.HashMap.empty + workerRpcEnvs.foreach { rpcEnv => + val workDir = Utils.createTempDir(namePrefix = this.getClass.getSimpleName()).toString + val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), + Worker.ENDPOINT_NAME, workDir, masterAndWorkerConf, masterAndWorkerSecurityManager) + rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) + workers.append(worker) + val oldRpcEnv = rpcAddressToRpcEnv.put(rpcEnv.address, rpcEnv) + logInfo(s"Created a worker at ${rpcEnv.address} with workdir $workDir") + assert(oldRpcEnv.isEmpty, s"Detected duplicate rpcEnv ${oldRpcEnv} for ${rpcEnv.address}") + } + workerIdToRpcEnvs.clear() + // Wait until all workers register with master successfully + eventually(timeout(1.minute), interval(1.seconds)) { + val workersOnMaster = getMasterState.workers + val numWorkersCurrently = workersOnMaster.length + logInfo(s"Waiting for $numWorkers workers to come up: So far $numWorkersCurrently") + assert(numWorkersCurrently === numWorkers) + workersOnMaster.foreach { workerInfo => + val rpcAddress = RpcAddress(workerInfo.host, workerInfo.port) + val rpcEnv = rpcAddressToRpcEnv(rpcAddress) + assert(rpcEnv != null, s"Cannot find the worker for $rpcAddress") + val oldRpcEnv = workerIdToRpcEnvs.put(workerInfo.id, rpcEnv) + assert(oldRpcEnv.isEmpty, s"Detected duplicate rpcEnv ${oldRpcEnv} for worker " + + s"${workerInfo.id}") + } + } + logInfo(s"Created ${workers.size} workers") + } + + private def getMasterState: MasterStateResponse = { + master.self.askSync[MasterStateResponse](RequestMasterState) + } + + private def getApplications(): Seq[ApplicationInfo] = { + getMasterState.activeApps + } + + def decommissionWorkerOnMaster(workerInfo: WorkerInfo, reason: String): Unit = { + logInfo(s"Trying to decommission worker ${workerInfo.id} for reason `$reason`") + master.self.send(WorkerDecommission(workerInfo.id, workerInfo.endpoint)) + } + + def killWorkerAfterTimeout(workerInfo: WorkerInfo, secondsToWait: Int): Unit = { + val env = workerIdToRpcEnvs(workerInfo.id) + Thread.sleep(secondsToWait * 1000L) + env.shutdown() + env.awaitTermination() + } + + def createSparkContext(extraConfs: (String, String)*): SparkContext = { + val conf = new SparkConf() + .setMaster(masterRpcEnv.address.toSparkURL) + .setAppName("test") + .setAll(extraConfs) + sc = new SparkContext(conf) + val appId = sc.applicationId + eventually(timeout(1.minute), interval(1.seconds)) { + val apps = getApplications() + assert(apps.size === 1) + assert(apps.head.id === appId) + assert(apps.head.getExecutorLimit === Int.MaxValue) + } + sc + } + + private class ExternalShuffleServiceHolder() { + // The external shuffle service can start with default configs and not get polluted by the + // other configs used in this test. + private val transportConf = SparkTransportConf.fromSparkConf(new SparkConf(), + "shuffle", numUsableCores = 2) + private val rpcHandler = new ExternalBlockHandler(transportConf, null) + private val transportContext = new TransportContext(transportConf, rpcHandler) + private val server = transportContext.createServer() + + def getPort: Int = server.getPort + + def close(): Unit = { + Utils.tryLogNonFatalError { + server.close() + } + Utils.tryLogNonFatalError { + rpcHandler.close() + } + Utils.tryLogNonFatalError { + transportContext.close() + } + } + } +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 45af0d086890..c829006923c4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -178,6 +178,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorDecommission( executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = {} + override def getExecutorDecommissionInfo( + executorId: String): Option[ExecutorDecommissionInfo] = None } /** @@ -785,6 +787,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorDecommission( executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = {} + override def getExecutorDecommissionInfo( + executorId: String): Option[ExecutorDecommissionInfo] = None } val noKillScheduler = new DAGScheduler( sc, diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index b2a5f77b4b04..07d88672290f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -101,4 +101,6 @@ private class DummyTaskScheduler extends TaskScheduler { override def executorDecommission( executorId: String, decommissionInfo: ExecutorDecommissionInfo): Unit = {} + override def getExecutorDecommissionInfo( + executorId: String): Option[ExecutorDecommissionInfo] = None } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 9ca3ce9d43ca..e5836458e7f9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1802,6 +1802,53 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(2 == taskDescriptions.head.resources(GPU).addresses.size) } + private def setupSchedulerForDecommissionTests(): TaskSchedulerImpl = { + val taskScheduler = setupSchedulerWithMaster( + s"local[2]", + config.CPUS_PER_TASK.key -> 1.toString) + taskScheduler.submitTasks(FakeTask.createTaskSet(2)) + val multiCoreWorkerOffers = IndexedSeq(WorkerOffer("executor0", "host0", 1), + WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(multiCoreWorkerOffers).flatten + assert(taskDescriptions.map(_.executorId).sorted === Seq("executor0", "executor1")) + taskScheduler + } + + test("scheduler should keep the decommission info where host was decommissioned") { + val scheduler = setupSchedulerForDecommissionTests() + + scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0", false)) + scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1", true)) + scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0 new", false)) + scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1 new", false)) + + assert(scheduler.getExecutorDecommissionInfo("executor0") + === Some(ExecutorDecommissionInfo("0 new", false))) + assert(scheduler.getExecutorDecommissionInfo("executor1") + === Some(ExecutorDecommissionInfo("1", true))) + assert(scheduler.getExecutorDecommissionInfo("executor2").isEmpty) + } + + test("scheduler should ignore decommissioning of removed executors") { + val scheduler = setupSchedulerForDecommissionTests() + + // executor 0 is decommissioned after loosing + assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty) + scheduler.executorLost("executor0", ExecutorExited(0, false, "normal")) + assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty) + scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("", false)) + assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty) + + // executor 1 is decommissioned before loosing + assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty) + scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false)) + assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined) + scheduler.executorLost("executor1", ExecutorExited(0, false, "normal")) + assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty) + scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false)) + assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty) + } + /** * Used by tests to simulate a task failure. This calls the failure handler explicitly, to ensure * that all the state is updated when this method returns. Otherwise, there's no way to know when