Skip to content

Commit 2b0d745

Browse files
[SPARK-5342][YARN] Allow long running Spark apps to run on secure YARN/HDFS.
Current Spark apps running on Secure YARN/HDFS would not be able to write data to HDFS after 7 days, since delegation tokens cannot be renewed beyond that. This means Spark Streaming apps will not be able to run on Secure YARN. This commit adds basic functionality to fix this issue. In this patch: - new parameters are added - principal and keytab, which can be used to login to a KDC - the client logs in, and then get tokens to start the AM - the keytab is copied to the staging directory - the AM waits for 60% of the time till expiry of the tokens and then logs in using the keytab - each time after 60% of the time, new tokens are created and sent to the executors
1 parent ccba5bc commit 2b0d745

File tree

6 files changed

+159
-29
lines changed

6 files changed

+159
-29
lines changed

core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala

Lines changed: 110 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,26 @@
1717

1818
package org.apache.spark.deploy
1919

20+
import java.io.{ByteArrayInputStream, DataInputStream, DataOutputStream, ByteArrayOutputStream}
2021
import java.lang.reflect.Method
22+
import java.net.URI
23+
import java.nio.ByteBuffer
2124
import java.security.PrivilegedExceptionAction
25+
import java.util.concurrent.atomic.AtomicBoolean
26+
import java.util.concurrent.{TimeUnit, ThreadFactory, Executors}
2227

2328
import org.apache.hadoop.conf.Configuration
24-
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
29+
import org.apache.hadoop.fs.{FileUtil, FileStatus, FileSystem, Path}
2530
import org.apache.hadoop.fs.FileSystem.Statistics
26-
import org.apache.hadoop.mapred.JobConf
31+
import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier
32+
import org.apache.hadoop.mapred.{Master, JobConf}
2733
import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
2834
import org.apache.hadoop.security.Credentials
2935
import org.apache.hadoop.security.UserGroupInformation
3036

31-
import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException}
37+
import org.apache.spark._
3238
import org.apache.spark.annotation.DeveloperApi
33-
import org.apache.spark.util.Utils
39+
import org.apache.spark.util.{SerializableBuffer, Utils}
3440

3541
import scala.collection.JavaConversions._
3642

