diff --git a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala index 51c3d9b158cbe..ecc82d7ac8001 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DependencyUtils.scala @@ -94,7 +94,7 @@ private[deploy] object DependencyUtils { hadoopConf: Configuration, secMgr: SecurityManager): String = { require(fileList != null, "fileList cannot be null.") - fileList.split(",") + Utils.stringToSeq(fileList) .map(downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr)) .mkString(",") } @@ -121,6 +121,11 @@ private[deploy] object DependencyUtils { uri.getScheme match { case "file" | "local" => path + case "http" | "https" | "ftp" if Utils.isTesting => + // This is only used for SparkSubmitSuite unit test. Instead of downloading file remotely, + // return a dummy local path instead. + val file = new File(uri.getPath) + new File(targetDir, file.getName).toURI.toString case _ => val fname = new Path(uri).getName() val localFile = Utils.doFetchFile(uri.toString(), targetDir, fname, sparkConf, secMgr, @@ -131,7 +136,7 @@ private[deploy] object DependencyUtils { def resolveGlobPaths(paths: String, hadoopConf: Configuration): String = { require(paths != null, "paths cannot be null.") - paths.split(",").map(_.trim).filter(_.nonEmpty).flatMap { path => + Utils.stringToSeq(paths).flatMap { path => val uri = Utils.resolveURI(path) uri.getScheme match { case "local" | "http" | "https" | "ftp" => Array(path) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index ea9c9bdaede76..286a4379d2040 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -25,11 +25,11 @@ import java.text.ParseException import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import scala.util.Properties +import scala.util.{Properties, Try} import org.apache.commons.lang3.StringUtils import org.apache.hadoop.conf.{Configuration => HadoopConfiguration} -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.ivy.Ivy @@ -48,6 +48,7 @@ import org.apache.spark._ import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.launcher.SparkLauncher import org.apache.spark.util._ @@ -367,6 +368,52 @@ object SparkSubmit extends CommandLineUtils with Logging { }.orNull } + // When running in YARN, for some remote resources with scheme: + // 1. Hadoop FileSystem doesn't support them. + // 2. We explicitly bypass Hadoop FileSystem with "spark.yarn.dist.forceDownloadSchemes". + // We will download them to local disk prior to add to YARN's distributed cache. + // For yarn client mode, since we already download them with above code, so we only need to + // figure out the local path and replace the remote one. + if (clusterManager == YARN) { + sparkConf.setIfMissing(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + val secMgr = new SecurityManager(sparkConf) + val forceDownloadSchemes = sparkConf.get(FORCE_DOWNLOAD_SCHEMES) + + def shouldDownload(scheme: String): Boolean = { + forceDownloadSchemes.contains(scheme) || + Try { FileSystem.getFileSystemClass(scheme, hadoopConf) }.isFailure + } + + def downloadResource(resource: String): String = { + val uri = Utils.resolveURI(resource) + uri.getScheme match { + case "local" | "file" => resource + case e if shouldDownload(e) => + val file = new File(targetDir, new Path(uri).getName) + if (file.exists()) { + file.toURI.toString + } else { + downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + } + case _ => uri.toString + } + } + + args.primaryResource = Option(args.primaryResource).map { downloadResource }.orNull + args.files = Option(args.files).map { files => + Utils.stringToSeq(files).map(downloadResource).mkString(",") + }.orNull + args.pyFiles = Option(args.pyFiles).map { pyFiles => + Utils.stringToSeq(pyFiles).map(downloadResource).mkString(",") + }.orNull + args.jars = Option(args.jars).map { jars => + Utils.stringToSeq(jars).map(downloadResource).mkString(",") + }.orNull + args.archives = Option(args.archives).map { archives => + Utils.stringToSeq(archives).map(downloadResource).mkString(",") + }.orNull + } + // If we're running a python app, set the main class to our specific python runner if (args.isPython && deployMode == CLIENT) { if (args.primaryResource == PYSPARK_SHELL) { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index e0f696080e566..44a2815b81a73 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -400,4 +400,14 @@ package object config { .doc("Memory to request as a multiple of the size that used to unroll the block.") .doubleConf .createWithDefault(1.5) + + private[spark] val FORCE_DOWNLOAD_SCHEMES = + ConfigBuilder("spark.yarn.dist.forceDownloadSchemes") + .doc("Comma-separated list of schemes for which files will be downloaded to the " + + "local disk prior to being added to YARN's distributed cache. For use in cases " + + "where the YARN service does not support schemes that are supported by Spark, like http, " + + "https and ftp.") + .stringConf + .toSequence + .createWithDefault(Nil) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index bc08808a4d292..836e33c36d9a1 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2684,6 +2684,9 @@ private[spark] object Utils extends Logging { redact(redactionPattern, kvs.toArray) } + def stringToSeq(str: String): Seq[String] = { + str.split(",").map(_.trim()).filter(_.nonEmpty) + } } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 4d69ce844d2ea..ad801bf8519a6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -897,6 +897,71 @@ class SparkSubmitSuite sysProps("spark.submit.pyFiles") should (startWith("/")) } + test("download remote resource if it is not supported by yarn service") { + testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = false) + } + + test("avoid downloading remote resource if it is supported by yarn service") { + testRemoteResources(isHttpSchemeBlacklisted = false, supportMockHttpFs = true) + } + + test("force download from blacklisted schemes") { + testRemoteResources(isHttpSchemeBlacklisted = true, supportMockHttpFs = true) + } + + private def testRemoteResources(isHttpSchemeBlacklisted: Boolean, + supportMockHttpFs: Boolean): Unit = { + val hadoopConf = new Configuration() + updateConfWithFakeS3Fs(hadoopConf) + if (supportMockHttpFs) { + hadoopConf.set("fs.http.impl", classOf[TestFileSystem].getCanonicalName) + hadoopConf.set("fs.http.impl.disable.cache", "true") + } + + val tmpDir = Utils.createTempDir() + val mainResource = File.createTempFile("tmpPy", ".py", tmpDir) + val tmpS3Jar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpS3JarPath = s"s3a://${new File(tmpS3Jar.toURI).getAbsolutePath}" + val tmpHttpJar = TestUtils.createJarWithFiles(Map("test.resource" -> "USER"), tmpDir) + val tmpHttpJarPath = s"http://${new File(tmpHttpJar.toURI).getAbsolutePath}" + + val args = Seq( + "--class", UserClasspathFirstTest.getClass.getName.stripPrefix("$"), + "--name", "testApp", + "--master", "yarn", + "--deploy-mode", "client", + "--jars", s"$tmpS3JarPath,$tmpHttpJarPath", + s"s3a://$mainResource" + ) ++ ( + if (isHttpSchemeBlacklisted) { + Seq("--conf", "spark.yarn.dist.forceDownloadSchemes=http,https") + } else { + Nil + } + ) + + val appArgs = new SparkSubmitArguments(args) + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + + val jars = sysProps("spark.yarn.dist.jars").split(",").toSet + + // The URI of remote S3 resource should still be remote. + assert(jars.contains(tmpS3JarPath)) + + if (supportMockHttpFs) { + // If Http FS is supported by yarn service, the URI of remote http resource should + // still be remote. + assert(jars.contains(tmpHttpJarPath)) + } else { + // If Http FS is not supported by yarn service, or http scheme is configured to be force + // downloading, the URI of remote http resource should be changed to a local one. + val jarName = new File(tmpHttpJar.toURI).getName + val localHttpJar = jars.filter(_.contains(jarName)) + localHttpJar.size should be(1) + localHttpJar.head should startWith("file:") + } + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index e4a74556d4f26..432639588cc2b 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -211,6 +211,15 @@ To use a custom metrics.properties for the application master and executors, upd Comma-separated list of jars to be placed in the working directory of each executor. + + spark.yarn.dist.forceDownloadSchemes + (none) + + Comma-separated list of schemes for which files will be downloaded to the local disk prior to + being added to YARN's distributed cache. For use in cases where the YARN service does not + support schemes that are supported by Spark, like http, https and ftp. + + spark.executor.instances 2