Skip to content

Commit 0985b4e

Browse files
Write tokens to HDFS and read them back when required, rather than sending them over the wire.
1 parent d79b2b9 commit 0985b4e

File tree

9 files changed

+100
-85
lines changed

9 files changed

+100
-85
lines changed

bin/utils.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ function gatherSparkSubmitOpts() {
3636
--conf | --repositories | --properties-file | --driver-memory | --driver-java-options | \
3737
--driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \
3838
--total-executor-cores | --executor-cores | --queue | --num-executors | --archives | \
39-
--proxy-user)
39+
--proxy-user | --principal | --keytab)
4040
if [[ $# -lt 2 ]]; then
4141
"$SUBMIT_USAGE_FUNCTION"
4242
exit 1;

core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ class SparkHadoopUtil extends Logging {
122122
UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename)
123123
}
124124

125-
def setPrincipalAndKeytabForLogin(principal: String, keytab: String): Unit = {}
125+
def setPrincipalAndKeytabForLogin(principal: String, keytab: String): Unit = ???
126126

127-
private[spark] def scheduleLoginFromKeytab(callback: (SerializableBuffer) => Unit): Unit = {}
127+
private[spark] def scheduleLoginFromKeytab(callback: (String) => Unit): Unit = {}
128128

129129
/**
130130
* Returns a function that can be called to find Hadoop FileSystem bytes read. If

core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ object SparkSubmit {
372372
OptionAssigner(args.files, YARN, CLUSTER, clOption = "--files"),
373373
OptionAssigner(args.archives, YARN, CLUSTER, clOption = "--archives"),
374374
OptionAssigner(args.jars, YARN, CLUSTER, clOption = "--addJars"),
375+
OptionAssigner(args.principal, YARN, CLUSTER, clOption = "--principal"),
376+
OptionAssigner(args.keytab, YARN, CLUSTER, clOption = "--keytab"),
375377

376378
// Other options
377379
OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES,

core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
5858
var action: SparkSubmitAction = null
5959
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
6060
var proxyUser: String = null
61+
var principal: String = null
62+
var keytab: String = null
6163

6264
// Standalone cluster mode only
6365
var supervise: Boolean = false
@@ -410,6 +412,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
410412
proxyUser = value
411413
parse(tail)
412414

415+
case ("--principal") :: value :: tail =>
416+
principal = value
417+
parse(tail)
418+
419+
case ("--keytab") :: value :: tail =>
420+
keytab = value
421+
parse(tail)
422+
413423
case ("--help" | "-h") :: tail =>
414424
printUsageAndExit(0)
415425

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818
package org.apache.spark.executor
1919

20-
import java.net.URL
2120
import java.io.{ByteArrayInputStream, DataInputStream}
21+
import java.net.URL
2222
import java.nio.ByteBuffer
2323

2424
import scala.collection.mutable
@@ -27,6 +27,7 @@ import scala.concurrent.Await
2727
import akka.actor.{Actor, ActorSelection, Props}
2828
import akka.pattern.Patterns
2929
import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
30+
import org.apache.hadoop.fs.{Path, FileSystem}
3031
import org.apache.hadoop.security.Credentials
3132

3233
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
@@ -109,12 +110,14 @@ private[spark] class CoarseGrainedExecutorBackend(
109110
context.system.shutdown()
110111

111112
// Add new credentials received from the driver to the current user.
112-
case UpdateCredentials(newCredentials) =>
113+
case UpdateCredentials(newCredentialsPath) =>
113114
logInfo("New credentials received from driver, adding the credentials to the current user")
114115
val credentials = new Credentials()
115-
credentials.readTokenStorageStream(
116-
new DataInputStream(new ByteArrayInputStream(newCredentials.value.array())))
116+
val remoteFs = FileSystem.get(SparkHadoopUtil.get.conf)
117+
val inStream = remoteFs.open(new Path(newCredentialsPath))
118+
credentials.readTokenStorageStream(inStream)
117119
SparkHadoopUtil.get.addCurrentUserCredentials(credentials)
120+
inStream.close()
118121
}
119122

120123
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ private[spark] object CoarseGrainedClusterMessages {
5353

5454
// When the delegation tokens are about expire, the driver creates new tokens and sends them to
5555
// the executors via this message.
56-
case class UpdateCredentials(newCredentials: SerializableBuffer)
56+
case class UpdateCredentials(newCredentialsLocation: String)
5757
extends CoarseGrainedClusterMessage
5858

5959
object StatusUpdate {

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,9 @@ import akka.actor._
2727
import akka.pattern.ask
2828
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
2929

30-
import org.apache.spark.deploy.SparkHadoopUtil
30+
3131
import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
32+
import org.apache.spark.deploy.SparkHadoopUtil
3233
import org.apache.spark.scheduler._
3334
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
3435
import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
@@ -75,11 +76,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
7576
/**
7677
* Send new credentials to executors. This is the method that is called when the scheduled
7778
* login completes, so the new credentials can be sent to the executors.
78-
* @param credentials
79+
* @param credentialsPath
7980
*/
80-
def sendNewCredentialsToExecutors(credentials: SerializableBuffer): Unit = {
81+
def sendNewCredentialsToExecutors(credentialsPath: String): Unit = {
8182
// We don't care about the reply, so going to deadLetters is fine.
82-
executorDataMap.values.foreach(_.executorActor ! UpdateCredentials(credentials))
83+
executorDataMap.values.foreach(_.executorActor ! UpdateCredentials(credentialsPath))
8384
}
8485

8586
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive {

yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -562,25 +562,22 @@ private[spark] class Client(
562562
}
563563

564564
def setupCredentials(): Unit = {
565-
Option(args.principal) match {
566-
case Some(principal) =>
567-
Option(args.keytab) match {
568-
case Some(keytabPath) =>
569-
// Generate a file name that can be used for the keytab file, that does not conflict
570-
// with any user file.
571-
logInfo("Attempting to login to the Kerberos" +
572-
s" using principal: $principal and keytab: $keytabPath")
573-
val f = new File(keytabPath)
574-
keytabFileName = f.getName + "-" + System.currentTimeMillis()
575-
val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytabPath)
576-
credentials = ugi.getCredentials
577-
loginFromKeytab = true
578-
logInfo("Successfully logged into Kerberos.")
579-
case None =>
580-
throw new SparkException("Keytab must be specified when principal is specified.")
581-
}
582-
case None =>
583-
credentials = UserGroupInformation.getCurrentUser.getCredentials
565+
if (args.principal != null) {
566+
if (args.keytab == null) {
567+
throw new SparkException("Keytab must be specified when principal is specified.")
568+
}
569+
logInfo("Attempting to login to the Kerberos" +
570+
s" using principal: ${args.principal} and keytab: ${args.keytab}")
571+
val f = new File(args.keytab)
572+
// Generate a file name that can be used for the keytab file, that does not conflict
573+
// with any user file.
574+
keytabFileName = f.getName + "-" + System.currentTimeMillis()
575+
val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(args.principal, args.keytab)
576+
credentials = ugi.getCredentials
577+
loginFromKeytab = true
578+
logInfo("Successfully logged into Kerberos.")
579+
} else {
580+
credentials = UserGroupInformation.getCurrentUser.getCredentials
584581
}
585582
}
586583

yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ package org.apache.spark.deploy.yarn
2020
import java.io._
2121
import java.net.URI
2222
import java.nio.ByteBuffer
23-
import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean}
24-
import java.util.concurrent.{TimeUnit, ThreadFactory, Executors}
23+
import java.util.concurrent.{TimeUnit, Executors}
2524
import java.util.regex.Matcher
2625
import java.util.regex.Pattern
2726

@@ -49,10 +48,10 @@ import org.apache.spark.util.{SerializableBuffer, Utils}
4948
*/
5049
class YarnSparkHadoopUtil extends SparkHadoopUtil {
5150

52-
private var keytabFile: Option[String] = None
53-
private var loginPrincipal: Option[String] = None
54-
private val loggedInViaKeytab = new AtomicBoolean(false)
55-
private val loggedInUGI = new AtomicReference[UserGroupInformation](null)
51+
private var keytab: String = null
52+
private var principal: String = null
53+
@volatile private var loggedInViaKeytab = false
54+
@volatile private var loggedInUGI: UserGroupInformation = null
5655

5756
override def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
5857
dest.addCredentials(source.getCredentials())
@@ -94,58 +93,61 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
9493
}
9594

9695
override def setPrincipalAndKeytabForLogin(principal: String, keytab: String): Unit = {
97-
loginPrincipal = Option(principal)
98-
keytabFile = Option(keytab)
96+
this.principal = principal
97+
this.keytab = keytab
9998
}
10099

101100
private[spark] override def scheduleLoginFromKeytab(
102-
callback: (SerializableBuffer) => Unit): Unit = {
103-
104-
loginPrincipal match {
105-
case Some(principal) =>
106-
val keytab = keytabFile.get
107-
val remoteFs = FileSystem.get(conf)
108-
val remoteKeytabPath = new Path(
109-
remoteFs.getHomeDirectory, System.getenv("SPARK_STAGING_DIR") + Path.SEPARATOR + keytab)
110-
val localFS = FileSystem.getLocal(conf)
111-
// At this point, SparkEnv is likely no initialized, so create a dir, put the keytab there.
112-
val tempDir = Utils.createTempDir()
113-
val localURI = new URI(tempDir.getAbsolutePath + Path.SEPARATOR + keytab)
114-
val qualifiedURI = new URI(localFS.makeQualified(new Path(localURI)).toString)
115-
FileUtil.copy(
116-
remoteFs, remoteKeytabPath, localFS, new Path(qualifiedURI), false, false, conf)
117-
// Get the current credentials, find out when they expire.
118-
val creds = UserGroupInformation.getCurrentUser.getCredentials
119-
val credStream = new ByteArrayOutputStream()
120-
creds.writeTokenStorageToStream(new DataOutputStream(credStream))
121-
val in = new DataInputStream(new ByteArrayInputStream(credStream.toByteArray))
122-
val tokenIdentifier = new DelegationTokenIdentifier()
123-
tokenIdentifier.readFields(in)
124-
val timeToRenewal = (0.6 * (tokenIdentifier.getMaxDate - System.currentTimeMillis())).toLong
125-
Executors.newSingleThreadScheduledExecutor(new ThreadFactory {
126-
override def newThread(r: Runnable): Thread = {
127-
val t = new Thread(r)
128-
t.setName("Delegation Token Refresh Thread")
129-
t.setDaemon(true)
130-
t
131-
}
132-
}).scheduleWithFixedDelay(new Runnable {
133-
override def run(): Unit = {
134-
if (!loggedInViaKeytab.get()) {
135-
loggedInUGI.set(UserGroupInformation.loginUserFromKeytabAndReturnUGI(
136-
principal, tempDir.getAbsolutePath + Path.SEPARATOR + keytab))
137-
loggedInViaKeytab.set(true)
101+
callback: (String) => Unit): Unit = {
102+
if (principal != null) {
103+
val stagingDir = System.getenv("SPARK_YARN_STAGING_DIR")
104+
val remoteFs = FileSystem.get(conf)
105+
val remoteKeytabPath = new Path(
106+
remoteFs.getHomeDirectory, stagingDir + Path.SEPARATOR + keytab)
107+
val localFS = FileSystem.getLocal(conf)
108+
// At this point, SparkEnv is likely no initialized, so create a dir, put the keytab there.
109+
val tempDir = Utils.createTempDir()
110+
Utils.chmod700(tempDir)
111+
val localURI = new URI(tempDir.getAbsolutePath + Path.SEPARATOR + keytab)
112+
val qualifiedURI = new URI(localFS.makeQualified(new Path(localURI)).toString)
113+
FileUtil.copy(
114+
remoteFs, remoteKeytabPath, localFS, new Path(qualifiedURI), false, false, conf)
115+
// Get the current credentials, find out when they expire.
116+
val creds = {
117+
if (loggedInUGI == null) {
118+
UserGroupInformation.getCurrentUser.getCredentials
119+
} else {
120+
loggedInUGI.getCredentials
121+
}
122+
}
123+
val credStream = new ByteArrayOutputStream()
124+
creds.writeTokenStorageToStream(new DataOutputStream(credStream))
125+
val in = new DataInputStream(new ByteArrayInputStream(credStream.toByteArray))
126+
val tokenIdentifier = new DelegationTokenIdentifier()
127+
tokenIdentifier.readFields(in)
128+
val timeToRenewal = (0.6 * (tokenIdentifier.getMaxDate - System.currentTimeMillis())).toLong
129+
Executors.newSingleThreadScheduledExecutor(
130+
Utils.namedThreadFactory("Delegation Token Refresh Thread")).scheduleWithFixedDelay(
131+
new Runnable {
132+
override def run(): Unit = {
133+
if (!loggedInViaKeytab) {
134+
loggedInUGI = UserGroupInformation.loginUserFromKeytabAndReturnUGI(
135+
principal, tempDir.getAbsolutePath + Path.SEPARATOR + keytab)
136+
loggedInViaKeytab = true
137+
}
138+
val nns = getNameNodesToAccess(sparkConf) + remoteKeytabPath
139+
val newCredentials = loggedInUGI.getCredentials
140+
obtainTokensForNamenodes(nns, conf, newCredentials)
141+
val tokenPath = new Path(remoteFs.getHomeDirectory, stagingDir + Path.SEPARATOR +
142+
"credentials - " + System.currentTimeMillis())
143+
val stream = remoteFs.create(tokenPath, true)
144+
// Now write this out via Akka to executors.
145+
newCredentials.writeTokenStorageToStream(stream)
146+
stream.hflush()
147+
stream.close()
148+
callback(tokenPath.toString)
138149
}
139-
val nns = getNameNodesToAccess(sparkConf) + remoteKeytabPath
140-
val newCredentials = loggedInUGI.get().getCredentials
141-
obtainTokensForNamenodes(nns, conf, newCredentials)
142-
// Now write this out via Akka to executors.
143-
val outputStream = new ByteArrayOutputStream()
144-
newCredentials.writeTokenStorageToStream(new DataOutputStream(outputStream))
145-
callback(new SerializableBuffer(ByteBuffer.wrap(outputStream.toByteArray)))
146-
}
147-
}, timeToRenewal, timeToRenewal, TimeUnit.MILLISECONDS)
148-
case None =>
150+
}, timeToRenewal, timeToRenewal, TimeUnit.MILLISECONDS)
149151
}
150152
}
151153

0 commit comments

Comments
 (0)