@@ -40,9 +46,14 @@ import scala.collection.JavaConversions._
4046
*/
4147
@DeveloperApi
4248
class SparkHadoopUtil extends Logging {
43-
val conf: Configuration = newConfiguration(new SparkConf())
49+
val sparkConf = new SparkConf()
50+
val conf: Configuration = newConfiguration(sparkConf)
4451
UserGroupInformation.setConfiguration(conf)
4552

53+
private var keytabFile: Option[String] = None
54+
private var loginPrincipal: Option[String] = None
55+
private val loggedInViaKeytab = new AtomicBoolean(false)
56+
4657
/**
4758
* Runs the given function with a Hadoop UserGroupInformation as a thread local variable
4859
* (distributed to child threads), used for authenticating HDFS and YARN calls.
@@ -121,6 +132,100 @@ class SparkHadoopUtil extends Logging {
121132
UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename)
122133
}
123134

135+
def setPrincipalAndKeytabForLogin(principal: String, keytab: String): Unit ={
136+
loginPrincipal = Option(principal)
137+
keytabFile = Option(keytab)
138+
}
139+
140+
private[spark] def scheduleLoginFromKeytab(callback: (SerializableBuffer) => Unit): Unit = {
141+
142+
loginPrincipal match {
143+
case Some(principal) =>
144+
val keytab = keytabFile.get
145+
val remoteFs = FileSystem.get(conf)
146+
val remoteKeytabPath = new Path(
147+
remoteFs.getHomeDirectory, System.getenv("SPARK_STAGING_DIR") + Path.SEPARATOR + keytab)
148+
val localFS = FileSystem.getLocal(conf)
149+
// At this point, SparkEnv is likely no initialized, so create a dir, put the keytab there.
150+
val tempDir = Utils.createTempDir()
151+
val localURI = new URI(tempDir.getAbsolutePath + Path.SEPARATOR + keytab)
152+
val qualifiedURI = new URI(localFS.makeQualified(new Path(localURI)).toString)
153+
FileUtil.copy(
154+
remoteFs, remoteKeytabPath, localFS, new Path(qualifiedURI), false, false, conf)
155+
// Get the current credentials, find out when they expire.
156+
val creds = UserGroupInformation.getCurrentUser.getCredentials
157+
val credStream = new ByteArrayOutputStream()
158+
creds.writeTokenStorageToStream(new DataOutputStream(credStream))
159+
val in = new DataInputStream(new ByteArrayInputStream(credStream.toByteArray))
160+
val tokenIdentifier = new DelegationTokenIdentifier()
161+
tokenIdentifier.readFields(in)
162+
val timeToRenewal = (0.6 * (tokenIdentifier.getMaxDate - System.currentTimeMillis())).toLong
163+
Executors.newSingleThreadScheduledExecutor(new ThreadFactory {
164+
override def newThread(r: Runnable): Thread = {
165+
val t = new Thread(r)
166+
t.setName("Delegation Token Refresh Thread")
167+
t.setDaemon(true)
168+
t
169+
}
170+
}).scheduleWithFixedDelay(new Runnable {
171+
override def run(): Unit = {
172+
if (!loggedInViaKeytab.get()) {
173+
loginUserFromKeytab(principal, tempDir.getAbsolutePath + Path.SEPARATOR + keytab)
174+
loggedInViaKeytab.set(true)
175+
}
176+
val nns = getNameNodesToAccess(sparkConf) + remoteKeytabPath
177+
val newCredentials = new Credentials()
178+
obtainTokensForNamenodes(nns, conf, newCredentials)
179+
// Now write this out via Akka to executors.
180+
val outputStream = new ByteArrayOutputStream()
181+
newCredentials.writeTokenStorageToStream(new DataOutputStream(outputStream))
182+
callback(new SerializableBuffer(ByteBuffer.wrap(outputStream.toByteArray)))
183+
}
184+
}, timeToRenewal, timeToRenewal, TimeUnit.MILLISECONDS)
185+
186+
}
187+
}
188+
189+
/**
190+
* Get the list of namenodes the user may access.
191+
*/
192+
def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = {
193+
sparkConf.get("spark.yarn.access.namenodes", "")
194+
.split(",")
195+
.map(_.trim())
196+
.filter(!_.isEmpty)
197+
.map(new Path(_))
198+
.toSet
199+
}
200+
201+
def getTokenRenewer(conf: Configuration): String = {
202+
val delegTokenRenewer = Master.getMasterPrincipal(conf)
203+
logDebug("delegation token renewer is: " + delegTokenRenewer)
204+
if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
205+
val errorMessage = "Can't get Master Kerberos principal for use as renewer"
206+
logError(errorMessage)
207+
throw new SparkException(errorMessage)
208+
}
209+
delegTokenRenewer
210+
}
211+
212+
/**
213+
* Obtains tokens for the namenodes passed in and adds them to the credentials.
214+
*/
215+
def obtainTokensForNamenodes(
216+
paths: Set[Path],
217+
conf: Configuration,
218+
creds: Credentials): Unit = {
219+
if (UserGroupInformation.isSecurityEnabled()) {
220+
val delegTokenRenewer = getTokenRenewer(conf)
221+
paths.foreach { dst =>
222+
val dstFs = dst.getFileSystem(conf)
223+
logDebug("getting token for namenode: " + dst)
224+
dstFs.addDelegationTokens(delegTokenRenewer, creds)
225+
}
226+
}
227+
}
228+
124229
/**
125230
* Returns a function that can be called to find Hadoop FileSystem bytes read. If
126231
* getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.executor
1919

2020
import java.net.URL
21+
import java.io.{ByteArrayInputStream, DataInputStream}
2122
import java.nio.ByteBuffer
2223

2324
import scala.collection.mutable
@@ -26,6 +27,7 @@ import scala.concurrent.Await
2627
import akka.actor.{Actor, ActorSelection, Props}
2728
import akka.pattern.Patterns
2829
import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
30+
import org.apache.hadoop.security.Credentials
2931

3032
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv}
3133
import org.apache.spark.TaskState.TaskState
@@ -105,6 +107,12 @@ private[spark] class CoarseGrainedExecutorBackend(
105107
executor.stop()
106108
context.stop(self)
107109
context.system.shutdown()
110+
111+
case UpdateCredentials(newCredentials) =>
112+
val credentials = new Credentials()
113+
credentials.readTokenStorageStream(
114+
new DataInputStream(new ByteArrayInputStream(newCredentials.value.array())))
115+
SparkHadoopUtil.get.addCurrentUserCredentials(credentials)
108116
}
109117

110118
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ private[spark] object CoarseGrainedClusterMessages {
5151
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
5252
data: SerializableBuffer) extends CoarseGrainedClusterMessage
5353

54+
// Driver to all executors.
55+
case class UpdateCredentials(newCredentials: SerializableBuffer)
56+
extends CoarseGrainedClusterMessage
57+
5458
object StatusUpdate {
5559
/** Alternate factory method that takes a ByteBuffer directly for the data field */
5660
def apply(executorId: String, taskId: Long, state: TaskState, data: ByteBuffer)

