@@ -20,14 +20,17 @@ package org.apache.spark.deploy.yarn
2020import java .io ._
2121import java .net .URI
2222import java .nio .ByteBuffer
23- import java .util .concurrent .{TimeUnit , Executors }
23+ import java .util .concurrent .{ TimeUnit , Executors }
2424import java .util .regex .Matcher
2525import java .util .regex .Pattern
2626
2727import scala .collection .mutable .HashMap
28+ import scala .collection .JavaConversions ._
2829import scala .util .Try
2930
31+ import org .apache .hadoop .fs .Options .Rename
3032import org .apache .hadoop .fs .{FileUtil , Path , FileSystem }
33+ import org .apache .hadoop .hdfs .DistributedFileSystem
3134import org .apache .hadoop .hdfs .security .token .delegation .DelegationTokenIdentifier
3235import org .apache .hadoop .io .Text
3336import org .apache .hadoop .mapred .{Master , JobConf }
@@ -41,7 +44,7 @@ import org.apache.hadoop.conf.Configuration
4144
4245import org .apache .spark .{SparkException , SecurityManager , SparkConf }
4346import org .apache .spark .deploy .SparkHadoopUtil
44- import org .apache .spark .util .{ SerializableBuffer , Utils }
47+ import org .apache .spark .util .Utils
4548
4649/**
4750 * Contains util methods to interact with Hadoop from spark.
@@ -52,6 +55,13 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
5255 private var principal : String = null
5356 @ volatile private var loggedInViaKeytab = false
5457 @ volatile private var loggedInUGI : UserGroupInformation = null
58+ @ volatile private var lastCredentialsRefresh = 0l
59+ private lazy val delegationTokenRenewer =
60+ Executors .newSingleThreadScheduledExecutor(
61+ Utils .namedThreadFactory(" Delegation Token Refresh Thread" ))
62+ private lazy val delegationTokenExecuterUpdaterThread = new Runnable {
63+ override def run (): Unit = updateCredentialsIfRequired()
64+ }
5565
5666 override def transferCredentials (source : UserGroupInformation , dest : UserGroupInformation ) {
5767 dest.addCredentials(source.getCredentials())
@@ -92,57 +102,118 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
92102 if (credentials != null ) credentials.getSecretKey(new Text (key)) else null
93103 }
94104
95- override def setPrincipalAndKeytabForLogin (principal : String , keytab : String ): Unit = {
96- this .principal = principal
97- this .keytab = keytab
105+ private [spark] override def scheduleLoginFromKeytab (): Unit = {
106+ val principal = System .getenv(" SPARK_PRINCIPAL" )
107+ val keytab = System .getenv(" SPARK_KEYTAB" )
108+ if (principal != null ) {
109+ val delegationTokenRenewerThread =
110+ new Runnable {
111+ override def run (): Unit = {
112+ if (! loggedInViaKeytab) {
113+ // Keytab is copied by YARN to the working directory of the AM, so full path is
114+ // not needed.
115+ loggedInUGI = UserGroupInformation .loginUserFromKeytabAndReturnUGI(
116+ principal, keytab)
117+ loggedInViaKeytab = true
118+ }
119+ val nns = getNameNodesToAccess(sparkConf)
120+ val newCredentials = loggedInUGI.getCredentials
121+ obtainTokensForNamenodes(nns, conf, newCredentials)
122+ val remoteFs = FileSystem .get(conf)
123+ val stagingDirPath =
124+ new Path (remoteFs.getHomeDirectory, System .getenv(" SPARK_YARN_STAGING_DIR" ))
125+ val tokenPathStr = sparkConf.get(" spark.yarn.credentials.file" )
126+ val tokenPath = new Path (stagingDirPath.toString, tokenPathStr)
127+ val tempTokenPath = new Path (stagingDirPath.toString, tokenPathStr + " .tmp" )
128+ val stream = remoteFs.create(tempTokenPath, true )
129+ // Now write this out to HDFS
130+ newCredentials.writeTokenStorageToStream(stream)
131+ stream.hflush()
132+ stream.close()
133+ remoteFs.delete(tokenPath, true )
134+ remoteFs.rename(tempTokenPath, tokenPath)
135+ delegationTokenRenewer.schedule(
136+ this , (0.75 * (getLatestValidity - System .currentTimeMillis())).toLong,
137+ TimeUnit .MILLISECONDS )
138+ }
139+ }
140+ val timeToRenewal = (0.75 * (getLatestValidity - System .currentTimeMillis())).toLong
141+ delegationTokenRenewer.schedule(
142+ delegationTokenRenewerThread, timeToRenewal, TimeUnit .MILLISECONDS )
143+ }
98144 }
99145
100- private [spark] override def scheduleLoginFromKeytab (
101- callback : (String ) => Unit ): Unit = {
102- if (principal != null ) {
103- // Get the current credentials, find out when they expire.
104- val creds = {
105- if (loggedInUGI == null ) {
106- UserGroupInformation .getCurrentUser.getCredentials
107- } else {
108- loggedInUGI.getCredentials
146+ override def updateCredentialsIfRequired (): Unit = {
147+ try {
148+ val credentialsFile = sparkConf.get(" spark.yarn.credentials.file" )
149+ if (credentialsFile != null && ! credentialsFile.isEmpty) {
150+ val remoteFs = FileSystem .get(conf)
151+ val sparkStagingDir = System .getenv(" SPARK_YARN_STAGING_DIR" )
152+ val stagingDirPath = new Path (remoteFs.getHomeDirectory, sparkStagingDir)
153+ val credentialsFilePath = new Path (stagingDirPath, credentialsFile)
154+ if (remoteFs.exists(credentialsFilePath)) {
155+ val status = remoteFs.getFileStatus(credentialsFilePath)
156+ val modTimeAtStart = status.getModificationTime
157+ if (modTimeAtStart > lastCredentialsRefresh) {
158+ val newCredentials = getCredentialsFromHDFSFile(remoteFs, credentialsFilePath)
159+ val newStatus = remoteFs.getFileStatus(credentialsFilePath)
160+ // File was updated after we started reading it, lets come back later and try to read it.
161+ if (newStatus.getModificationTime != modTimeAtStart) {
162+ delegationTokenRenewer
163+ .schedule(delegationTokenExecuterUpdaterThread, 1 , TimeUnit .HOURS )
164+ } else {
165+ UserGroupInformation .getCurrentUser.addCredentials(newCredentials)
166+ lastCredentialsRefresh = status.getModificationTime
167+ val totalValidity = getLatestValidity - lastCredentialsRefresh
168+ val timeToRunRenewal = lastCredentialsRefresh + (0.8 * totalValidity).toLong
169+ val timeFromNowToRenewal = timeToRunRenewal - System .currentTimeMillis()
170+ delegationTokenRenewer.schedule(delegationTokenExecuterUpdaterThread,
171+ timeFromNowToRenewal, TimeUnit .MILLISECONDS )
172+ }
173+ } else {
174+ // Check every hour to see if new credentials arrived.
175+ delegationTokenRenewer.schedule(delegationTokenExecuterUpdaterThread, 1 , TimeUnit .HOURS )
176+ }
109177 }
110178 }
111- val credStream = new ByteArrayOutputStream ()
112- creds.writeTokenStorageToStream(new DataOutputStream (credStream))
113- val in = new DataInputStream (new ByteArrayInputStream (credStream.toByteArray))
114- val tokenIdentifier = new DelegationTokenIdentifier ()
115- tokenIdentifier.readFields(in)
116- val timeToRenewal = (0.6 * (tokenIdentifier.getMaxDate - System .currentTimeMillis())).toLong
117- Executors .newSingleThreadScheduledExecutor(
118- Utils .namedThreadFactory(" Delegation Token Refresh Thread" )).scheduleWithFixedDelay(
119- new Runnable {
120- override def run (): Unit = {
121- if (! loggedInViaKeytab) {
122- // Keytab is copied by YARN to the working directory of the AM, so full path is
123- // not needed.
124- loggedInUGI = UserGroupInformation .loginUserFromKeytabAndReturnUGI(
125- principal, keytab)
126- loggedInViaKeytab = true
127- }
128- val nns = getNameNodesToAccess(sparkConf)
129- val newCredentials = loggedInUGI.getCredentials
130- obtainTokensForNamenodes(nns, conf, newCredentials)
131- val remoteFs = FileSystem .get(conf)
132- val stagingDir = System .getenv(" SPARK_YARN_STAGING_DIR" )
133- val tokenPath = new Path (remoteFs.getHomeDirectory, stagingDir + Path .SEPARATOR +
134- " credentials - " + System .currentTimeMillis())
135- val stream = remoteFs.create(tokenPath, true )
136- // Now write this out via Akka to executors.
137- newCredentials.writeTokenStorageToStream(stream)
138- stream.hflush()
139- stream.close()
140- callback(tokenPath.toString)
141- }
142- }, timeToRenewal, timeToRenewal, TimeUnit .MILLISECONDS )
179+ } catch {
180+ // Since the file may get deleted while we are reading it,
181+ case e : Exception =>
182+ logWarning(
183+ " Error encountered while trying to update credentials, will try again in 1 hour" , e)
184+ delegationTokenRenewer.schedule(delegationTokenExecuterUpdaterThread, 1 , TimeUnit .HOURS )
143185 }
144186 }
145187
188+ private [spark] def getCredentialsFromHDFSFile (
189+ remoteFs : FileSystem ,
190+ tokenPath : Path
191+ ): Credentials = {
192+ val stream = remoteFs.open(tokenPath)
193+ val newCredentials = new Credentials ()
194+ newCredentials.readFields(stream)
195+ newCredentials
196+ }
197+
198+ private [spark] def getLatestValidity : Long = {
199+ val creds = UserGroupInformation .getCurrentUser.getCredentials
200+ var latestValidity : Long = 0
201+ creds.getAllTokens
202+ .filter(_.getKind == DelegationTokenIdentifier .HDFS_DELEGATION_KIND )
203+ .foreach { t =>
204+ val identifier = new DelegationTokenIdentifier ()
205+ identifier.readFields(new DataInputStream (new ByteArrayInputStream (t.getIdentifier)))
206+ latestValidity = {
207+ if (latestValidity < identifier.getMaxDate) {
208+ identifier.getMaxDate
209+ } else {
210+ latestValidity
211+ }
212+ }
213+ }
214+ latestValidity
215+ }
216+
146217 /**
147218 * Get the list of namenodes the user may access.
148219 */
@@ -172,7 +243,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
172243 def obtainTokensForNamenodes (
173244 paths : Set [Path ],
174245 conf : Configuration ,
175- creds : Credentials ): Unit = {
246+ creds : Credentials
247+ ): Unit = {
176248 if (UserGroupInformation .isSecurityEnabled()) {
177249 val delegTokenRenewer = getTokenRenewer(conf)
178250 paths.foreach { dst =>
0 commit comments