diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 24d1a8f9eceae..941f1915959aa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -227,6 +227,11 @@ class SparkContext(config: SparkConf) extends Logging { /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf) + // Need to do security authentication when Hadoop security is turned on + if (SparkHadoopUtil.get.isSecurityEnabled()) { + SparkHadoopUtil.get.doUserAuthentication(this) + } + // Optionally log Spark events private[spark] val eventLogger: Option[EventLoggingListener] = { if (conf.getBoolean("spark.eventLog.enabled", false)) { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index fe0ad9ebbca12..206bbe93ff831 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -18,13 +18,19 @@ package org.apache.spark.deploy import java.security.PrivilegedExceptionAction +import java.util.{Collection, TimerTask, Timer} +import java.io.{File, IOException} +import java.net.URI +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.token.{TokenIdentifier, Token} +import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException} +import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import scala.collection.JavaConversions._ @@ -38,6 +44,8 @@ class SparkHadoopUtil extends Logging { val conf: Configuration = newConfiguration(new SparkConf()) UserGroupInformation.setConfiguration(conf) + val sparkConf = new SparkConf() + /** * Runs the given function with a Hadoop UserGroupInformation as a thread local variable * (distributed to child threads), used for authenticating HDFS and YARN calls. @@ -117,6 +125,170 @@ class SparkHadoopUtil extends Logging { def getSecretKeyFromUserCredentials(key: String): Array[Byte] = { null } + /** + * Return whether Hadoop security is enabled or not. + * + * @return Whether Hadoop security is enabled or not + */ + def isSecurityEnabled(): Boolean = { + UserGroupInformation.isSecurityEnabled + } + + /** + * Do user authentication when Hadoop security is turned on. Used by the driver. + * + * @param sc Spark context + */ + def doUserAuthentication(sc: SparkContext) { + getAuthenticationType match { + case "keytab" => { + // Authentication through a Kerberos keytab file. Necessary for + // long-running services like Shark/Spark Streaming. + scheduleKerberosRenewTask(sc) + } + case _ => { + // No authentication needed. Assuming authentication is already done + // before Spark is launched, e.g., the user has authenticated with + // Kerberos through kinit already. + // Renew a Hadoop delegation token and store the token into a file. + // Add the token file so it gets downloaded by every slave nodes. + sc.addFile(initDelegationToken().toString) + } + } + } + + /** + * Get the user whom the task belongs to. + * + * @param userName Name of the user whom the task belongs to + * @return The user whom the task belongs to + */ + def getTaskUser(userName: String): UserGroupInformation = { + val ugi = UserGroupInformation.createRemoteUser(userName) + // Change the authentication method to Kerberos + ugi.setAuthenticationMethod( + UserGroupInformation.AuthenticationMethod.KERBEROS) + // Get and add Hadoop delegation tokens for the user + val iter = getDelegationTokens().iterator() + while (iter.hasNext) { + ugi.addToken(iter.next()) + } + + ugi + } + + /** + * Get the type of Hadoop security authentication. + * + * @return Type of Hadoop security authentication + */ + private def getAuthenticationType: String = { + sparkConf.get("spark.hadoop.security.authentication") + + } + + /** + * Schedule a timer task for automatically renewing Kerberos credential. + * + * @param sc @param sc Spark context + */ + private def scheduleKerberosRenewTask(sc: SparkContext): Unit = { + val kerberosRenewTimer = new Timer() + val kerberosRenewTimerTask = new TimerTask { + def run(): Unit = { + try { + kerberosLoginFromKeytab + // Renew a Hadoop delegation token and store the token into a file. + // Add the token file so it gets downloaded by every slave nodes. + sc.addFile(initDelegationToken().toString) + } catch { + case ioe: IOException => { + logError("Failed to login from Kerberos keytab", ioe) + } + } + } + } + + val interval = sparkConf.getLong( + "spark.hadoop.security.kerberos.renewInterval", 21600000) + kerberosRenewTimer.schedule(kerberosRenewTimerTask, 0, interval) + logInfo("Scheduled timer task for renewing Kerberos credential") + } + + /** + * Log a user in from a keytab file. Loads user credential from a keytab + * file and logs the user in. + */ + private def kerberosLoginFromKeytab(): Unit = { + val user = System.getProperty("user.name") + val home = System.getProperty("user.home") + val defaultKeytab = home + Path.SEPARATOR + user + ".keytab" + val keytab = sparkConf.get( + "spark.hadoop.security.kerberos.keytab", defaultKeytab) + .replaceAll("_USER", user).replaceAll("_HOME", home) + val principal = sparkConf.get( + "spark.hadoop.security.kerberos.principal", user).replaceAll("_USER", user) + .replaceAll("_HOME", home) + + // Keytab file not found + if (!new File(keytab).exists()) { + throw new IOException("Keytab file %s not found".format(keytab)) + } + + loginUserFromKeytab(principal, keytab) + } + + /** + * Initialize a Hadoop delegation token, store the token into a file, + * and add it to the SparkContext so executors can get it. + * + * @return URI of the token file + */ + private def initDelegationToken(): URI = { + val localFS = FileSystem.getLocal(conf) + // Store the token file under user's home directory + val tokenFile = new Path(localFS.getHomeDirectory, sparkConf.get( + "spark.hadoop.security.token.name", "spark.token")) + if (localFS.exists(tokenFile)) { + localFS.delete(tokenFile, false) + } + + // Get a new token and write it to the given token file + val currentUser = UserGroupInformation.getCurrentUser + val fs = FileSystem.get(conf) + val token: Token[_ <: TokenIdentifier] = + fs.getDelegationToken(currentUser.getShortUserName) + .asInstanceOf[Token[_ <: TokenIdentifier]] + val cred = new Credentials() + cred.addToken(token.getService, token) + cred.writeTokenStorageFile(tokenFile, conf) + // Make sure the token file is read-only to the owner + localFS.setPermission(tokenFile, FsPermission.createImmutable(0400)) + + logInfo("Stored Hadoop delegation token for user %s to file %s".format( + currentUser.getShortUserName, tokenFile.toUri.toString)) + tokenFile.toUri + } + + /** + * Get delegation tokens from the token file added through SparkContext.addFile(). + * + * @return Collection of delegation tokens + */ + private def getDelegationTokens(): Collection[Token[_ <: TokenIdentifier]] = { + // Get the token file added through SparkContext.addFile() + val source = new File(SparkFiles.get(sparkConf.get( + "spark.hadoop.security.token.name", "spark.token"))) + if (source.exists()) { + val sourcePath = new Path("file://" + source.getAbsolutePath) + // Read credentials from the token file + Credentials.readTokenStorageFile(sourcePath, conf).getAllTokens + } else { + throw new IOException( + "Token file %s does not exist".format(source.getAbsolutePath)) + } + } + def loginUserFromKeytab(principalName: String, keytabFilename: String) { UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename) } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index dd903dc65d204..b39c0ba94f7c4 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -21,6 +21,7 @@ import java.io.File import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.concurrent._ +import java.security.PrivilegedExceptionAction import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -158,7 +159,8 @@ private[spark] class Executor( try { SparkEnv.set(env) Accumulators.clear() - val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask) + val (userName, taskFiles, taskJars, taskBytes) = + Task.deserializeWithDependencies(serializedTask) updateDependencies(taskFiles, taskJars) task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader) @@ -178,7 +180,19 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() - val value = task.run(taskId.toInt) + var value: Any = None + if (SparkHadoopUtil.get.isSecurityEnabled()) { + // Get the user whom the task belongs to + val ugi = SparkHadoopUtil.get.getTaskUser(userName) + // Run the task as the user whom the task belongs to + ugi.doAs(new PrivilegedExceptionAction[Unit] { + def run(): Unit = { + value = task.run(taskId.toInt) + } + }) + } else { + value = task.run(taskId.toInt) + } val taskFinish = System.currentTimeMillis() // If the task has been killed, let's fail it. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 6aa0cca06878d..fa3961a2e6ebf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -107,6 +107,7 @@ private[spark] object Task { * Serialize a task and the current app dependencies (files and JARs added to the SparkContext) */ def serializeWithDependencies( + userName: String, task: Task[_], currentFiles: HashMap[String, Long], currentJars: HashMap[String, Long], @@ -116,6 +117,9 @@ private[spark] object Task { val out = new ByteArrayOutputStream(4096) val dataOut = new DataOutputStream(out) + // Write the name of the user launching the task + dataOut.writeUTF(userName) + // Write currentFiles dataOut.writeInt(currentFiles.size) for ((name, timestamp) <- currentFiles) { @@ -142,14 +146,17 @@ private[spark] object Task { * and return the task itself as a serialized ByteBuffer. The caller can then update its * ClassLoaders and deserialize the task. * - * @return (taskFiles, taskJars, taskBytes) + * @return (userName, taskFiles, taskJars, taskBytes) */ def deserializeWithDependencies(serializedTask: ByteBuffer) - : (HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { + : (String, HashMap[String, Long], HashMap[String, Long], ByteBuffer) = { val in = new ByteBufferInputStream(serializedTask) val dataIn = new DataInputStream(in) + // Read the name of the user launching the task + val userName = dataIn.readUTF() + // Read task's files val taskFiles = new HashMap[String, Long]() val numFiles = dataIn.readInt() @@ -166,6 +173,6 @@ private[spark] object Task { // Create a sub-buffer for the rest of the data, which is the serialized Task object val subBuffer = serializedTask.slice() // ByteBufferInputStream will have read just up to task - (taskFiles, taskJars, subBuffer) + (userName, taskFiles, taskJars, subBuffer) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index d9d53faf843ff..4c834e961a383 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -433,10 +433,11 @@ private[spark] class TaskSetManager( } // Serialize and return the task val startTime = clock.getTime() + val userName = System.getProperty("user.name") // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here // we assume the task can be serialized without exceptions. val serializedTask = Task.serializeWithDependencies( - task, sched.sc.addedFiles, sched.sc.addedJars, ser) + userName, task, sched.sc.addedFiles, sched.sc.addedJars, ser) if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 && !emittedTaskSizeWarning) { emittedTaskSizeWarning = true diff --git a/docs/configuration.md b/docs/configuration.md index 36178efb97103..259e1444d4659 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -722,6 +722,44 @@ Apart from these, the following properties are also available, and may be useful Number of cores to allocate for each task. +
spark.task.maxFailures