diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md
index 234c9870548c7..14e2df4ed0702 100644
--- a/docs/running-on-kubernetes.md
+++ b/docs/running-on-kubernetes.md
@@ -213,6 +213,14 @@ from the other deployment modes. See the [configuration page](configuration.html
(typically 6-10%).
+
+ spark.kubernetes.driver.labels |
+ (none) |
+
+ Custom labels that will be added to the driver pod. This should be a comma-separated list of label key-value pairs,
+ where each label is in the format key=value.
+ |
+
## Current Limitations
diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala
index 6d7de973a52c2..073afcbba7b52 100644
--- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala
+++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/kubernetes/Client.scala
@@ -77,6 +77,8 @@ private[spark] class Client(
private val serviceAccount = sparkConf.get("spark.kubernetes.submit.serviceAccountName",
"default")
+ private val customLabels = sparkConf.get("spark.kubernetes.driver.labels", "")
+
private implicit val retryableExecutionContext = ExecutionContext
.fromExecutorService(
Executors.newSingleThreadExecutor(new ThreadFactoryBuilder()
@@ -85,6 +87,7 @@ private[spark] class Client(
.build()))
def run(): Unit = {
+ val parsedCustomLabels = parseCustomLabels(customLabels)
var k8ConfBuilder = new ConfigBuilder()
.withApiVersion("v1")
.withMasterUrl(master)
@@ -109,14 +112,15 @@ private[spark] class Client(
.withType("Opaque")
.done()
try {
- val selectors = Map(DRIVER_LAUNCHER_SELECTOR_LABEL -> driverLauncherSelectorValue).asJava
+ val resolvedSelectors = (Map(DRIVER_LAUNCHER_SELECTOR_LABEL -> driverLauncherSelectorValue)
+ ++ parsedCustomLabels).asJava
val (servicePorts, containerPorts) = configurePorts()
val service = kubernetesClient.services().createNew()
.withNewMetadata()
.withName(kubernetesAppId)
.endMetadata()
.withNewSpec()
- .withSelector(selectors)
+ .withSelector(resolvedSelectors)
.withPorts(servicePorts.asJava)
.endSpec()
.done()
@@ -137,7 +141,7 @@ private[spark] class Client(
.asScala
.find(status =>
status.getName == DRIVER_LAUNCHER_CONTAINER_NAME && status.getReady) match {
- case Some(status) =>
+ case Some(_) =>
try {
val driverLauncher = getDriverLauncherService(
k8ClientConfig, master)
@@ -184,7 +188,7 @@ private[spark] class Client(
kubernetesClient.pods().createNew()
.withNewMetadata()
.withName(kubernetesAppId)
- .withLabels(selectors)
+ .withLabels(resolvedSelectors)
.endMetadata()
.withNewSpec()
.withRestartPolicy("OnFailure")
@@ -291,7 +295,7 @@ private[spark] class Client(
Utils.tryWithResource(kubernetesClient
.pods()
- .withLabels(selectors)
+ .withLabels(resolvedSelectors)
.watch(podWatcher)) { createDriverPod }
} finally {
kubernetesClient.secrets().delete(secret)
@@ -336,7 +340,7 @@ private[spark] class Client(
.getOption("spark.ui.port")
.map(_.toInt)
.getOrElse(DEFAULT_UI_PORT))
- (servicePorts.toSeq, containerPorts.toSeq)
+ (servicePorts, containerPorts)
}
private def buildSubmissionRequest(): KubernetesCreateSubmissionRequest = {
@@ -366,7 +370,7 @@ private[spark] class Client(
uploadedJarsBase64Contents = uploadJarsBase64Contents)
}
- def compressJars(maybeFilePaths: Option[String]): Option[TarGzippedData] = {
+ private def compressJars(maybeFilePaths: Option[String]): Option[TarGzippedData] = {
maybeFilePaths
.map(_.split(","))
.map(CompressionUtils.createTarGzip(_))
@@ -391,6 +395,23 @@ private[spark] class Client(
sslSocketFactory = sslContext.getSocketFactory,
trustContext = trustManager)
}
+
+ private def parseCustomLabels(labels: String): Map[String, String] = {
+ labels.split(",").map(_.trim).filterNot(_.isEmpty).map(label => {
+ label.split("=", 2).toSeq match {
+ case Seq(k, v) =>
+ require(k != DRIVER_LAUNCHER_SELECTOR_LABEL, "Label with key" +
+ s" $DRIVER_LAUNCHER_SELECTOR_LABEL cannot be used in" +
+ " spark.kubernetes.driver.labels, as it is reserved for Spark's" +
+ " internal configuration.")
+ (k, v)
+ case _ =>
+ throw new SparkException("Custom labels set by spark.kubernetes.driver.labels" +
+ " must be a comma-separated list of key-value pairs, with format =." +
+ s" Got label: $label. All labels: $labels")
+ }
+ }).toMap
+ }
}
private object Client {
diff --git a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/KubernetesSuite.scala b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/KubernetesSuite.scala
index 6247a1674f8d6..7b3c2b93b865b 100644
--- a/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/KubernetesSuite.scala
+++ b/resource-managers/kubernetes/integration-tests/src/test/scala/org/apache/spark/deploy/kubernetes/integrationtest/KubernetesSuite.scala
@@ -161,4 +161,38 @@ private[spark] class KubernetesSuite extends SparkFunSuite with BeforeAndAfter {
"spark-pi", NAMESPACE, "spark-ui-port")
expectationsForStaticAllocation(sparkMetricsService)
}
+
+ test("Run with custom labels") {
+ val args = Array(
+ "--master", s"k8s://https://${Minikube.getMinikubeIp}:8443",
+ "--deploy-mode", "cluster",
+ "--kubernetes-namespace", NAMESPACE,
+ "--name", "spark-pi",
+ "--executor-memory", "512m",
+ "--executor-cores", "1",
+ "--num-executors", "1",
+ "--upload-jars", HELPER_JAR,
+ "--class", MAIN_CLASS,
+ "--conf", s"spark.kubernetes.submit.caCertFile=${clientConfig.getCaCertFile}",
+ "--conf", s"spark.kubernetes.submit.clientKeyFile=${clientConfig.getClientKeyFile}",
+ "--conf", s"spark.kubernetes.submit.clientCertFile=${clientConfig.getClientCertFile}",
+ "--conf", "spark.kubernetes.executor.docker.image=spark-executor:latest",
+ "--conf", "spark.kubernetes.driver.docker.image=spark-driver:latest",
+ "--conf", "spark.kubernetes.driver.labels=label1=label1value,label2=label2value",
+ EXAMPLES_JAR)
+ SparkSubmit.main(args)
+ val driverPodLabels = minikubeKubernetesClient
+ .pods
+ .withName("spark-pi")
+ .get
+ .getMetadata
+ .getLabels
+ // We can't match all of the selectors directly since one of the selectors is based on the
+ // launch time.
+ assert(driverPodLabels.size == 3, "Unexpected number of pod labels.")
+ assert(driverPodLabels.containsKey("driver-launcher-selector"), "Expected driver launcher" +
+ " selector label to be present.")
+ assert(driverPodLabels.get("label1") == "label1value", "Unexpected value for label1")
+ assert(driverPodLabels.get("label2") == "label2value", "Unexpected value for label2")
+ }
}