core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import akka.actor._
2727
import akka.pattern.ask
2828
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
2929

30+
import org.apache.spark.deploy.SparkHadoopUtil
3031
import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
3132
import org.apache.spark.scheduler._
3233
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -75,6 +76,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
7576
override protected def log = CoarseGrainedSchedulerBackend.this.log
7677
private val addressToExecutorId = new HashMap[Address, String]
7778

79+
// If a principal and keytab have been set, use that to create new credentials for executors
80+
// periodically
81+
SparkHadoopUtil.get.scheduleLoginFromKeytab(sendNewCredentialsToExecutors _)
82+
7883
override def preStart() {
7984
// Listen for remote client disconnection events, since they don't go through Akka's watch()
8085
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -85,6 +90,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
8590
context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers)
8691
}
8792

93+
def sendNewCredentialsToExecutors(credentials: SerializableBuffer): Unit = {
94+
executorDataMap.values.foreach{ x =>
95+
x.executorActor ! UpdateCredentials(credentials)
96+
}
97+
}
98+
8899
def receiveWithLogging = {
89100
case RegisterExecutor(executorId, hostPort, cores, logUrls) =>
90101
Utils.checkHostPort(hostPort, "Host port expected " + hostPort)

yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,11 @@ object ApplicationMaster extends Logging {
576576
master = new ApplicationMaster(amArgs, new YarnRMClient(amArgs))
577577
System.exit(master.run())
578578
}
579+
// At this point, we have tokens that will expire only after a while, so we now schedule a
580+
// login for some time before the tokens expire. Since the SparkContext has already started,
581+
// we can now get access to the driver actor as well.
582+
SparkHadoopUtil.get.setPrincipalAndKeytabForLogin(
583+
System.getenv("SPARK_PRINCIPAL"), System.getenv("SPARK_KEYTAB"))
579584
}
580585

