1717
1818package org .apache .spark .deploy .yarn
1919
20- import java .io .File
20+ import java .io ._
21+ import java .net .URI
22+ import java .nio .ByteBuffer
23+ import java .util .concurrent .atomic .AtomicBoolean
24+ import java .util .concurrent .{TimeUnit , ThreadFactory , Executors }
2125import java .util .regex .Matcher
2226import java .util .regex .Pattern
2327
2428import scala .collection .mutable .HashMap
2529import scala .util .Try
2630
31+ import org .apache .hadoop .fs .{FileUtil , Path , FileSystem }
32+ import org .apache .hadoop .hdfs .security .token .delegation .DelegationTokenIdentifier
2733import org .apache .hadoop .io .Text
28- import org .apache .hadoop .mapred .JobConf
34+ import org .apache .hadoop .mapred .{ Master , JobConf }
2935import org .apache .hadoop .security .Credentials
3036import org .apache .hadoop .security .UserGroupInformation
3137import org .apache .hadoop .yarn .conf .YarnConfiguration
@@ -34,15 +40,19 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
3440import org .apache .hadoop .yarn .api .records .{Priority , ApplicationAccessType }
3541import org .apache .hadoop .conf .Configuration
3642
37- import org .apache .spark .{SecurityManager , SparkConf }
43+ import org .apache .spark .{SparkException , SecurityManager , SparkConf }
3844import org .apache .spark .deploy .SparkHadoopUtil
39- import org .apache .spark .util .Utils
45+ import org .apache .spark .util .{ SerializableBuffer , Utils }
4046
4147/**
4248 * Contains util methods to interact with Hadoop from spark.
4349 */
4450class YarnSparkHadoopUtil extends SparkHadoopUtil {
4551
52+ private var keytabFile : Option [String ] = None
53+ private var loginPrincipal : Option [String ] = None
54+ private val loggedInViaKeytab = new AtomicBoolean (false )
55+
4656 override def transferCredentials (source : UserGroupInformation , dest : UserGroupInformation ) {
4757 dest.addCredentials(source.getCredentials())
4858 }
@@ -82,6 +92,101 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil {
8292 if (credentials != null ) credentials.getSecretKey(new Text (key)) else null
8393 }
8494
95+ override def setPrincipalAndKeytabForLogin (principal : String , keytab : String ): Unit = {
96+ loginPrincipal = Option (principal)
97+ keytabFile = Option (keytab)
98+ }
99+
100+ private [spark] override def scheduleLoginFromKeytab (
101+ callback : (SerializableBuffer ) => Unit ): Unit = {
102+
103+ loginPrincipal match {
104+ case Some (principal) =>
105+ val keytab = keytabFile.get
106+ val remoteFs = FileSystem .get(conf)
107+ val remoteKeytabPath = new Path (
108+ remoteFs.getHomeDirectory, System .getenv(" SPARK_STAGING_DIR" ) + Path .SEPARATOR + keytab)
109+ val localFS = FileSystem .getLocal(conf)
110+ // At this point, SparkEnv is likely no initialized, so create a dir, put the keytab there.
111+ val tempDir = Utils .createTempDir()
112+ val localURI = new URI (tempDir.getAbsolutePath + Path .SEPARATOR + keytab)
113+ val qualifiedURI = new URI (localFS.makeQualified(new Path (localURI)).toString)
114+ FileUtil .copy(
115+ remoteFs, remoteKeytabPath, localFS, new Path (qualifiedURI), false , false , conf)
116+ // Get the current credentials, find out when they expire.
117+ val creds = UserGroupInformation .getCurrentUser.getCredentials
118+ val credStream = new ByteArrayOutputStream ()
119+ creds.writeTokenStorageToStream(new DataOutputStream (credStream))
120+ val in = new DataInputStream (new ByteArrayInputStream (credStream.toByteArray))
121+ val tokenIdentifier = new DelegationTokenIdentifier ()
122+ tokenIdentifier.readFields(in)
123+ val timeToRenewal = (0.6 * (tokenIdentifier.getMaxDate - System .currentTimeMillis())).toLong
124+ Executors .newSingleThreadScheduledExecutor(new ThreadFactory {
125+ override def newThread (r : Runnable ): Thread = {
126+ val t = new Thread (r)
127+ t.setName(" Delegation Token Refresh Thread" )
128+ t.setDaemon(true )
129+ t
130+ }
131+ }).scheduleWithFixedDelay(new Runnable {
132+ override def run (): Unit = {
133+ if (! loggedInViaKeytab.get()) {
134+ loginUserFromKeytab(principal, tempDir.getAbsolutePath + Path .SEPARATOR + keytab)
135+ loggedInViaKeytab.set(true )
136+ }
137+ val nns = getNameNodesToAccess(sparkConf) + remoteKeytabPath
138+ val newCredentials = new Credentials ()
139+ obtainTokensForNamenodes(nns, conf, newCredentials)
140+ // Now write this out via Akka to executors.
141+ val outputStream = new ByteArrayOutputStream ()
142+ newCredentials.writeTokenStorageToStream(new DataOutputStream (outputStream))
143+ callback(new SerializableBuffer (ByteBuffer .wrap(outputStream.toByteArray)))
144+ }
145+ }, timeToRenewal, timeToRenewal, TimeUnit .MILLISECONDS )
146+ case None =>
147+ }
148+ }
149+
150+ /**
151+ * Get the list of namenodes the user may access.
152+ */
153+ def getNameNodesToAccess (sparkConf : SparkConf ): Set [Path ] = {
154+ sparkConf.get(" spark.yarn.access.namenodes" , " " )
155+ .split(" ," )
156+ .map(_.trim())
157+ .filter(! _.isEmpty)
158+ .map(new Path (_))
159+ .toSet
160+ }
161+
162+ def getTokenRenewer (conf : Configuration ): String = {
163+ val delegTokenRenewer = Master .getMasterPrincipal(conf)
164+ logDebug(" delegation token renewer is: " + delegTokenRenewer)
165+ if (delegTokenRenewer == null || delegTokenRenewer.length() == 0 ) {
166+ val errorMessage = " Can't get Master Kerberos principal for use as renewer"
167+ logError(errorMessage)
168+ throw new SparkException (errorMessage)
169+ }
170+ delegTokenRenewer
171+ }
172+
173+ /**
174+ * Obtains tokens for the namenodes passed in and adds them to the credentials.
175+ */
176+ def obtainTokensForNamenodes (
177+ paths : Set [Path ],
178+ conf : Configuration ,
179+ creds : Credentials ): Unit = {
180+ if (UserGroupInformation .isSecurityEnabled()) {
181+ val delegTokenRenewer = getTokenRenewer(conf)
182+ paths.foreach { dst =>
183+ val dstFs = dst.getFileSystem(conf)
184+ logDebug(" getting token for namenode: " + dst)
185+ dstFs.addDelegationTokens(delegTokenRenewer, creds)
186+ }
187+ }
188+ }
189+
85190}
86191
87192object YarnSparkHadoopUtil {
@@ -211,4 +316,5 @@ object YarnSparkHadoopUtil {
211316 def getClassPathSeparator (): String = {
212317 classPathSeparatorField.get(null ).asInstanceOf [String ]
213318 }
319+
214320}
0 commit comments