diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index 234c9870548c7..14e2df4ed0702 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -213,6 +213,14 @@ from the other deployment modes. See the [configuration page](configuration.html (typically 6-10%). + + spark.kubernetes.driver.labels + (none) + + Custom labels that will be added to the driver pod. This should be a comma-separated list of label key-value pairs, + where each label is in the format key=value. + + ## Current Limitations diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala index 6d7de973a52c2..073afcbba7b52 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala @@ -77,6 +77,8 @@ private[spark] class Client( private val serviceAccount = sparkConf.get("spark.kubernetes.submit.serviceAccountName", "default") + private val customLabels = sparkConf.get("spark.kubernetes.driver.labels", "") + private implicit val retryableExecutionContext = ExecutionContext .fromExecutorService( Executors.newSingleThreadExecutor(new ThreadFactoryBuilder() @@ -85,6 +87,7 @@ private[spark] class Client( .build())) def run(): Unit = { + val parsedCustomLabels = parseCustomLabels(customLabels) var k8ConfBuilder = new ConfigBuilder() .withApiVersion("v1") .withMasterUrl(master) @@ -109,14 +112,15 @@ private[spark] class Client( .withType("Opaque") .done() try { - val selectors = Map(DRIVER_LAUNCHER_SELECTOR_LABEL -> driverLauncherSelectorValue).asJava + val resolvedSelectors = (Map(DRIVER_LAUNCHER_SELECTOR_LABEL -> driverLauncherSelectorValue) + ++ parsedCustomLabels).asJava val (servicePorts, containerPorts) = configurePorts() val service = kubernetesClient.services().createNew() .withNewMetadata() .withName(kubernetesAppId) .endMetadata() .withNewSpec() - .withSelector(selectors) + .withSelector(resolvedSelectors) .withPorts(servicePorts.asJava) .endSpec() .done() @@ -137,7 +141,7 @@ private[spark] class Client( .asScala .find(status => status.getName == DRIVER_LAUNCHER_CONTAINER_NAME && status.getReady) match { - case Some(status) => + case Some(_) => try { val driverLauncher = getDriverLauncherService( k8ClientConfig, master) @@ -184,7 +188,7 @@ private[spark] class Client( kubernetesClient.pods().createNew() .withNewMetadata() .withName(kubernetesAppId) - .withLabels(selectors) + .withLabels(resolvedSelectors) .endMetadata() .withNewSpec() .withRestartPolicy("OnFailure") @@ -291,7 +295,7 @@ private[spark] class Client( Utils.tryWithResource(kubernetesClient .pods() - .withLabels(selectors) + .withLabels(resolvedSelectors) .watch(podWatcher)) { createDriverPod } } finally { kubernetesClient.secrets().delete(secret) @@ -336,7 +340,7 @@ private[spark] class Client( .getOption("spark.ui.port") .map(_.toInt) .getOrElse(DEFAULT_UI_PORT)) - (servicePorts.toSeq, containerPorts.toSeq) + (servicePorts, containerPorts) } private def buildSubmissionRequest(): KubernetesCreateSubmissionRequest = { @@ -366,7 +370,7 @@ private[spark] class Client( uploadedJarsBase64Contents = uploadJarsBase64Contents) } - def compressJars(maybeFilePaths: Option[String]): Option[TarGzippedData] = { + private def compressJars(maybeFilePaths: Option[String]): Option[TarGzippedData] = { maybeFilePaths .map(_.split(",")) .map(CompressionUtils.createTarGzip(_)) @@ -391,6 +395,23 @@ private[spark] class Client( sslSocketFactory = sslContext.getSocketFactory, trustContext = trustManager) } + + private def parseCustomLabels(labels: String): Map[String, String] = { + labels.split(",").map(_.trim).filterNot(_.isEmpty).map(label => { + label.split("=", 2).toSeq match { + case Seq(k, v) => + require(k != DRIVER_LAUNCHER_SELECTOR_LABEL, "Label with key" + + s" $DRIVER_LAUNCHER_SELECTOR_LABEL cannot be used in" + + " spark.kubernetes.driver.labels, as it is reserved for Spark's" + + " internal configuration.") + (k, v) + case _ => + throw new SparkException("Custom labels set by spark.kubernetes.driver.labels" + + " must be a comma-separated list of key-value pairs, with format =." + + s" Got label: $label. All labels: $labels") + } + }).toMap + } } private object Client { diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/KubernetesSuite.scala index 6247a1674f8d6..7b3c2b93b865b 100644 --- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/KubernetesSuite.scala +++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/KubernetesSuite.scala @@ -161,4 +161,38 @@ private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfter { "spark-pi", NAMESPACE, "spark-ui-port") expectationsForStaticAllocation(sparkMetricsService) } + + test("Run with custom labels") { + val args = Array( + "--master", s"k8s://https://${Minikube.getMinikubeIp}:8443", + "--deploy-mode", "cluster", + "--kubernetes-namespace", NAMESPACE, + "--name", "spark-pi", + "--executor-memory", "512m", + "--executor-cores", "1", + "--num-executors", "1", + "--upload-jars", HELPER_JAR, + "--class", MAIN_CLASS, + "--conf", s"spark.kubernetes.submit.caCertFile=${clientConfig.getCaCertFile}", + "--conf", s"spark.kubernetes.submit.clientKeyFile=${clientConfig.getClientKeyFile}", + "--conf", s"spark.kubernetes.submit.clientCertFile=${clientConfig.getClientCertFile}", + "--conf", "spark.kubernetes.executor.docker.image=spark-executor:latest", + "--conf", "spark.kubernetes.driver.docker.image=spark-driver:latest", + "--conf", "spark.kubernetes.driver.labels=label1=label1value,label2=label2value", + EXAMPLES_JAR) + SparkSubmit.main(args) + val driverPodLabels = minikubeKubernetesClient + .pods + .withName("spark-pi") + .get + .getMetadata + .getLabels + // We can't match all of the selectors directly since one of the selectors is based on the + // launch time. + assert(driverPodLabels.size == 3, "Unexpected number of pod labels.") + assert(driverPodLabels.containsKey("driver-launcher-selector"), "Expected driver launcher" + + " selector label to be present.") + assert(driverPodLabels.get("label1") == "label1value", "Unexpected value for label1") + assert(driverPodLabels.get("label2") == "label2value", "Unexpected value for label2") + } }