@@ -77,6 +77,8 @@ private[spark] class Client(
7777 private val serviceAccount = sparkConf.get(" spark.kubernetes.submit.serviceAccountName" ,
7878 " default" )
7979
80+ private val customLabels = sparkConf.get(" spark.kubernetes.driver.labels" , " " )
81+
8082 private implicit val retryableExecutionContext = ExecutionContext
8183 .fromExecutorService(
8284 Executors .newSingleThreadExecutor(new ThreadFactoryBuilder ()
@@ -85,6 +87,7 @@ private[spark] class Client(
8587 .build()))
8688
8789 def run (): Unit = {
90+ val parsedCustomLabels = parseCustomLabels(customLabels)
8891 var k8ConfBuilder = new ConfigBuilder ()
8992 .withApiVersion(" v1" )
9093 .withMasterUrl(master)
@@ -109,14 +112,15 @@ private[spark] class Client(
109112 .withType(" Opaque" )
110113 .done()
111114 try {
112- val selectors = Map (DRIVER_LAUNCHER_SELECTOR_LABEL -> driverLauncherSelectorValue).asJava
115+ val resolvedSelectors = (Map (DRIVER_LAUNCHER_SELECTOR_LABEL -> driverLauncherSelectorValue)
116+ ++ parsedCustomLabels).asJava
113117 val (servicePorts, containerPorts) = configurePorts()
114118 val service = kubernetesClient.services().createNew()
115119 .withNewMetadata()
116120 .withName(kubernetesAppId)
117121 .endMetadata()
118122 .withNewSpec()
119- .withSelector(selectors )
123+ .withSelector(resolvedSelectors )
120124 .withPorts(servicePorts.asJava)
121125 .endSpec()
122126 .done()
@@ -137,7 +141,7 @@ private[spark] class Client(
137141 .asScala
138142 .find(status =>
139143 status.getName == DRIVER_LAUNCHER_CONTAINER_NAME && status.getReady) match {
140- case Some (status ) =>
144+ case Some (_ ) =>
141145 try {
142146 val driverLauncher = getDriverLauncherService(
143147 k8ClientConfig, master)
@@ -184,7 +188,7 @@ private[spark] class Client(
184188 kubernetesClient.pods().createNew()
185189 .withNewMetadata()
186190 .withName(kubernetesAppId)
187- .withLabels(selectors )
191+ .withLabels(resolvedSelectors )
188192 .endMetadata()
189193 .withNewSpec()
190194 .withRestartPolicy(" OnFailure" )
@@ -291,7 +295,7 @@ private[spark] class Client(
291295
292296 Utils .tryWithResource(kubernetesClient
293297 .pods()
294- .withLabels(selectors )
298+ .withLabels(resolvedSelectors )
295299 .watch(podWatcher)) { createDriverPod }
296300 } finally {
297301 kubernetesClient.secrets().delete(secret)
@@ -336,7 +340,7 @@ private[spark] class Client(
336340 .getOption(" spark.ui.port" )
337341 .map(_.toInt)
338342 .getOrElse(DEFAULT_UI_PORT ))
339- (servicePorts.toSeq , containerPorts.toSeq )
343+ (servicePorts, containerPorts)
340344 }
341345
342346 private def buildSubmissionRequest (): KubernetesCreateSubmissionRequest = {
@@ -366,7 +370,7 @@ private[spark] class Client(
366370 uploadedJarsBase64Contents = uploadJarsBase64Contents)
367371 }
368372
369- def compressJars (maybeFilePaths : Option [String ]): Option [TarGzippedData ] = {
373+ private def compressJars (maybeFilePaths : Option [String ]): Option [TarGzippedData ] = {
370374 maybeFilePaths
371375 .map(_.split(" ," ))
372376 .map(CompressionUtils .createTarGzip(_))
@@ -391,6 +395,23 @@ private[spark] class Client(
391395 sslSocketFactory = sslContext.getSocketFactory,
392396 trustContext = trustManager)
393397 }
398+
399+ private def parseCustomLabels (labels : String ): Map [String , String ] = {
400+ labels.split(" ," ).map(_.trim).filterNot(_.isEmpty).map(label => {
401+ label.split(" =" , 2 ).toSeq match {
402+ case Seq (k, v) =>
403+ require(k != DRIVER_LAUNCHER_SELECTOR_LABEL , " Label with key" +
404+ s " $DRIVER_LAUNCHER_SELECTOR_LABEL cannot be used in " +
405+ " spark.kubernetes.driver.labels, as it is reserved for Spark's" +
406+ " internal configuration." )
407+ (k, v)
408+ case _ =>
409+ throw new SparkException (" Custom labels set by spark.kubernetes.driver.labels" +
410+ " must be a comma-separated list of key-value pairs, with format <key>=<value>." +
411+ s " Got label: $label. All labels: $labels" )
412+ }
413+ }).toMap
414+ }
394415}
395416
396417private object Client {
0 commit comments