diff --git a/resource-managers/kubernetes/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala b/resource-managers/kubernetes/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala index df16e742dd7ab..7b43b4faf6201 100644 --- a/resource-managers/kubernetes/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala +++ b/resource-managers/kubernetes/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala @@ -45,7 +45,6 @@ private[spark] class Client(val args: ClientArguments, scheduler.stop() shutdownLatch.countDown() System.clearProperty("SPARK_KUBERNETES_MODE") - System.clearProperty("SPARK_IMAGE_PULLSECRET") } def awaitShutdown(): Unit = { diff --git a/resource-managers/kubernetes/src/main/scala/org/apache/spark/deploy/kubernetes/SparkJobResource.scala b/resource-managers/kubernetes/src/main/scala/org/apache/spark/deploy/kubernetes/SparkJobResource.scala index 52e9ff3e05661..0f3f6ec956716 100644 --- a/resource-managers/kubernetes/src/main/scala/org/apache/spark/deploy/kubernetes/SparkJobResource.scala +++ b/resource-managers/kubernetes/src/main/scala/org/apache/spark/deploy/kubernetes/SparkJobResource.scala @@ -17,23 +17,49 @@ package org.apache.spark.deploy.kubernetes +import java.nio.file.{Files, Paths} +import java.util.concurrent.TimeUnit + +import scala.concurrent.{blocking, ExecutionContext, Future, Promise} +import scala.util.{Failure, Success, Try} +import scala.util.control.Breaks.{break, breakable} + import io.fabric8.kubernetes.client.{BaseClient, KubernetesClient} -import okhttp3.{MediaType, OkHttpClient, Request, RequestBody} -import org.json4s.{CustomSerializer, DefaultFormats, JString} -import org.json4s.JsonAST.JNull +import okhttp3._ +import okio.{Buffer, BufferedSource} +import org.json4s.{DefaultFormats, Formats} import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.json4s.jackson.Serialization.{read, write} +import org.apache.spark.deploy.kubernetes.SparkJobResource._ import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.cluster.kubernetes.JobState -import org.apache.spark.scheduler.cluster.kubernetes.JobState._ +import org.apache.spark.scheduler.cluster.kubernetes.JobResourceCreateCall +import org.apache.spark.scheduler.cluster.kubernetes.JobResourceRUDCalls +import org.apache.spark.scheduler.cluster.kubernetes.JobStateSerDe /* - * Representation of a Spark Job State in Kubernetes + * CRUD + Watch operations on a Spark Job State in Kubernetes * */ -object SparkJobResource { +private[spark] object SparkJobResource { + + implicit val formats: Formats = DefaultFormats + JobStateSerDe + + val kind = "SparkJob" + val apiVersion = "apache.io/v1" + val apiEndpoint = s"apis/$apiVersion/namespaces/%s/sparkjobs" + + def getHttpClient(client: BaseClient): OkHttpClient = { + val field = classOf[BaseClient].getDeclaredField("httpClient") + try { + field.setAccessible(true) + field.get(client).asInstanceOf[OkHttpClient] + } finally { + field.setAccessible(false) + } + } + case class Metadata(name: String, uid: Option[String] = None, labels: Option[Map[String, String]] = None, @@ -44,120 +70,241 @@ object SparkJobResource { metadata: Metadata, spec: Map[String, Any]) - case object JobStateSerDe - extends CustomSerializer[JobState](format => - ({ - case JString("SUBMITTED") => JobState.SUBMITTED - case JString("QUEUED") => JobState.QUEUED - case JString("RUNNING") => JobState.RUNNING - case JString("FINISHED") => JobState.FINISHED - case JString("KILLED") => JobState.KILLED - case JString("FAILED") => JobState.FAILED - case JNull => - throw new UnsupportedOperationException("No JobState Specified") - }, { - case JobState.FAILED => JString("FAILED") - case JobState.SUBMITTED => JString("SUBMITTED") - case JobState.KILLED => JString("KILLED") - case JobState.FINISHED => JString("FINISHED") - case JobState.QUEUED => JString("QUEUED") - case JobState.RUNNING => JString("RUNNING") - })) + case class WatchObject(`type`: String, `object`: SparkJobState) } -class SparkJobResource(client: KubernetesClient) extends Logging { - import SparkJobResource._ +private[spark] class SparkJobCreateResource(client: KubernetesClient, namespace: String) + extends JobResourceCreateCall with Logging { - implicit val formats = DefaultFormats + JobStateSerDe private val httpClient = getHttpClient(client.asInstanceOf[BaseClient]) - private val kind = "SparkJob" - private val apiVersion = "apache.io/v1" - private val apiEndpoint = s"${client.getMasterUrl}apis/$apiVersion/" + - s"namespaces/${client.getNamespace}/sparkjobs" - - private def getHttpClient(client: BaseClient): OkHttpClient = { - val field = classOf[BaseClient].getDeclaredField("httpClient") - try { - field.setAccessible(true) - field.get(client).asInstanceOf[OkHttpClient] - } finally { - field.setAccessible(false) - } - } - /* - * using a Map as an argument here allows adding more info into the Object if needed - * */ - def createJobObject(name: String, keyValuePairs: Map[String, Any]): Unit = { + /** + * Using a Map as an argument here allows adding more info into the Object if needed + * This is currently called on the client machine. We can avoid the token use. + * */ + override def createJobObject(name: String, keyValuePairs: Map[String, Any]): Unit = { val resourceObject = SparkJobState(apiVersion, kind, Metadata(name), keyValuePairs) val payload = parse(write(resourceObject)) val requestBody = RequestBody .create(MediaType.parse("application/json"), compact(render(payload))) - val request = - new Request.Builder().post(requestBody).url(apiEndpoint).build() - + val request = new Request.Builder() + .post(requestBody) + .url(s"${client.getMasterUrl}${apiEndpoint.format(namespace)}") + .build() + logDebug(s"Create Request: $request") val response = httpClient.newCall(request).execute() - if (response.code() == 201) { - logInfo( - s"Successfully posted resource $name: " + - s"${pretty(render(parse(write(resourceObject))))}") - } else { + if (!response.isSuccessful) { + response.body().close() val msg = s"Failed to post resource $name. ${response.toString}. ${compact(render(payload))}" logError(msg) throw new SparkException(msg) } + response.body().close() + logDebug(s"Successfully posted resource $name: " + + s"${pretty(render(parse(write(resourceObject))))}") } +} + +private[spark] class SparkJobRUDResource( + client: KubernetesClient, + namespace: String, + ec: ExecutionContext) extends JobResourceRUDCalls with Logging { + + private val protocol = "https://" + + private val httpClient = getHttpClient(client.asInstanceOf[BaseClient]) + + private var watchSource: BufferedSource = _ + + private lazy val buffer = new Buffer() - def updateJobObject(name: String, value: String, fieldPath: String): Unit = { + // Since this will be running inside a pod + // we can access the pods token and use it with the Authorization header when + // making rest calls to the k8s Api + private val kubeToken = { + val path = Paths.get("/var/run/secrets/kubernetes.io/serviceaccount/token") + val tok = Try(new String(Files.readAllBytes(path))) match { + case Success(some) => Option(some) + case Failure(e: Throwable) => logError(s"${e.getMessage}") + None + } + tok.map(t => t).getOrElse{ + // Log a warning just in case, but this should almost certainly never happen + logWarning("Error while retrieving pod token") + "" + } + } + + // we can also get the host from the environment variable + private val k8sServiceHost = { + val host = Try(sys.env("KUBERNETES_SERVICE_HOST")) match { + case Success(h) => Option(h) + case Failure(_) => None + } + host.map(h => h).getOrElse{ + // Log a warning just in case, but this should almost certainly never happen + logWarning("Error while retrieving k8s host address") + "127.0.0.1" + } + } + + // the port from the environment variable + private val k8sPort = { + val port = Try(sys.env("KUBERNETES_PORT_443_TCP_PORT")) match { + case Success(p) => Option(p) + case Failure(_) => None + } + port.map(p => p).getOrElse{ + // Log a warning just in case, but this should almost certainly never happen + logWarning("Error while retrieving k8s host port") + "8001" + } + } + + private def executeBlocking(cb: => WatchObject): Future[WatchObject] = { + val p = Promise[WatchObject]() + ec.execute(new Runnable { + override def run(): Unit = { + try { + p.trySuccess(blocking(cb)) + } catch { + case e: Throwable => p.tryFailure(e) + } + } + }) + p.future + } + + // Serves as a way to interrupt to the watcher thread. + // This closes the source the watcher is reading from and as a result triggers promise completion + def stopWatcher(): Unit = { + if (watchSource != null) { + buffer.close() + watchSource.close() + } + } + + override def updateJobObject(name: String, value: String, fieldPath: String): Unit = { val payload = List( ("op" -> "replace") ~ ("path" -> fieldPath) ~ ("value" -> value)) val requestBody = RequestBody.create( MediaType.parse("application/json-patch+json"), compact(render(payload))) val request = new Request.Builder() - .post(requestBody) - .url(s"$apiEndpoint/$name") + .addHeader("Authorization", s"Bearer $kubeToken") + .patch(requestBody) + .url(s"$protocol$k8sServiceHost:$k8sPort/${apiEndpoint.format(namespace)}/$name") .build() + logDebug(s"Update Request: $request") val response = httpClient.newCall(request).execute() - if (response.code() == 200) { - logInfo(s"Successfully patched resource $name") - } else { + if (!response.isSuccessful) { + response.body().close() val msg = s"Failed to patch resource $name. ${response.message()}. ${compact(render(payload))}" logError(msg) - throw new SparkException(msg) + throw new SparkException(s"${response.code()} ${response.message()}") } + response.body().close() + logDebug(s"Successfully patched resource $name.") } - def deleteJobObject(name: String): Unit = { - val request = - new Request.Builder().delete().url(s"$apiEndpoint/$name").build() + override def deleteJobObject(name: String): Unit = { + val request = new Request.Builder() + .addHeader("Authorization", s"Bearer $kubeToken") + .delete() + .url(s"$protocol$k8sServiceHost:$k8sPort/${apiEndpoint.format(namespace)}/$name") + .build() + logDebug(s"Delete Request: $request") val response = httpClient.newCall(request).execute() - if (response.code() == 200) { - logInfo(s"Successfully deleted resource $name") - } else { + if (!response.isSuccessful) { + response.body().close() val msg = - s"Failed to delete resource $name. ${response.message()}. ${request}" + s"Failed to delete resource $name. ${response.message()}. $request" logError(msg) throw new SparkException(msg) } + response.body().close() + logInfo(s"Successfully deleted resource $name") } def getJobObject(name: String): SparkJobState = { - val request = - new Request.Builder().get().url(s"$apiEndpoint/$name").build() + val request = new Request.Builder() + .addHeader("Authorization", s"Bearer $kubeToken") + .get() + .url(s"$protocol$k8sServiceHost:$k8sPort/${apiEndpoint.format(namespace)}/$name") + .build() + logDebug(s"Get Request: $request") val response = httpClient.newCall(request).execute() - if (response.code() == 200) { - logInfo(s"Successfully retrieved resource $name") - read[SparkJobState](response.body().string()) - } else { + if (!response.isSuccessful) { + response.body().close() val msg = s"Failed to retrieve resource $name. ${response.message()}" logError(msg) throw new SparkException(msg) } + logInfo(s"Successfully retrieved resource $name") + read[SparkJobState](response.body().string()) + } + + /** + * This method has an helper method that blocks to watch the object. + * The future is completed on a Delete event or source exhaustion. + * This method relies on the assumption of one sparkjob per namespace + */ + override def watchJobObject(): Future[WatchObject] = { + val watchClient = httpClient.newBuilder().readTimeout(0, TimeUnit.MILLISECONDS).build() + val request = new Request.Builder() + .addHeader("Authorization", s"Bearer $kubeToken") + .get() + .url(s"$protocol$k8sServiceHost:$k8sPort/${apiEndpoint.format(namespace)}?watch=true") + .build() + logDebug(s"Watch Request: $request") + val resp = watchClient.newCall(request).execute() + if (!resp.isSuccessful) { + resp.body().close() + val msg = s"Failed to start watch on resource ${resp.code()} ${resp.message()}" + logWarning(msg) + throw new SparkException(msg) + } + logInfo(s"Starting watch on jobResource") + watchJobObjectUtil(resp) } + /** + * This method has a blocking call - wait on SSE - inside it. + * However it is sent off in a new thread + */ + private def watchJobObjectUtil(response: Response): Future[WatchObject] = { + @volatile var wo: WatchObject = null + watchSource = response.body().source() + executeBlocking { + breakable { + // This will block until there are bytes to read or the source is exhausted. + while (!watchSource.exhausted()) { + watchSource.read(buffer, 8192) match { + case -1 => + cleanUpListener(watchSource, buffer) + throw new SparkException("Source is exhausted and object state is unknown") + case _ => + wo = read[WatchObject](buffer.readUtf8()) + wo match { + case WatchObject("DELETED", w) => + logInfo(s"${w.metadata.name} has been deleted") + cleanUpListener(watchSource, buffer) + case WatchObject(e, _) => logInfo(s"$e event. Still watching") + } + } + } + } + wo + } + } + + private def cleanUpListener(source: BufferedSource, buffer: Buffer): Unit = { + buffer.close() + source.close() + break() + } } diff --git a/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/JobResourceCreateCall.scala b/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/JobResourceCreateCall.scala new file mode 100644 index 0000000000000..42127486eb569 --- /dev/null +++ b/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/JobResourceCreateCall.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.scheduler.cluster.kubernetes + +import scala.concurrent.Future + +import org.apache.spark.deploy.kubernetes.SparkJobResource.{SparkJobState, WatchObject} + +/** + * Isolated this since the method is called on the client machine + */ +trait JobResourceCreateCall { + def createJobObject(name: String, keyValuePairs: Map[String, Any]): Unit +} + +/** + * RUD and W - Read, Update, Delete & Watch + */ +trait JobResourceRUDCalls { + def deleteJobObject(name: String): Unit + + def getJobObject(name: String): SparkJobState + + def updateJobObject(name: String, value: String, fieldPath: String): Unit + + def watchJobObject(): Future[WatchObject] +} diff --git a/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/JobStateSerDe.scala b/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/JobStateSerDe.scala new file mode 100644 index 0000000000000..9fb9b3742ab0d --- /dev/null +++ b/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/JobStateSerDe.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster.kubernetes + +import org.json4s.{CustomSerializer, JString} +import org.json4s.JsonAST.JNull + +import org.apache.spark.scheduler.cluster.kubernetes.JobState.JobState + +/** + * JobState Serializer and Deserializer + */ +object JobStateSerDe extends CustomSerializer[JobState](_ => + ({ + case JString("SUBMITTED") => JobState.SUBMITTED + case JString("QUEUED") => JobState.QUEUED + case JString("RUNNING") => JobState.RUNNING + case JString("FINISHED") => JobState.FINISHED + case JString("KILLED") => JobState.KILLED + case JString("FAILED") => JobState.FAILED + case JNull => + throw new UnsupportedOperationException("No JobState Specified") + }, { + case JobState.FAILED => JString("FAILED") + case JobState.SUBMITTED => JString("SUBMITTED") + case JobState.KILLED => JString("KILLED") + case JobState.FINISHED => JString("FINISHED") + case JobState.QUEUED => JString("QUEUED") + case JobState.RUNNING => JString("RUNNING") + }) +) diff --git a/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterScheduler.scala b/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterScheduler.scala index 8aaa5193db968..6f1b37601d171 100644 --- a/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterScheduler.scala +++ b/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterScheduler.scala @@ -22,14 +22,14 @@ import java.util.Date import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.util.Random +import scala.util.{Failure, Random, Success, Try} import io.fabric8.kubernetes.api.model.{Pod, PodBuilder, PodFluent, ServiceBuilder} import io.fabric8.kubernetes.client.{ConfigBuilder, DefaultKubernetesClient, KubernetesClient, KubernetesClientException} import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.deploy.Command -import org.apache.spark.deploy.kubernetes.ClientArguments +import org.apache.spark.deploy.kubernetes.{ClientArguments, SparkJobCreateResource} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.util.Utils @@ -49,19 +49,19 @@ private[spark] class KubernetesClusterScheduler(conf: SparkConf) private val DEFAULT_CORES = 1.0 logInfo("Created KubernetesClusterScheduler instance") - val client = setupKubernetesClient() - val driverName = s"spark-driver-${Random.alphanumeric take 5 mkString ""}".toLowerCase() - val svcName = s"spark-svc-${Random.alphanumeric take 5 mkString ""}".toLowerCase() - val nameSpace = conf.get( + private val client = setupKubernetesClient() + private val driverName = s"spark-driver-${Random.alphanumeric take 5 mkString ""}".toLowerCase() + private val svcName = s"spark-svc-${Random.alphanumeric take 5 mkString ""}".toLowerCase() + private val nameSpace = conf.get( "spark.kubernetes.namespace", KubernetesClusterScheduler.defaultNameSpace) - val serviceAccountName = conf.get( + private val serviceAccountName = conf.get( "spark.kubernetes.serviceAccountName", KubernetesClusterScheduler.defaultServiceAccountName) // Anything that should either not be passed to driver config in the cluster, or // that is going to be explicitly managed as command argument to the driver pod - val confBlackList = scala.collection.Set( + private val confBlackList = scala.collection.Set( "spark.master", "spark.app.name", "spark.submit.deployMode", @@ -69,20 +69,41 @@ private[spark] class KubernetesClusterScheduler(conf: SparkConf) "spark.dynamicAllocation.enabled", "spark.shuffle.service.enabled") - val instances = conf.get(EXECUTOR_INSTANCES).getOrElse(1) + private val instances = conf.get(EXECUTOR_INSTANCES).getOrElse(1) // image needs to support shim scripts "/opt/driver.sh" and "/opt/executor.sh" - private val sparkDriverImage = conf.getOption("spark.kubernetes.sparkImage").getOrElse { + private val sparkImage = conf.getOption("spark.kubernetes.sparkImage").getOrElse { // TODO: this needs to default to some standard Apache Spark image throw new SparkException("Spark image not set. Please configure spark.kubernetes.sparkImage") } + private val sparkJobResource = new SparkJobCreateResource(client, nameSpace) + private val imagePullSecret = conf.get("spark.kubernetes.imagePullSecret", "") private val isImagePullSecretSet = isSecretRunning(imagePullSecret) logWarning("Instances: " + instances) def start(args: ClientArguments): Unit = { + val sparkJobResourceName = + s"sparkJob-$nameSpace-${Random.alphanumeric take 5 mkString ""}".toLowerCase() + val keyValuePairs = Map( + "num-executors" -> instances, + "image" -> sparkImage, + "state" -> JobState.QUEUED, + "spark-driver" -> driverName, + "spark-svc" -> svcName) + + Try(sparkJobResource.createJobObject(sparkJobResourceName, keyValuePairs)) match { + case Success(_) => + conf.set("spark.kubernetes.jobResourceName", sparkJobResourceName) + conf.set("spark.kubernetes.jobResourceSet", "true") + logInfo(s"Object with name: $sparkJobResourceName posted to k8s successfully") + case Failure(e: Throwable) => // log and carry on + conf.set("spark.kubernetes.jobResourceSet", "false") + logInfo(s"Failed to post object $sparkJobResourceName due to ${e.getMessage}") + } + startDriver(client, args) } @@ -104,8 +125,8 @@ private[spark] class KubernetesClusterScheduler(conf: SparkConf) val clientJarUri = args.userJar // This is the kubernetes master we're launching on. - val kubernetesHost = "k8s://" + client.getMasterUrl().getHost() - logInfo("Using as kubernetes-master: " + kubernetesHost.toString()) + val kubernetesHost = "k8s://" + client.getMasterUrl.getHost + logInfo("Using as kubernetes-master: " + kubernetesHost.toString) val submitArgs = scala.collection.mutable.ArrayBuffer.empty[String] submitArgs ++= Vector( @@ -116,7 +137,10 @@ private[spark] class KubernetesClusterScheduler(conf: SparkConf) s"--conf=spark.executor.jar=$clientJarUri", s"--conf=spark.executor.instances=$instances", s"--conf=spark.kubernetes.namespace=$nameSpace", - s"--conf=spark.kubernetes.sparkImage=$sparkDriverImage") + s"--conf=spark.kubernetes.sparkImage=$sparkImage") + + submitArgs ++= conf.getAll.collect { + case (name, value) if !confBlackList.contains(name) => s"--conf $name=$value" } if (conf.getBoolean("spark.dynamicAllocation.enabled", false)) { submitArgs ++= Vector( @@ -165,7 +189,8 @@ private[spark] class KubernetesClusterScheduler(conf: SparkConf) try { client.secrets().inNamespace(nameSpace).withName(name).get() != null } catch { - case e: KubernetesClientException => false + case e: KubernetesClientException => logError(e.getMessage) + false // is this enough to throw a SparkException? For now default to false } } @@ -183,7 +208,7 @@ private[spark] class KubernetesClusterScheduler(conf: SparkConf) .withServiceAccount(serviceAccountName) .addNewContainer() .withName("spark-driver") - .withImage(sparkDriverImage) + .withImage(sparkImage) .withImagePullPolicy("Always") .withCommand(s"/opt/driver.sh") .withArgs(submitArgs: _*) @@ -194,15 +219,15 @@ private[spark] class KubernetesClusterScheduler(conf: SparkConf) private def buildPodUtil(pod: PodFluent.SpecNested[PodBuilder]): Pod = { if (isImagePullSecretSet) { - System.setProperty("SPARK_IMAGE_PULLSECRET", imagePullSecret) pod.addNewImagePullSecret(imagePullSecret).endSpec().build() } else { + conf.remove("spark.kubernetes.imagePullSecret") pod.endSpec().build() } } def setupKubernetesClient(): KubernetesClient = { - val sparkHost = new java.net.URI(conf.get("spark.master")).getHost() + val sparkHost = new java.net.URI(conf.get("spark.master")).getHost var config = new ConfigBuilder().withNamespace(nameSpace) if (sparkHost != "default") { diff --git a/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala b/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala index ff676c0394768..40dc66d342584 100644 --- a/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala +++ b/resource-managers/kubernetes/src/main/scala/org/apache/spark/scheduler/cluster/kubernetes/KubernetesClusterSchedulerBackend.scala @@ -20,46 +20,57 @@ package org.apache.spark.scheduler.cluster.kubernetes import scala.collection.{concurrent, mutable} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Future -import scala.util.Random +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.util.{Failure, Random, Success, Try} import io.fabric8.kubernetes.api.model.{Pod, PodBuilder, PodFluent} import io.fabric8.kubernetes.client.DefaultKubernetesClient import org.apache.spark.{SparkConf, SparkContext, SparkException} -import org.apache.spark.deploy.kubernetes.SparkJobResource +import org.apache.spark.deploy.kubernetes.SparkJobResource.WatchObject +import org.apache.spark.deploy.kubernetes.SparkJobRUDResource import org.apache.spark.internal.config._ import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class KubernetesClusterSchedulerBackend( - scheduler: TaskSchedulerImpl, - sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { + scheduler: TaskSchedulerImpl, + sc: SparkContext) extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { private val client = new DefaultKubernetesClient() - val DEFAULT_NUMBER_EXECUTORS = 2 - val podPrefix = s"spark-executor-${Random.alphanumeric take 5 mkString ""}".toLowerCase() + private val DEFAULT_NUMBER_EXECUTORS = 2 + private val jobResourceName = conf.get("spark.kubernetes.jobResourceName", "") // using a concurrent TrieMap gets rid of possible concurrency issues // key is executor id, value is pod name - private var executorToPod = new concurrent.TrieMap[String, String] // active executors - private var shutdownToPod = new concurrent.TrieMap[String, String] // pending shutdown - private var executorID = 0 + private val executorToPod = new concurrent.TrieMap[String, String] // active executors + private val shutdownToPod = new concurrent.TrieMap[String, String] // pending shutdown - val sparkImage = conf.get("spark.kubernetes.sparkImage") - val clientJarUri = conf.get("spark.executor.jar") - val ns = conf.get( + private val sparkImage = conf.get("spark.kubernetes.sparkImage") + private val ns = conf.get( "spark.kubernetes.namespace", KubernetesClusterScheduler.defaultNameSpace) - val dynamicExecutors = Utils.isDynamicAllocationEnabled(conf) - val sparkJobResource = new SparkJobResource(client) + private val dynamicExecutors = Utils.isDynamicAllocationEnabled(conf) - private val imagePullSecret = System.getProperty("SPARK_IMAGE_PULLSECRET", "") + private implicit val resourceWatcherPool = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonFixedThreadPool(2, "resource-watcher-pool")) + + private val sparkJobResource = new SparkJobRUDResource(client, ns, resourceWatcherPool) + + private val imagePullSecret = conf.get("spark.kubernetes.imagePullSecret", "") + + private val executorBaseName = + s"spark-executor-${Random.alphanumeric take 5 mkString ""}".toLowerCase + + private var workingWithJobResource = + conf.get("spark.kubernetes.jobResourceSet", "false").toBoolean + + private var watcherFuture: Future[WatchObject] = _ // executor back-ends take their configuration this way if (dynamicExecutors) { @@ -69,34 +80,80 @@ private[spark] class KubernetesClusterSchedulerBackend( override def start(): Unit = { super.start() - val keyValuePairs = Map("num-executors" -> getInitialTargetExecutorNumber(sc.conf), - "image" -> sparkImage, - "state" -> JobState.SUBMITTED) - try { - logInfo(s"Creating Job Resource with name. $podPrefix") - sparkJobResource.createJobObject(podPrefix, keyValuePairs) - } catch { - case e: SparkException => - logWarning(s"SparkJob object not created. ${e.getMessage}") - // SparkJob should continue if this fails as discussed on thread. - // TODO: we should short-circuit on things like update or delete - } + startLogic() createExecutorPods(getInitialTargetExecutorNumber(sc.getConf)) } + private def startLogic(): Unit = { + if (workingWithJobResource) { + logInfo(s"Updating Job Resource with name. $jobResourceName") + Try(sparkJobResource + .updateJobObject(jobResourceName, JobState.SUBMITTED.toString, "/spec/state")) match { + case Success(_) => startWatcher() + case Failure(e: SparkException) if e.getMessage startsWith "404" => + logWarning(s"Possible deletion of jobResource before backend start") + workingWithJobResource = false + case Failure(e: SparkException) => + logWarning(s"SparkJob object not updated. ${e.getMessage}") + // SparkJob should continue if this fails as discussed + // Maybe some retry + backoff mechanism ? + } + } + } + + private def startWatcher(): Unit = { + resourceWatcherPool.execute(new Runnable { + override def run(): Unit = { + watcherFuture = sparkJobResource.watchJobObject() + watcherFuture onComplete { + case Success(w: WatchObject) if w.`type` == "DELETED" => + logInfo("TPR Object deleted externally. Cleaning up") + stopUtil() + // TODO: are there other todo's for a clean kill while job is running? + case Success(w: WatchObject) => + // Log a warning just in case, but this should almost certainly never happen + logWarning(s"Unexpected response received. $w") + deleteJobResource() + workingWithJobResource = false + case Failure(e: Throwable) => + logWarning(e.getMessage) + deleteJobResource() + workingWithJobResource = false // in case watcher fails early on + } + } + }) + } + override def stop(): Unit = { + if (workingWithJobResource) { + sparkJobResource.stopWatcher() + try { + ThreadUtils.awaitResult(watcherFuture, Duration.Inf) + } catch { + case _ : Throwable => + } + } + stopUtil() + + } + + private def stopUtil() = { + resourceWatcherPool.shutdown() // Kill all executor pods indiscriminately killExecutorPods(executorToPod.toVector) killExecutorPods(shutdownToPod.toVector) // TODO: pods that failed during build up due to some error are left behind. - try{ - sparkJobResource.deleteJobObject(podPrefix) + super.stop() + } + + private def deleteJobResource(): Unit = { + try { + sparkJobResource.deleteJobObject(jobResourceName) } catch { case e: SparkException => - logWarning(s"SparkJob object not deleted. ${e.getMessage}") + logError(s"SparkJob object not deleted. ${e.getMessage}") // what else do we need to do here ? } - super.stop() } // Dynamic allocation interfaces @@ -122,11 +179,15 @@ private[spark] class KubernetesClusterSchedulerBackend( } // TODO: be smarter about when to update. - try { - sparkJobResource.updateJobObject(podPrefix, requestedTotal.toString, "/spec/num-executors") - } catch { - case e: SparkException => logWarning(s"SparkJob Object not updated. ${e.getMessage}") + if (workingWithJobResource) { + Try(sparkJobResource + .updateJobObject(jobResourceName, requestedTotal.toString, "/spec/num-executors")) match { + case Success(_) => logInfo(s"Object with name: $jobResourceName updated successfully") + case Failure(e: SparkException) => + logWarning(s"SparkJob Object not updated. ${e.getMessage}") + } } + // TODO: are there meaningful failure modes here? Future.successful(true) } @@ -139,13 +200,12 @@ private[spark] class KubernetesClusterSchedulerBackend( private def createExecutorPods(n: Int) { for (i <- 1 to n) { - executorID += 1 - executorToPod += ((executorID.toString, createExecutorPod(executorID))) + executorToPod += ((i.toString, createExecutorPod(i))) } } def shutdownExecutors(idPodPairs: Seq[(String, String)]) { - val active = getExecutorIds.toSet + val active = getExecutorIds().toSet // Check for any finished shutting down and kill the pods val shutdown = shutdownToPod.toVector.filter { case (e, _) => !active.contains(e) } @@ -205,7 +265,7 @@ private[spark] class KubernetesClusterSchedulerBackend( def createExecutorPod(executorNum: Int): String = { // create a single k8s executor pod. val labelMap = Map("type" -> "spark-executor") - val podName = s"$podPrefix-$executorNum" + val podName = s"$executorBaseName-$executorNum" val submitArgs = mutable.ArrayBuffer.empty[String]