581586
private[spark] def sparkContextInitialized(sc: SparkContext) = {

yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ private[spark] class Client(
7070
private val isClusterMode = args.isClusterMode
7171

7272
private var loginFromKeytab = false
73-
private var kerberosFileName: String = null
73+
private var keytabFileName: String = null
7474

7575

7676
def stop(): Unit = yarnClient.stop()
@@ -89,6 +89,7 @@ private[spark] class Client(
8989
* available in the alpha API.
9090
*/
9191
def submitApplication(): ApplicationId = {
92+
// Setup the credentials before doing anything else, so we have don't have issues at any point.
9293
setupCredentials()
9394
yarnClient.init(yarnConf)
9495
yarnClient.start()
@@ -319,6 +320,21 @@ private[spark] class Client(
319320
env("SPARK_YARN_MODE") = "true"
320321
env("SPARK_YARN_STAGING_DIR") = stagingDir
321322
env("SPARK_USER") = UserGroupInformation.getCurrentUser().getShortUserName()
323+
// If we logged in from keytab, make sure we copy the keytab to the staging directory on
324+
// HDFS, and setup the relevant environment vars, so the AM can login again.
325+
if (loginFromKeytab) {
326+
val fs = FileSystem.get(hadoopConf)
327+
val stagingDirPath = new Path(fs.getHomeDirectory, stagingDir)
328+
val localUri = new URI(args.keytab)
329+
val localPath = getQualifiedLocalPath(localUri, hadoopConf)
330+
val destinationPath = new Path(stagingDirPath, keytabFileName)
331+
val replication = sparkConf.getInt("spark.yarn.submit.file.replication",
332+
fs.getDefaultReplication(destinationPath)).toShort
333+
copyFileToRemote(destinationPath, localPath, replication)
334+
env("SPARK_PRINCIPAL") = args.principal
335+
env("SPARK_KEYTAB") = keytabFileName
336+
}
337+
322338

323339
// Set the environment variables to be passed on to the executors.
324340
distCacheMgr.setDistFilesEnv(env)
@@ -553,7 +569,7 @@ private[spark] class Client(
553569
// Generate a file name that can be used for the keytab file, that does not conflict
554570
// with any user file.
555571
val f = new File(keytabPath)
556-
kerberosFileName = f.getName + "-" + System.currentTimeMillis()
572+
keytabFileName = f.getName + "-" + System.currentTimeMillis()
557573
val ugi = UserGroupInformation.loginUserFromKeytabAndReturnUGI(principal, keytabPath)
558574
credentials = ugi.getCredentials
559575
loginFromKeytab = true
@@ -891,23 +907,11 @@ object Client extends Logging {
891907
* Get the list of namenodes the user may access.
892908
*/
893909
private[yarn] def getNameNodesToAccess(sparkConf: SparkConf): Set[Path] = {
894-
sparkConf.get("spark.yarn.access.namenodes", "")
895-
.split(",")
896-
.map(_.trim())
897-
.filter(!_.isEmpty)
898-
.map(new Path(_))
899-
.toSet
910+
SparkHadoopUtil.get.getNameNodesToAccess(sparkConf)
900911
}
901912

902913
private[yarn] def getTokenRenewer(conf: Configuration): String = {
903-
val delegTokenRenewer = Master.getMasterPrincipal(conf)
904-
logDebug("delegation token renewer is: " + delegTokenRenewer)
905-
if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) {
906-
val errorMessage = "Can't get Master Kerberos principal for use as renewer"
907-
logError(errorMessage)
908-
throw new SparkException(errorMessage)
909-
}
910-
delegTokenRenewer
914+
SparkHadoopUtil.get.getTokenRenewer(conf)
911915
}
912916

913917
/**
@@ -917,14 +921,7 @@ object Client extends Logging {
917921
paths: Set[Path],
918922
conf: Configuration,
919923
creds: Credentials): Unit = {
920-
if (UserGroupInformation.isSecurityEnabled()) {
921-
val delegTokenRenewer = getTokenRenewer(conf)
922-
paths.foreach { dst =>
923-
val dstFs = dst.getFileSystem(conf)
924-
logDebug("getting token for namenode: " + dst)
925-
dstFs.addDelegationTokens(delegTokenRenewer, creds)
926-
}
927-
}
924+
SparkHadoopUtil.get.obtainTokensForNamenodes(paths, conf, creds)
928925
}
929926

930927
/**

0 commit comments

Comments
 (0)