1717
1818package org .apache .spark .deploy
1919
20+ import java .io .{ByteArrayInputStream , DataInputStream , DataOutputStream , ByteArrayOutputStream }
2021import java .lang .reflect .Method
22+ import java .net .URI
23+ import java .nio .ByteBuffer
2124import java .security .PrivilegedExceptionAction
25+ import java .util .concurrent .atomic .AtomicBoolean
26+ import java .util .concurrent .{TimeUnit , ThreadFactory , Executors }
2227
2328import org .apache .hadoop .conf .Configuration
24- import org .apache .hadoop .fs .{FileStatus , FileSystem , Path }
29+ import org .apache .hadoop .fs .{FileUtil , FileStatus , FileSystem , Path }
2530import org .apache .hadoop .fs .FileSystem .Statistics
26- import org .apache .hadoop .mapred .JobConf
31+ import org .apache .hadoop .hdfs .security .token .delegation .DelegationTokenIdentifier
32+ import org .apache .hadoop .mapred .{Master , JobConf }
2733import org .apache .hadoop .mapreduce .{JobContext , TaskAttemptContext }
2834import org .apache .hadoop .security .Credentials
2935import org .apache .hadoop .security .UserGroupInformation
3036
31- import org .apache .spark .{ Logging , SparkContext , SparkConf , SparkException }
37+ import org .apache .spark ._
3238import org .apache .spark .annotation .DeveloperApi
33- import org .apache .spark .util .Utils
39+ import org .apache .spark .util .{ SerializableBuffer , Utils }
3440
3541import scala .collection .JavaConversions ._
3642
@@ -40,9 +46,14 @@ import scala.collection.JavaConversions._
4046 */
4147@ DeveloperApi
4248class SparkHadoopUtil extends Logging {
43- val conf : Configuration = newConfiguration(new SparkConf ())
49+ val sparkConf = new SparkConf ()
50+ val conf : Configuration = newConfiguration(sparkConf)
4451 UserGroupInformation .setConfiguration(conf)
4552
53+ private var keytabFile : Option [String ] = None
54+ private var loginPrincipal : Option [String ] = None
55+ private val loggedInViaKeytab = new AtomicBoolean (false )
56+
4657 /**
4758 * Runs the given function with a Hadoop UserGroupInformation as a thread local variable
4859 * (distributed to child threads), used for authenticating HDFS and YARN calls.
@@ -121,6 +132,100 @@ class SparkHadoopUtil extends Logging {
121132 UserGroupInformation .loginUserFromKeytab(principalName, keytabFilename)
122133 }
123134
135+ def setPrincipalAndKeytabForLogin (principal : String , keytab : String ): Unit = {
136+ loginPrincipal = Option (principal)
137+ keytabFile = Option (keytab)
138+ }
139+
140+ private [spark] def scheduleLoginFromKeytab (callback : (SerializableBuffer ) => Unit ): Unit = {
141+
142+ loginPrincipal match {
143+ case Some (principal) =>
144+ val keytab = keytabFile.get
145+ val remoteFs = FileSystem .get(conf)
146+ val remoteKeytabPath = new Path (
147+ remoteFs.getHomeDirectory, System .getenv(" SPARK_STAGING_DIR" ) + Path .SEPARATOR + keytab)
148+ val localFS = FileSystem .getLocal(conf)
149+ // At this point, SparkEnv is likely no initialized, so create a dir, put the keytab there.
150+ val tempDir = Utils .createTempDir()
151+ val localURI = new URI (tempDir.getAbsolutePath + Path .SEPARATOR + keytab)
152+ val qualifiedURI = new URI (localFS.makeQualified(new Path (localURI)).toString)
153+ FileUtil .copy(
154+ remoteFs, remoteKeytabPath, localFS, new Path (qualifiedURI), false , false , conf)
155+ // Get the current credentials, find out when they expire.
156+ val creds = UserGroupInformation .getCurrentUser.getCredentials
157+ val credStream = new ByteArrayOutputStream ()
158+ creds.writeTokenStorageToStream(new DataOutputStream (credStream))
159+ val in = new DataInputStream (new ByteArrayInputStream (credStream.toByteArray))
160+ val tokenIdentifier = new DelegationTokenIdentifier ()
161+ tokenIdentifier.readFields(in)
162+ val timeToRenewal = (0.6 * (tokenIdentifier.getMaxDate - System .currentTimeMillis())).toLong
163+ Executors .newSingleThreadScheduledExecutor(new ThreadFactory {
164+ override def newThread (r : Runnable ): Thread = {
165+ val t = new Thread (r)
166+ t.setName(" Delegation Token Refresh Thread" )
167+ t.setDaemon(true )
168+ t
169+ }
170+ }).scheduleWithFixedDelay(new Runnable {
171+ override def run (): Unit = {
172+ if (! loggedInViaKeytab.get()) {
173+ loginUserFromKeytab(principal, tempDir.getAbsolutePath + Path .SEPARATOR + keytab)
174+ loggedInViaKeytab.set(true )
175+ }
176+ val nns = getNameNodesToAccess(sparkConf) + remoteKeytabPath
177+ val newCredentials = new Credentials ()
178+ obtainTokensForNamenodes(nns, conf, newCredentials)
179+ // Now write this out via Akka to executors.
180+ val outputStream = new ByteArrayOutputStream ()
181+ newCredentials.writeTokenStorageToStream(new DataOutputStream (outputStream))
182+ callback(new SerializableBuffer (ByteBuffer .wrap(outputStream.toByteArray)))
183+ }
184+ }, timeToRenewal, timeToRenewal, TimeUnit .MILLISECONDS )
185+
186+ }
187+ }
188+
189+ /**
190+ * Get the list of namenodes the user may access.
191+ */
192+ def getNameNodesToAccess (sparkConf : SparkConf ): Set [Path ] = {
193+ sparkConf.get(" spark.yarn.access.namenodes" , " " )
194+ .split(" ," )
195+ .map(_.trim())
196+ .filter(! _.isEmpty)
197+ .map(new Path (_))
198+ .toSet
199+ }
200+
201+ def getTokenRenewer (conf : Configuration ): String = {
202+ val delegTokenRenewer = Master .getMasterPrincipal(conf)
203+ logDebug(" delegation token renewer is: " + delegTokenRenewer)
204+ if (delegTokenRenewer == null || delegTokenRenewer.length() == 0 ) {
205+ val errorMessage = " Can't get Master Kerberos principal for use as renewer"
206+ logError(errorMessage)
207+ throw new SparkException (errorMessage)
208+ }
209+ delegTokenRenewer
210+ }
211+
212+ /**
213+ * Obtains tokens for the namenodes passed in and adds them to the credentials.
214+ */
215+ def obtainTokensForNamenodes (
216+ paths : Set [Path ],
217+ conf : Configuration ,
218+ creds : Credentials ): Unit = {
219+ if (UserGroupInformation .isSecurityEnabled()) {
220+ val delegTokenRenewer = getTokenRenewer(conf)
221+ paths.foreach { dst =>
222+ val dstFs = dst.getFileSystem(conf)
223+ logDebug(" getting token for namenode: " + dst)
224+ dstFs.addDelegationTokens(delegTokenRenewer, creds)
225+ }
226+ }
227+ }
228+
124229 /**
125230 * Returns a function that can be called to find Hadoop FileSystem bytes read. If
126231 * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
0 commit comments