diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala index df2d94ec85216..764e351e70286 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterManager.scala @@ -29,7 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.kubernetes.KubernetesExternalShuffleClientImpl import org.apache.spark.scheduler.{ExternalClusterManager, SchedulerBackend, TaskScheduler, TaskSchedulerImpl} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterManager extends ExternalClusterManager with Logging { @@ -132,12 +132,18 @@ private[spark] class KubernetesClusterManager extends ExternalClusterManager wit executorInitContainerBootstrap, executorInitContainerSecretVolumePlugin, kubernetesShuffleManager) + val allocatorExecutor = ThreadUtils + .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") + val requestExecutorsService = ThreadUtils.newDaemonCachedThreadPool( + "kubernetes-executor-requests") new KubernetesClusterSchedulerBackend( - sc.taskScheduler.asInstanceOf[TaskSchedulerImpl], - sc, + scheduler.asInstanceOf[TaskSchedulerImpl], + sc.env.rpcEnv, executorPodFactory, kubernetesShuffleManager, - kubernetesClient) + kubernetesClient, + allocatorExecutor, + requestExecutorsService) } override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala index 1593a9a842e98..54612c80cdf28 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler.cluster.kubernetes import java.io.Closeable import java.net.InetAddress import java.util.Collections -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ExecutorService, ScheduledExecutorService, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.{AtomicInteger, AtomicLong, AtomicReference} import io.fabric8.kubernetes.api.model._ @@ -29,25 +29,28 @@ import scala.collection.{concurrent, mutable} import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.{SparkContext, SparkEnv, SparkException} +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.deploy.kubernetes.config._ import org.apache.spark.deploy.kubernetes.constants._ import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpointAddress, RpcEnv} import org.apache.spark.scheduler.{ExecutorExited, SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RetrieveSparkAppConfig, SparkAppConfig} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend -import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.util.Utils private[spark] class KubernetesClusterSchedulerBackend( scheduler: TaskSchedulerImpl, - val sc: SparkContext, + rpcEnv: RpcEnv, executorPodFactory: ExecutorPodFactory, shuffleManager: Option[KubernetesExternalShuffleManager], - kubernetesClient: KubernetesClient) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { + kubernetesClient: KubernetesClient, + allocatorExecutor: ScheduledExecutorService, + requestExecutorsService: ExecutorService) + extends CoarseGrainedSchedulerBackend(scheduler, rpcEnv) { import KubernetesClusterSchedulerBackend._ + private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) private val RUNNING_EXECUTOR_PODS_LOCK = new Object // Indexed by executor IDs and guarded by RUNNING_EXECUTOR_PODS_LOCK. private val runningExecutorsToPods = new mutable.HashMap[String, Pod] @@ -57,10 +60,10 @@ private[spark] class KubernetesClusterSchedulerBackend( private val EXECUTOR_PODS_BY_IPS_LOCK = new Object // Indexed by executor IP addrs and guarded by EXECUTOR_PODS_BY_IPS_LOCK private val executorPodsByIPs = new mutable.HashMap[String, Pod] - private val failedPods: concurrent.Map[String, ExecutorExited] = new - ConcurrentHashMap[String, ExecutorExited]().asScala - private val executorsToRemove = Collections.newSetFromMap[String]( - new ConcurrentHashMap[String, java.lang.Boolean]()).asScala + private val podsWithKnownExitReasons: concurrent.Map[String, ExecutorExited] = + new ConcurrentHashMap[String, ExecutorExited]().asScala + private val disconnectedPodsByExecutorIdPendingRemoval = + new ConcurrentHashMap[String, Pod]().asScala private val kubernetesNamespace = conf.get(KUBERNETES_NAMESPACE) @@ -68,10 +71,8 @@ private[spark] class KubernetesClusterSchedulerBackend( .get(KUBERNETES_DRIVER_POD_NAME) .getOrElse( throw new SparkException("Must specify the driver pod name")) - private val executorPodNamePrefix = conf.get(KUBERNETES_EXECUTOR_POD_NAME_PREFIX) - private implicit val requestExecutorContext = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("kubernetes-executor-requests")) + requestExecutorsService) private val driverPod = try { kubernetesClient.pods().inNamespace(kubernetesNamespace). @@ -93,9 +94,9 @@ private[spark] class KubernetesClusterSchedulerBackend( protected var totalExpectedExecutors = new AtomicInteger(0) private val driverUrl = RpcEndpointAddress( - sc.getConf.get("spark.driver.host"), - sc.getConf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + conf.get("spark.driver.host"), + conf.getInt("spark.driver.port", DEFAULT_DRIVER_PORT), + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString private val initialExecutors = getInitialTargetExecutorNumber() @@ -109,21 +110,14 @@ private[spark] class KubernetesClusterSchedulerBackend( s"${KUBERNETES_ALLOCATION_BATCH_SIZE} " + s"is ${podAllocationSize}, should be a positive integer") - private val allocator = ThreadUtils - .newDaemonSingleThreadScheduledExecutor("kubernetes-pod-allocator") + private val allocatorRunnable = new Runnable { - private val allocatorRunnable: Runnable = new Runnable { - - // Number of times we are allowed check for the loss reason for an executor before we give up - // and assume the executor failed for good, and attribute it to a framework fault. - private val MAX_EXECUTOR_LOST_REASON_CHECKS = 10 - private val executorsToRecover = new mutable.HashSet[String] // Maintains a map of executor id to count of checks performed to learn the loss reason // for an executor. - private val executorReasonChecks = new mutable.HashMap[String, Int] + private val executorReasonCheckAttemptCounts = new mutable.HashMap[String, Int] override def run(): Unit = { - removeFailedExecutors() + handleDisconnectedExecutors() RUNNING_EXECUTOR_PODS_LOCK.synchronized { if (totalRegisteredExecutors.get() < runningExecutorsToPods.size) { logDebug("Waiting for pending executors before scaling") @@ -132,7 +126,7 @@ private[spark] class KubernetesClusterSchedulerBackend( } else { val nodeToLocalTaskCount = getNodesWithLocalTaskCounts for (i <- 0 until math.min( - totalExpectedExecutors.get - runningExecutorsToPods.size, podAllocationSize)) { + totalExpectedExecutors.get - runningExecutorsToPods.size, podAllocationSize)) { val (executorId, pod) = allocateNewExecutorPod(nodeToLocalTaskCount) runningExecutorsToPods.put(executorId, pod) runningPodsToExecutors.put(pod.getMetadata.getName, executorId) @@ -143,43 +137,47 @@ private[spark] class KubernetesClusterSchedulerBackend( } } - def removeFailedExecutors(): Unit = { - val localRunningExecutorsToPods = RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.toMap - } - executorsToRemove.foreach { case (executorId) => - localRunningExecutorsToPods.get(executorId).map { pod: Pod => - failedPods.get(pod.getMetadata.getName).map { executorExited: ExecutorExited => - logDebug(s"Removing executor $executorId with loss reason " + executorExited.message) - removeExecutor(executorId, executorExited) - if (!executorExited.exitCausedByApp) { - executorsToRecover.add(executorId) - } - }.getOrElse(removeExecutorOrIncrementLossReasonCheckCount(executorId)) - }.getOrElse(removeExecutorOrIncrementLossReasonCheckCount(executorId)) - - executorsToRecover.foreach(executorId => { - executorsToRemove -= executorId - executorReasonChecks -= executorId - RUNNING_EXECUTOR_PODS_LOCK.synchronized { - runningExecutorsToPods.remove(executorId).map { pod: Pod => - kubernetesClient.pods().delete(pod) - runningPodsToExecutors.remove(pod.getMetadata.getName) - }.getOrElse(logWarning(s"Unable to remove pod for unknown executor $executorId")) + def handleDisconnectedExecutors(): Unit = { + // For each disconnected executor, synchronize with the loss reasons that may have been found + // by the executor pod watcher. If the loss reason was discovered by the watcher, + // inform the parent class with removeExecutor. + val disconnectedPodsByExecutorIdPendingRemovalCopy = + Map.empty ++ disconnectedPodsByExecutorIdPendingRemoval + disconnectedPodsByExecutorIdPendingRemovalCopy.foreach { case (executorId, executorPod) => + val knownExitReason = podsWithKnownExitReasons.remove(executorPod.getMetadata.getName) + knownExitReason.fold { + removeExecutorOrIncrementLossReasonCheckCount(executorId) + } { executorExited => + logDebug(s"Removing executor $executorId with loss reason " + executorExited.message) + removeExecutor(executorId, executorExited) + // We keep around executors that have exit conditions caused by the application. This + // allows them to be debugged later on. Otherwise, mark them as to be deleted from the + // the API server. + if (!executorExited.exitCausedByApp) { + deleteExecutorFromClusterAndDataStructures(executorId) } - }) - executorsToRecover.clear() + } } } def removeExecutorOrIncrementLossReasonCheckCount(executorId: String): Unit = { - val reasonCheckCount = executorReasonChecks.getOrElse(executorId, 0) - if (reasonCheckCount > MAX_EXECUTOR_LOST_REASON_CHECKS) { - removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons")) - executorsToRecover.add(executorId) - executorReasonChecks -= executorId + val reasonCheckCount = executorReasonCheckAttemptCounts.getOrElse(executorId, 0) + if (reasonCheckCount >= MAX_EXECUTOR_LOST_REASON_CHECKS) { + removeExecutor(executorId, SlaveLost("Executor lost for unknown reasons.")) + deleteExecutorFromClusterAndDataStructures(executorId) } else { - executorReasonChecks.put(executorId, reasonCheckCount + 1) + executorReasonCheckAttemptCounts.put(executorId, reasonCheckCount + 1) + } + } + + def deleteExecutorFromClusterAndDataStructures(executorId: String): Unit = { + disconnectedPodsByExecutorIdPendingRemoval -= executorId + executorReasonCheckAttemptCounts -= executorId + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.remove(executorId).map { pod => + kubernetesClient.pods().delete(pod) + runningPodsToExecutors.remove(pod.getMetadata.getName) + }.getOrElse(logWarning(s"Unable to remove pod for unknown executor $executorId")) } } } @@ -214,18 +212,18 @@ private[spark] class KubernetesClusterSchedulerBackend( .withLabel(SPARK_APP_ID_LABEL, applicationId()) .watch(new ExecutorPodsWatcher())) - allocator.scheduleWithFixedDelay( - allocatorRunnable, 0, podAllocationInterval, TimeUnit.SECONDS) + allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable, 0L, podAllocationInterval, TimeUnit.SECONDS) shuffleManager.foreach(_.start(applicationId())) - if (!Utils.isDynamicAllocationEnabled(sc.conf)) { + if (!Utils.isDynamicAllocationEnabled(conf)) { doRequestTotalExecutors(initialExecutors) } } override def stop(): Unit = { // stop allocation of new resources and caches. - allocator.shutdown() + allocatorExecutor.shutdown() shuffleManager.foreach(_.stop()) // send stop message to executors so they shut down cleanly @@ -298,7 +296,7 @@ private[spark] class KubernetesClusterSchedulerBackend( executorId, applicationId(), driverUrl, - sc.conf.getExecutorEnv, + conf.getExecutorEnv, driverPod, nodeToLocalTaskCount) try { @@ -318,11 +316,14 @@ private[spark] class KubernetesClusterSchedulerBackend( override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future[Boolean] { RUNNING_EXECUTOR_PODS_LOCK.synchronized { for (executor <- executorIds) { - runningExecutorsToPods.remove(executor) match { - case Some(pod) => - kubernetesClient.pods().delete(pod) - runningPodsToExecutors.remove(pod.getMetadata.getName) - case None => logWarning(s"Unable to remove pod for unknown executor $executor") + val maybeRemovedExecutor = runningExecutorsToPods.remove(executor) + maybeRemovedExecutor.foreach { executorPod => + kubernetesClient.pods().delete(executorPod) + disconnectedPodsByExecutorIdPendingRemoval(executor) = executorPod + runningPodsToExecutors.remove(executorPod.getMetadata.getName) + } + if (maybeRemovedExecutor.isEmpty) { + logWarning(s"Unable to remove pod for unknown executor $executor") } } } @@ -396,10 +397,9 @@ private[spark] class KubernetesClusterSchedulerBackend( } def handleErroredPod(pod: Pod): Unit = { - val alreadyReleased = isPodAlreadyReleased(pod) val containerExitStatus = getExecutorExitStatus(pod) // container was probably actively killed by the driver. - val exitReason = if (alreadyReleased) { + val exitReason = if (isPodAlreadyReleased(pod)) { ExecutorExited(containerExitStatus, exitCausedByApp = false, s"Container in pod " + pod.getMetadata.getName + " exited from explicit termination request.") @@ -411,17 +411,23 @@ private[spark] class KubernetesClusterSchedulerBackend( // Here we can't be sure that that exit was caused by the application but this seems // to be the right default since we know the pod was not explicitly deleted by // the user. - "Pod exited with following container exit status code " + containerExitStatus + s"Pod ${pod.getMetadata.getName}'s executor container exited with exit status" + + s" code $containerExitStatus." } ExecutorExited(containerExitStatus, exitCausedByApp = true, containerExitReason) } - failedPods.put(pod.getMetadata.getName, exitReason) + podsWithKnownExitReasons.put(pod.getMetadata.getName, exitReason) } def handleDeletedPod(pod: Pod): Unit = { - val exitReason = ExecutorExited(getExecutorExitStatus(pod), exitCausedByApp = false, - "Pod " + pod.getMetadata.getName + " deleted or lost.") - failedPods.put(pod.getMetadata.getName, exitReason) + val exitMessage = if (isPodAlreadyReleased(pod)) { + s"Container in pod ${pod.getMetadata.getName} exited from explicit termination request." + } else { + s"Pod ${pod.getMetadata.getName} deleted or lost." + } + val exitReason = ExecutorExited( + getExecutorExitStatus(pod), exitCausedByApp = false, exitMessage) + podsWithKnownExitReasons.put(pod.getMetadata.getName, exitReason) } } @@ -433,12 +439,15 @@ private[spark] class KubernetesClusterSchedulerBackend( rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) extends DriverEndpoint(rpcEnv, sparkProperties) { - private val externalShufflePort = conf.getInt("spark.shuffle.service.port", 7337) override def onDisconnected(rpcAddress: RpcAddress): Unit = { addressToExecutorId.get(rpcAddress).foreach { executorId => if (disableExecutor(executorId)) { - executorsToRemove.add(executorId) + RUNNING_EXECUTOR_PODS_LOCK.synchronized { + runningExecutorsToPods.get(executorId).foreach { pod => + disconnectedPodsByExecutorIdPendingRemoval(executorId) = pod + } + } } } } @@ -448,7 +457,7 @@ private[spark] class KubernetesClusterSchedulerBackend( new PartialFunction[Any, Unit]() { override def isDefinedAt(msg: Any): Boolean = { msg match { - case RetrieveSparkAppConfig(executorId) => + case RetrieveSparkAppConfig(_) => shuffleManager.isDefined case _ => false } @@ -477,11 +486,12 @@ private[spark] class KubernetesClusterSchedulerBackend( } private object KubernetesClusterSchedulerBackend { - private val DEFAULT_STATIC_PORT = 10000 - private val EXECUTOR_ID_COUNTER = new AtomicLong(0L) private val VMEM_EXCEEDED_EXIT_CODE = -103 private val PMEM_EXCEEDED_EXIT_CODE = -104 private val UNKNOWN_EXIT_CODE = -111 + // Number of times we are allowed check for the loss reason for an executor before we give up + // and assume the executor failed for good, and attribute it to a framework fault. + val MAX_EXECUTOR_LOST_REASON_CHECKS = 10 def memLimitExceededLogMessage(diagnostics: String): String = { s"Pod/Container killed for exceeding memory limits. $diagnostics" + diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala new file mode 100644 index 0000000000000..b30d1c2543bea --- /dev/null +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackendSuite.scala @@ -0,0 +1,383 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import java.util.concurrent.{ExecutorService, ScheduledExecutorService, TimeUnit} + +import io.fabric8.kubernetes.api.model.{DoneablePod, Pod, PodBuilder, PodList} +import io.fabric8.kubernetes.client.{KubernetesClient, Watch, Watcher} +import io.fabric8.kubernetes.client.Watcher.Action +import io.fabric8.kubernetes.client.dsl.{FilterWatchListDeletable, MixedOperation, NonNamespaceOperation, PodResource} +import org.mockito.{AdditionalAnswers, ArgumentCaptor, Mock, MockitoAnnotations} +import org.mockito.Matchers.{any, eq => mockitoEq} +import org.mockito.Mockito.{doNothing, never, times, verify, when} +import org.scalatest.BeforeAndAfter +import org.scalatest.mock.MockitoSugar._ +import scala.collection.JavaConverters._ +import scala.concurrent.Future + +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.deploy.kubernetes.config._ +import org.apache.spark.deploy.kubernetes.constants._ +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpoint, RpcEndpointAddress, RpcEndpointRef, RpcEnv, RpcTimeout} +import org.apache.spark.scheduler.{ExecutorExited, LiveListenerBus, SlaveLost, TaskSchedulerImpl} +import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages.{RegisterExecutor, RemoveExecutor} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend + +private[spark] class KubernetesClusterSchedulerBackendSuite + extends SparkFunSuite with BeforeAndAfter { + + private val APP_ID = "test-spark-app" + private val DRIVER_POD_NAME = "spark-driver-pod" + private val NAMESPACE = "test-namespace" + private val SPARK_DRIVER_HOST = "localhost" + private val SPARK_DRIVER_PORT = 7077 + private val POD_ALLOCATION_INTERVAL = 60L + private val DRIVER_URL = RpcEndpointAddress( + SPARK_DRIVER_HOST, SPARK_DRIVER_PORT, CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString + private val FIRST_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod1") + .endMetadata() + .withNewSpec() + .withNodeName("node1") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.100") + .endStatus() + .build() + private val SECOND_EXECUTOR_POD = new PodBuilder() + .withNewMetadata() + .withName("pod2") + .endMetadata() + .withNewSpec() + .withNodeName("node2") + .endSpec() + .withNewStatus() + .withHostIP("192.168.99.101") + .endStatus() + .build() + + private type PODS = MixedOperation[Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + private type LABELLED_PODS = FilterWatchListDeletable[ + Pod, PodList, java.lang.Boolean, Watch, Watcher[Pod]] + private type IN_NAMESPACE_PODS = NonNamespaceOperation[ + Pod, PodList, DoneablePod, PodResource[Pod, DoneablePod]] + + @Mock + private var sparkContext: SparkContext = _ + + @Mock + private var listenerBus: LiveListenerBus = _ + + @Mock + private var taskSchedulerImpl: TaskSchedulerImpl = _ + + @Mock + private var allocatorExecutor: ScheduledExecutorService = _ + + @Mock + private var requestExecutorsService: ExecutorService = _ + + @Mock + private var executorPodFactory: ExecutorPodFactory = _ + + @Mock + private var shuffleManager: KubernetesExternalShuffleManager = _ + + @Mock + private var kubernetesClient: KubernetesClient = _ + + @Mock + private var podOperations: PODS = _ + + @Mock + private var podsWithLabelOperations: LABELLED_PODS = _ + + @Mock + private var podsInNamespace: IN_NAMESPACE_PODS = _ + + @Mock + private var podsWithDriverName: PodResource[Pod, DoneablePod] = _ + + @Mock + private var rpcEnv: RpcEnv = _ + + @Mock + private var driverEndpointRef: RpcEndpointRef = _ + + @Mock + private var executorPodsWatch: Watch = _ + + private var sparkConf: SparkConf = _ + private var executorPodsWatcherArgument: ArgumentCaptor[Watcher[Pod]] = _ + private var allocatorRunnable: ArgumentCaptor[Runnable] = _ + private var requestExecutorRunnable: ArgumentCaptor[Runnable] = _ + private var driverEndpoint: ArgumentCaptor[RpcEndpoint] = _ + + private val driverPod = new PodBuilder() + .withNewMetadata() + .withName(DRIVER_POD_NAME) + .addToLabels(SPARK_APP_ID_LABEL, APP_ID) + .addToLabels(SPARK_ROLE_LABEL, SPARK_POD_DRIVER_ROLE) + .endMetadata() + .build() + + before { + MockitoAnnotations.initMocks(this) + sparkConf = new SparkConf() + .set("spark.app.id", APP_ID) + .set(KUBERNETES_DRIVER_POD_NAME, DRIVER_POD_NAME) + .set(KUBERNETES_NAMESPACE, NAMESPACE) + .set("spark.driver.host", SPARK_DRIVER_HOST) + .set("spark.driver.port", SPARK_DRIVER_PORT.toString) + .set(KUBERNETES_ALLOCATION_BATCH_DELAY, POD_ALLOCATION_INTERVAL) + executorPodsWatcherArgument = ArgumentCaptor.forClass(classOf[Watcher[Pod]]) + allocatorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + requestExecutorRunnable = ArgumentCaptor.forClass(classOf[Runnable]) + driverEndpoint = ArgumentCaptor.forClass(classOf[RpcEndpoint]) + when(sparkContext.conf).thenReturn(sparkConf) + when(sparkContext.listenerBus).thenReturn(listenerBus) + when(taskSchedulerImpl.sc).thenReturn(sparkContext) + when(kubernetesClient.pods()).thenReturn(podOperations) + when(podOperations.withLabel(SPARK_APP_ID_LABEL, APP_ID)).thenReturn(podsWithLabelOperations) + when(podsWithLabelOperations.watch(executorPodsWatcherArgument.capture())) + .thenReturn(executorPodsWatch) + when(podOperations.inNamespace(NAMESPACE)).thenReturn(podsInNamespace) + when(podsInNamespace.withName(DRIVER_POD_NAME)).thenReturn(podsWithDriverName) + when(podsWithDriverName.get()).thenReturn(driverPod) + when(allocatorExecutor.scheduleWithFixedDelay( + allocatorRunnable.capture(), + mockitoEq(0L), + mockitoEq(POD_ALLOCATION_INTERVAL), + mockitoEq(TimeUnit.SECONDS))).thenReturn(null) + // Creating Futures in Scala backed by a Java executor service resolves to running + // ExecutorService#execute (as opposed to submit) + doNothing().when(requestExecutorsService).execute(requestExecutorRunnable.capture()) + when(rpcEnv.setupEndpoint( + mockitoEq(CoarseGrainedSchedulerBackend.ENDPOINT_NAME), driverEndpoint.capture())) + .thenReturn(driverEndpointRef) + when(driverEndpointRef.ask[Boolean] + (any(classOf[Any])) + (any())).thenReturn(mock[Future[Boolean]]) + } + + test("Basic lifecycle expectations when starting and stopping the scheduler.") { + val scheduler = newSchedulerBackend(true) + scheduler.start() + verify(shuffleManager).start(APP_ID) + assert(executorPodsWatcherArgument.getValue != null) + assert(allocatorRunnable.getValue != null) + scheduler.stop() + verify(shuffleManager).stop() + verify(executorPodsWatch).close() + } + + test("Static allocation should request executors upon first allocator run.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend(true) + scheduler.start() + requestExecutorRunnable.getValue.run() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + verify(podOperations).create(FIRST_EXECUTOR_POD) + verify(podOperations).create(SECOND_EXECUTOR_POD) + } + + test("Killing executors deletes the executor pods") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 2) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend(true) + scheduler.start() + requestExecutorRunnable.getValue.run() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + allocatorRunnable.getValue.run() + scheduler.doKillExecutors(Seq("2")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations).delete(SECOND_EXECUTOR_POD) + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + test("Executors should be requested in batches.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 2) + val scheduler = newSchedulerBackend(true) + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).create(FIRST_EXECUTOR_POD) + verify(podOperations, never()).create(SECOND_EXECUTOR_POD) + val registerFirstExecutorMessage = RegisterExecutor( + "1", mock[RpcEndpointRef], "localhost", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + allocatorRunnable.getValue.run() + verify(podOperations).create(SECOND_EXECUTOR_POD) + } + + test("Deleting executors and then running an allocator pass after finding the loss reason" + + " should only delete the pod once.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend(true) + scheduler.start() + requestExecutorRunnable.getValue.run() + when(podOperations.create(any(classOf[Pod]))) + .thenAnswer(AdditionalAnswers.returnsFirstArg()) + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + scheduler.doRequestTotalExecutors(0) + requestExecutorRunnable.getAllValues.asScala.last.run() + scheduler.doKillExecutors(Seq("1")) + requestExecutorRunnable.getAllValues.asScala.last.run() + verify(podOperations, times(1)).delete(FIRST_EXECUTOR_POD) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + + val exitedPod = exitPod(FIRST_EXECUTOR_POD, 0) + executorPodsWatcherArgument.getValue.eventReceived(Action.DELETED, exitedPod) + allocatorRunnable.getValue.run() + verify(podOperations, times(1)).delete(FIRST_EXECUTOR_POD) + verify(driverEndpointRef, times(1)).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 0, + exitCausedByApp = false, + s"Container in pod ${exitedPod.getMetadata.getName} exited from" + + s" explicit termination request."))) + } + + test("Executors that disconnect from application errors are noted as exits caused by app.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend(true) + scheduler.start() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + executorPodsWatcherArgument.getValue.eventReceived( + Action.ERROR, exitPod(FIRST_EXECUTOR_POD, 1)) + + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + scheduler.doRequestTotalExecutors(1) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getAllValues.asScala.last.run() + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", ExecutorExited( + 1, + exitCausedByApp = true, + s"Pod ${FIRST_EXECUTOR_POD.getMetadata.getName}'s executor container exited with" + + " exit status code 1."))) + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + + test("Executors should only try to get the loss reason a number of times before giving up and" + + " removing the executor.") { + sparkConf + .set(KUBERNETES_ALLOCATION_BATCH_SIZE, 1) + .set(org.apache.spark.internal.config.EXECUTOR_INSTANCES, 1) + val scheduler = newSchedulerBackend(true) + scheduler.start() + expectPodCreationWithId(1, FIRST_EXECUTOR_POD) + when(podOperations.create(any(classOf[Pod]))).thenAnswer(AdditionalAnswers.returnsFirstArg()) + requestExecutorRunnable.getValue.run() + allocatorRunnable.getValue.run() + val executorEndpointRef = mock[RpcEndpointRef] + when(executorEndpointRef.address).thenReturn(RpcAddress("pod.example.com", 9000)) + val registerFirstExecutorMessage = RegisterExecutor( + "1", executorEndpointRef, "localhost:9000", 1, Map.empty[String, String]) + when(taskSchedulerImpl.resourceOffers(any())).thenReturn(Seq.empty) + driverEndpoint.getValue.receiveAndReply(mock[RpcCallContext]) + .apply(registerFirstExecutorMessage) + driverEndpoint.getValue.onDisconnected(executorEndpointRef.address) + 1 to KubernetesClusterSchedulerBackend.MAX_EXECUTOR_LOST_REASON_CHECKS foreach { _ => + allocatorRunnable.getValue.run() + verify(podOperations, never()).delete(FIRST_EXECUTOR_POD) + } + expectPodCreationWithId(2, SECOND_EXECUTOR_POD) + allocatorRunnable.getValue.run() + verify(podOperations).delete(FIRST_EXECUTOR_POD) + verify(driverEndpointRef).ask[Boolean]( + RemoveExecutor("1", SlaveLost("Executor lost for unknown reasons."))) + } + + private def newSchedulerBackend(externalShuffle: Boolean): KubernetesClusterSchedulerBackend = { + new KubernetesClusterSchedulerBackend( + taskSchedulerImpl, + rpcEnv, + executorPodFactory, + if (externalShuffle) Some(shuffleManager) else None, + kubernetesClient, + allocatorExecutor, + requestExecutorsService) + } + + private def exitPod(basePod: Pod, exitCode: Int): Pod = { + new PodBuilder(FIRST_EXECUTOR_POD) + .editStatus() + .addNewContainerStatus() + .withNewState() + .withNewTerminated() + .withExitCode(exitCode) + .endTerminated() + .endState() + .endContainerStatus() + .endStatus() + .build() + } + + private def expectPodCreationWithId(executorId: Int, expectedPod: Pod): Unit = { + when(executorPodFactory.createExecutorPod( + executorId.toString, + APP_ID, + DRIVER_URL, + sparkConf.getExecutorEnv, + driverPod, + Map.empty)).thenReturn(expectedPod) + } + +}