diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 324f6f8894d3..d471bb214390 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -23,6 +23,7 @@ import java.net.URL import java.security.PrivilegedExceptionAction import java.text.ParseException import java.util.UUID +import java.util.Collections import scala.annotation.tailrec import scala.collection.mutable.{ArrayBuffer, HashMap, Map} @@ -1000,33 +1001,81 @@ private[spark] object SparkSubmitUtils { "tags_", "unsafe_") /** - * Represents a Maven Coordinate - * @param groupId the groupId of the coordinate - * @param artifactId the artifactId of the coordinate - * @param version the version of the coordinate + * Represents a Artifact Coordinate to resolve + * @param groupId the groupId of the artifact + * @param artifactId the artifactId of the artifact + * @param version the version of the artifact + * @param extraParams the extra params to resolve the artifact */ - private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) { - override def toString: String = s"$groupId:$artifactId:$version" + private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String, + extraParams: Map[String, String] = Map()) { + + def params: String = if (extraParams.isEmpty) { + "" + } else { + "?" + extraParams.map{ case (k, v) => (k + "=" + v) }.mkString(":") + } + + override def toString: String = s"$groupId:$artifactId:$version$params" } /** - * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided - * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. - * @param coordinates Comma-delimited string of maven coordinates - * @return Sequence of Maven coordinates + * Extracts artifact coordinates from a comma-delimited string. Coordinates should be provided + * in the format `groupId:artifactId:version?param1=value1\¶m2\&value2:..` or + * `groupId/artifactId:version?param1=value1\¶m2=value2:..` + * + * Param splitter & is the background process char in cli, so when multiple params is used, + * either & should to be escaped or the value of --packages should be enclosed in double quotes + * + * Optional params are 'classifier', 'transitive', 'exclude', 'conf': + * classifier: classifier of the artifact + * transitive: whether to resolve transitive deps for the artifact + * exlude: exclude list of transitive artifacts for this artifact(e.g. "a#b#c") + * conf: the conf of the artifact + * + * @param coordinates Comma-delimited string of artifact coordinates + * @return Sequence of Artifact coordinates */ def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = { coordinates.split(",").map { p => - val splits = p.replace("/", ":").split(":") - require(splits.length == 3, s"Provided Maven Coordinates must be in the form " + - s"'groupId:artifactId:version'. The coordinate provided is: $p") - require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " + - s"be whitespace. The groupId provided is: ${splits(0)}") - require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " + - s"be whitespace. The artifactId provided is: ${splits(1)}") - require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " + - s"be whitespace. The version provided is: ${splits(2)}") - new MavenCoordinate(splits(0), splits(1), splits(2)) + val errMsg = s"Provided Artifact Coordinates must be in the form " + + s"'groupId:artifactId:version?param1=a\\¶m2=b\\¶m3=c'. Optional params are" + + s"'classifier', 'transitive', 'exclude', 'conf'. The coordinate provided is: $p" + // Split artifact coordinate and params + val parts = p.replace("/", ":").split("\\?") + require(parts.length == 1 || parts.length == 2, errMsg) + // Parse coordinate 'groupId:artifactId:version' + val coords = parts(0).split(":") + require(coords.length == 3, errMsg) + require(coords(0) != null && coords(0).trim.nonEmpty, s"The groupId cannot be null or " + + s"be whitespace. The groupId provided is: ${coords(0)}. ${errMsg}") + require(coords(1) != null && coords(1).trim.nonEmpty, s"The artifactId cannot be null or " + + s"be whitespace. The artifactId provided is: ${coords(1)}. ${errMsg}") + require(coords(2) != null && coords(2).trim.nonEmpty, s"The version cannot be null or " + + s"be whitespace. The version provided is: ${coords(2)}. ${errMsg}") + if (parts.length == 1) { + new MavenCoordinate(coords(0), coords(1), coords(2)) + } else { + // Parse params 'param1=a\¶m2=b\¶m3=c' + val params = parts(1).split("\\&") + var paramMap = Map[String, String]() + for (i <- 0 until params.length) { + require(params(i) != null && params(i).trim.nonEmpty, errMsg) + val param = params(i).split("=") + require(param.length == 2, errMsg) + require(param(0) != null && param(0).trim.nonEmpty, s"The param key " + + s"cannot be null or be whitespace. The key provided is: ${param(0)}. ${errMsg}") + require(param(1) != null && param(1).trim.nonEmpty, s"The param value " + + s"cannot be null or be whitespace. The value provided is: ${param(1)}. ${errMsg}") + + if (Set("classifier", "transitive", "conf", "exclude").contains(param(0))) { + paramMap += (param(0) -> param(1)) + } else { + throw new RuntimeException(errMsg) + } + } + new MavenCoordinate(coords(0), coords(1), coords(2), paramMap) + } } } @@ -1098,20 +1147,51 @@ private[spark] object SparkSubmitUtils { cacheDirectory: File): String = { artifacts.map { artifactInfo => val artifact = artifactInfo.asInstanceOf[Artifact].getModuleRevisionId + val classifier = artifactInfo.asInstanceOf[Artifact].getId.getExtraAttribute("classifier") + val suffix = if (classifier == null) "" else s"-${classifier}" cacheDirectory.getAbsolutePath + File.separator + - s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}.jar" + s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}${suffix}.jar" }.mkString(",") } - /** Adds the given maven coordinates to Ivy's module descriptor. */ + /** Adds the given artifact coordinates to Ivy's module descriptor. */ def addDependenciesToIvy( md: DefaultModuleDescriptor, artifacts: Seq[MavenCoordinate], + ivySettings: IvySettings, ivyConfName: String): Unit = { - artifacts.foreach { mvn => - val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version) - val dd = new DefaultDependencyDescriptor(ri, false, false) + artifacts.foreach { art => + val ri = ModuleRevisionId.newInstance(art.groupId, art.artifactId, art.version) + val dd = art.extraParams.get("transitive") match { + case Some(t) => new DefaultDependencyDescriptor(null, ri, false, false, t.toBoolean) + case None => new DefaultDependencyDescriptor(ri, false, false) + } dd.addDependencyConfiguration(ivyConfName, ivyConfName + "(runtime)") + + art.extraParams.foreach { case (param, pvalue) => + param match { + // Exclude dependencies(name separated by #) for this artifact + case "exclude" => pvalue.split("#").foreach { ex => + dd.addExcludeRule(ivyConfName, + createExclusion("*:*" + ex + "*:*", ivySettings, ivyConfName))} + + // Add ivy conf for the artifact to default conf, so it can be resolved with default + // e.g. ivy: + case "conf" => dd.addDependencyConfiguration(ivyConfName, pvalue) + + // If this artifact has classifier, add descriptor with classifier so it can be resolved + // e.g. ivy: + case "classifier" => + val dad = new DefaultDependencyArtifactDescriptor(dd, art.artifactId, "jar", "jar", + null, Collections.singletonMap(param, pvalue)) + dad.addConfiguration(art.extraParams.getOrElse("conf", "")) + dd.addDependencyArtifact(ivyConfName, dad) + + // Already used, ignore + case "transitive" => + } + } + // scalastyle:off println printStream.println(s"${dd.getDependencyId} added as a dependency") // scalastyle:on println @@ -1245,11 +1325,11 @@ private[spark] object SparkSubmitUtils { } /** - * Resolves any dependencies that were supplied through maven coordinates - * @param coordinates Comma-delimited string of maven coordinates + * Resolves any dependencies that were supplied through artifact coordinates + * @param coordinates Comma-delimited string of artifact coordinates * @param ivySettings An IvySettings containing resolvers to use * @param exclusions Exclusions to apply when resolving transitive dependencies - * @return The comma-delimited path to the jars of the given maven artifacts including their + * @return The comma-delimited path to the jars of the given artifacts including their * transitive dependencies */ def resolveMavenCoordinates( @@ -1299,7 +1379,7 @@ private[spark] object SparkSubmitUtils { // Add exclusion rules for Spark and Scala Library addExclusionRules(ivySettings, ivyConfName, md) // add all supplied maven artifacts as dependencies - addDependenciesToIvy(md, artifacts, ivyConfName) + addDependenciesToIvy(md, artifacts, ivySettings, ivyConfName) exclusions.foreach { e => md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName)) } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index a0f09891787e..c069331248b1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy import java.io.{File, OutputStream, PrintStream} import java.nio.charset.StandardCharsets -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, Map} import com.google.common.io.Files import org.apache.ivy.core.module.descriptor.MDArtifact @@ -40,6 +40,28 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { def write(b: Int) = {} } + private def loadIvySettings(dummyIvyLocal: File): IvySettings = { + val settingsText = + s""" + | + | + | + | + | + | + | + | + | + | + |""".stripMargin + + val settingsFile = new File(tempIvyPath, "ivysettings.xml") + Files.write(settingsText, settingsFile, StandardCharsets.UTF_8) + SparkSubmitUtils.loadIvySettings(settingsFile.toString, None, None) + } + /** Simple PrintStream that reads data into a buffer */ private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() @@ -93,11 +115,13 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("add dependencies works correctly") { + val repos = "a/1,b/2,c/3" + val settings = SparkSubmitUtils.buildIvySettings(Option(repos), None) val md = SparkSubmitUtils.getModuleDescriptor val artifacts = SparkSubmitUtils.extractMavenCoordinates("com.databricks:spark-csv_2.11:0.1," + "com.databricks:spark-avro_2.11:0.1") - SparkSubmitUtils.addDependenciesToIvy(md, artifacts, "default") + SparkSubmitUtils.addDependenciesToIvy(md, artifacts, settings, "default") assert(md.getDependencies.length === 2) } @@ -225,25 +249,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { val main = new MavenCoordinate("my.great.lib", "mylib", "0.1") val dep = "my.great.dep:mydep:0.5" val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator) - val settingsText = - s""" - | - | - | - | - | - | - | - | - | - | - |""".stripMargin - - val settingsFile = new File(tempIvyPath, "ivysettings.xml") - Files.write(settingsText, settingsFile, StandardCharsets.UTF_8) - val settings = SparkSubmitUtils.loadIvySettings(settingsFile.toString, None, None) + val settings = loadIvySettings(dummyIvyLocal) settings.setDefaultIvyUserDir(new File(tempIvyPath)) // NOTE - can't set this through file val testUtilSettings = new IvySettings @@ -271,4 +277,59 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { .exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned") } } + + test("test artifact with no transitive dependencies") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("transitive" -> "false")) + val dep = "my.great.dep:mydep:0.5" + + IvyTestUtils.withRepository(main, Some(dep), None) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + SparkSubmitUtils.buildIvySettings(Some(repo), None), + isTest = true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + assert(jarPath.indexOf("mydep") < 0, "should not transitive dependency") + } + } + + test("test artifact excluding specific dependencies") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("exclude" -> "mydep")) + val dep = "my.great.dep:mydep:0.5" + + IvyTestUtils.withRepository(main, Some(dep), None) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + SparkSubmitUtils.buildIvySettings(Some(repo), None), + isTest = true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact") + assert(jarPath.indexOf("mydep") < 0, "should not find excluded dependency") + } + } + + test("test artifact with classifier") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("classifier" -> "test")) + + IvyTestUtils.withRepository(main, None, None) { repo => + val jarPath = SparkSubmitUtils.resolveMavenCoordinates( + main.toString, + SparkSubmitUtils.buildIvySettings(Some(repo), None), + isTest = true) + assert(jarPath.indexOf("mylib-0.1-test") >= 0, "should find artifact with classifier") + } + } + + test("test artifact with conf") { + val main = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("conf" -> "master")) + val badMain = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("conf" -> "badconf")) + + IvyTestUtils.withRepository(main, None, None) { repo => + val settings = SparkSubmitUtils.buildIvySettings(Some(repo), None) + val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, settings, isTest = true) + assert(jarPath.indexOf("mylib") >= 0, "should find artifact with good conf") + // artifact with bad conf should fail on runtime exception: configuration 'badconf' not found + intercept[RuntimeException] { + SparkSubmitUtils.resolveMavenCoordinates(badMain.toString, settings, isTest = true) + } + } + } }