@@ -20,8 +20,7 @@ package org.apache.spark.deploy.yarn
2020import java .io ._
2121import java .net .URI
2222import java .nio .ByteBuffer
23- import java .util .concurrent .atomic .{AtomicReference , AtomicBoolean }
24- import java .util .concurrent .{TimeUnit , ThreadFactory , Executors }
23+ import java .util .concurrent .{TimeUnit , Executors }
2524import java .util .regex .Matcher
2625import java .util .regex .Pattern
2726
@@ -49,10 +48,10 @@ import org.apache.spark.util.{SerializableBuffer, Utils}
4948 */
5049class YarnSparkHadoopUtil extends SparkHadoopUtil {
5150
52- private var keytabFile : Option [ String ] = None
53- private var loginPrincipal : Option [ String ] = None
54- private val loggedInViaKeytab = new AtomicBoolean ( false )
55- private val loggedInUGI = new AtomicReference [ UserGroupInformation ]( null )
51+ private var keytab : String = null
52+ private var principal : String = null
53+ @ volatile private var loggedInViaKeytab = false
54+ @ volatile private var loggedInUGI : UserGroupInformation = null
5655
5756 override def transferCredentials (source : UserGroupInformation , dest : UserGroupInformation ) {
5857 dest.addCredentials(source.getCredentials())
@@ -94,58 +93,61 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
9493 }
9594
9695 override def setPrincipalAndKeytabForLogin (principal : String , keytab : String ): Unit = {
97- loginPrincipal = Option ( principal)
98- keytabFile = Option ( keytab)
96+ this .principal = principal
97+ this .keytab = keytab
9998 }
10099
101100 private [spark] override def scheduleLoginFromKeytab (
102- callback : (SerializableBuffer ) => Unit ): Unit = {
103-
104- loginPrincipal match {
105- case Some (principal) =>
106- val keytab = keytabFile.get
107- val remoteFs = FileSystem .get(conf)
108- val remoteKeytabPath = new Path (
109- remoteFs.getHomeDirectory, System .getenv(" SPARK_STAGING_DIR" ) + Path .SEPARATOR + keytab)
110- val localFS = FileSystem .getLocal(conf)
111- // At this point, SparkEnv is likely no initialized, so create a dir, put the keytab there.
112- val tempDir = Utils .createTempDir()
113- val localURI = new URI (tempDir.getAbsolutePath + Path .SEPARATOR + keytab)
114- val qualifiedURI = new URI (localFS.makeQualified(new Path (localURI)).toString)
115- FileUtil .copy(
116- remoteFs, remoteKeytabPath, localFS, new Path (qualifiedURI), false , false , conf)
117- // Get the current credentials, find out when they expire.
118- val creds = UserGroupInformation .getCurrentUser.getCredentials
119- val credStream = new ByteArrayOutputStream ()
120- creds.writeTokenStorageToStream(new DataOutputStream (credStream))
121- val in = new DataInputStream (new ByteArrayInputStream (credStream.toByteArray))
122- val tokenIdentifier = new DelegationTokenIdentifier ()
123- tokenIdentifier.readFields(in)
124- val timeToRenewal = (0.6 * (tokenIdentifier.getMaxDate - System .currentTimeMillis())).toLong
125- Executors .newSingleThreadScheduledExecutor(new ThreadFactory {
126- override def newThread (r : Runnable ): Thread = {
127- val t = new Thread (r)
128- t.setName(" Delegation Token Refresh Thread" )
129- t.setDaemon(true )
130- t
131- }
132- }).scheduleWithFixedDelay(new Runnable {
133- override def run (): Unit = {
134- if (! loggedInViaKeytab.get()) {
135- loggedInUGI.set(UserGroupInformation .loginUserFromKeytabAndReturnUGI(
136- principal, tempDir.getAbsolutePath + Path .SEPARATOR + keytab))
137- loggedInViaKeytab.set(true )
101+ callback : (String ) => Unit ): Unit = {
102+ if (principal != null ) {
103+ val stagingDir = System .getenv(" SPARK_YARN_STAGING_DIR" )
104+ val remoteFs = FileSystem .get(conf)
105+ val remoteKeytabPath = new Path (
106+ remoteFs.getHomeDirectory, stagingDir + Path .SEPARATOR + keytab)
107+ val localFS = FileSystem .getLocal(conf)
108+ // At this point, SparkEnv is likely no initialized, so create a dir, put the keytab there.
109+ val tempDir = Utils .createTempDir()
110+ Utils .chmod700(tempDir)
111+ val localURI = new URI (tempDir.getAbsolutePath + Path .SEPARATOR + keytab)
112+ val qualifiedURI = new URI (localFS.makeQualified(new Path (localURI)).toString)
113+ FileUtil .copy(
114+ remoteFs, remoteKeytabPath, localFS, new Path (qualifiedURI), false , false , conf)
115+ // Get the current credentials, find out when they expire.
116+ val creds = {
117+ if (loggedInUGI == null ) {
118+ UserGroupInformation .getCurrentUser.getCredentials
119+ } else {
120+ loggedInUGI.getCredentials
121+ }
122+ }
123+ val credStream = new ByteArrayOutputStream ()
124+ creds.writeTokenStorageToStream(new DataOutputStream (credStream))
125+ val in = new DataInputStream (new ByteArrayInputStream (credStream.toByteArray))
126+ val tokenIdentifier = new DelegationTokenIdentifier ()
127+ tokenIdentifier.readFields(in)
128+ val timeToRenewal = (0.6 * (tokenIdentifier.getMaxDate - System .currentTimeMillis())).toLong
129+ Executors .newSingleThreadScheduledExecutor(
130+ Utils .namedThreadFactory(" Delegation Token Refresh Thread" )).scheduleWithFixedDelay(
131+ new Runnable {
132+ override def run (): Unit = {
133+ if (! loggedInViaKeytab) {
134+ loggedInUGI = UserGroupInformation .loginUserFromKeytabAndReturnUGI(
135+ principal, tempDir.getAbsolutePath + Path .SEPARATOR + keytab)
136+ loggedInViaKeytab = true
137+ }
138+ val nns = getNameNodesToAccess(sparkConf) + remoteKeytabPath
139+ val newCredentials = loggedInUGI.getCredentials
140+ obtainTokensForNamenodes(nns, conf, newCredentials)
141+ val tokenPath = new Path (remoteFs.getHomeDirectory, stagingDir + Path .SEPARATOR +
142+ " credentials - " + System .currentTimeMillis())
143+ val stream = remoteFs.create(tokenPath, true )
144+ // Now write this out via Akka to executors.
145+ newCredentials.writeTokenStorageToStream(stream)
146+ stream.hflush()
147+ stream.close()
148+ callback(tokenPath.toString)
138149 }
139- val nns = getNameNodesToAccess(sparkConf) + remoteKeytabPath
140- val newCredentials = loggedInUGI.get().getCredentials
141- obtainTokensForNamenodes(nns, conf, newCredentials)
142- // Now write this out via Akka to executors.
143- val outputStream = new ByteArrayOutputStream ()
144- newCredentials.writeTokenStorageToStream(new DataOutputStream (outputStream))
145- callback(new SerializableBuffer (ByteBuffer .wrap(outputStream.toByteArray)))
146- }
147- }, timeToRenewal, timeToRenewal, TimeUnit .MILLISECONDS )
148- case None =>
150+ }, timeToRenewal, timeToRenewal, TimeUnit .MILLISECONDS )
149151 }
150152 }
151153
0 commit comments