diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 324f6f8894d3..d471bb214390 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -23,6 +23,7 @@ import java.net.URL
import java.security.PrivilegedExceptionAction
import java.text.ParseException
import java.util.UUID
+import java.util.Collections
import scala.annotation.tailrec
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
@@ -1000,33 +1001,81 @@ private[spark] object SparkSubmitUtils {
"tags_", "unsafe_")
/**
- * Represents a Maven Coordinate
- * @param groupId the groupId of the coordinate
- * @param artifactId the artifactId of the coordinate
- * @param version the version of the coordinate
+ * Represents a Artifact Coordinate to resolve
+ * @param groupId the groupId of the artifact
+ * @param artifactId the artifactId of the artifact
+ * @param version the version of the artifact
+ * @param extraParams the extra params to resolve the artifact
*/
- private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String) {
- override def toString: String = s"$groupId:$artifactId:$version"
+ private[deploy] case class MavenCoordinate(groupId: String, artifactId: String, version: String,
+ extraParams: Map[String, String] = Map()) {
+
+ def params: String = if (extraParams.isEmpty) {
+ ""
+ } else {
+ "?" + extraParams.map{ case (k, v) => (k + "=" + v) }.mkString(":")
+ }
+
+ override def toString: String = s"$groupId:$artifactId:$version$params"
}
/**
- * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided
- * in the format `groupId:artifactId:version` or `groupId/artifactId:version`.
- * @param coordinates Comma-delimited string of maven coordinates
- * @return Sequence of Maven coordinates
+ * Extracts artifact coordinates from a comma-delimited string. Coordinates should be provided
+ * in the format `groupId:artifactId:version?param1=value1\¶m2\&value2:..` or
+ * `groupId/artifactId:version?param1=value1\¶m2=value2:..`
+ *
+ * Param splitter & is the background process char in cli, so when multiple params is used,
+ * either & should to be escaped or the value of --packages should be enclosed in double quotes
+ *
+ * Optional params are 'classifier', 'transitive', 'exclude', 'conf':
+ * classifier: classifier of the artifact
+ * transitive: whether to resolve transitive deps for the artifact
+ * exlude: exclude list of transitive artifacts for this artifact(e.g. "a#b#c")
+ * conf: the conf of the artifact
+ *
+ * @param coordinates Comma-delimited string of artifact coordinates
+ * @return Sequence of Artifact coordinates
*/
def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = {
coordinates.split(",").map { p =>
- val splits = p.replace("/", ":").split(":")
- require(splits.length == 3, s"Provided Maven Coordinates must be in the form " +
- s"'groupId:artifactId:version'. The coordinate provided is: $p")
- require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " +
- s"be whitespace. The groupId provided is: ${splits(0)}")
- require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " +
- s"be whitespace. The artifactId provided is: ${splits(1)}")
- require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " +
- s"be whitespace. The version provided is: ${splits(2)}")
- new MavenCoordinate(splits(0), splits(1), splits(2))
+ val errMsg = s"Provided Artifact Coordinates must be in the form " +
+ s"'groupId:artifactId:version?param1=a\\¶m2=b\\¶m3=c'. Optional params are" +
+ s"'classifier', 'transitive', 'exclude', 'conf'. The coordinate provided is: $p"
+ // Split artifact coordinate and params
+ val parts = p.replace("/", ":").split("\\?")
+ require(parts.length == 1 || parts.length == 2, errMsg)
+ // Parse coordinate 'groupId:artifactId:version'
+ val coords = parts(0).split(":")
+ require(coords.length == 3, errMsg)
+ require(coords(0) != null && coords(0).trim.nonEmpty, s"The groupId cannot be null or " +
+ s"be whitespace. The groupId provided is: ${coords(0)}. ${errMsg}")
+ require(coords(1) != null && coords(1).trim.nonEmpty, s"The artifactId cannot be null or " +
+ s"be whitespace. The artifactId provided is: ${coords(1)}. ${errMsg}")
+ require(coords(2) != null && coords(2).trim.nonEmpty, s"The version cannot be null or " +
+ s"be whitespace. The version provided is: ${coords(2)}. ${errMsg}")
+ if (parts.length == 1) {
+ new MavenCoordinate(coords(0), coords(1), coords(2))
+ } else {
+ // Parse params 'param1=a\¶m2=b\¶m3=c'
+ val params = parts(1).split("\\&")
+ var paramMap = Map[String, String]()
+ for (i <- 0 until params.length) {
+ require(params(i) != null && params(i).trim.nonEmpty, errMsg)
+ val param = params(i).split("=")
+ require(param.length == 2, errMsg)
+ require(param(0) != null && param(0).trim.nonEmpty, s"The param key " +
+ s"cannot be null or be whitespace. The key provided is: ${param(0)}. ${errMsg}")
+ require(param(1) != null && param(1).trim.nonEmpty, s"The param value " +
+ s"cannot be null or be whitespace. The value provided is: ${param(1)}. ${errMsg}")
+
+ if (Set("classifier", "transitive", "conf", "exclude").contains(param(0))) {
+ paramMap += (param(0) -> param(1))
+ } else {
+ throw new RuntimeException(errMsg)
+ }
+ }
+ new MavenCoordinate(coords(0), coords(1), coords(2), paramMap)
+ }
}
}
@@ -1098,20 +1147,51 @@ private[spark] object SparkSubmitUtils {
cacheDirectory: File): String = {
artifacts.map { artifactInfo =>
val artifact = artifactInfo.asInstanceOf[Artifact].getModuleRevisionId
+ val classifier = artifactInfo.asInstanceOf[Artifact].getId.getExtraAttribute("classifier")
+ val suffix = if (classifier == null) "" else s"-${classifier}"
cacheDirectory.getAbsolutePath + File.separator +
- s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}.jar"
+ s"${artifact.getOrganisation}_${artifact.getName}-${artifact.getRevision}${suffix}.jar"
}.mkString(",")
}
- /** Adds the given maven coordinates to Ivy's module descriptor. */
+ /** Adds the given artifact coordinates to Ivy's module descriptor. */
def addDependenciesToIvy(
md: DefaultModuleDescriptor,
artifacts: Seq[MavenCoordinate],
+ ivySettings: IvySettings,
ivyConfName: String): Unit = {
- artifacts.foreach { mvn =>
- val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version)
- val dd = new DefaultDependencyDescriptor(ri, false, false)
+ artifacts.foreach { art =>
+ val ri = ModuleRevisionId.newInstance(art.groupId, art.artifactId, art.version)
+ val dd = art.extraParams.get("transitive") match {
+ case Some(t) => new DefaultDependencyDescriptor(null, ri, false, false, t.toBoolean)
+ case None => new DefaultDependencyDescriptor(ri, false, false)
+ }
dd.addDependencyConfiguration(ivyConfName, ivyConfName + "(runtime)")
+
+ art.extraParams.foreach { case (param, pvalue) =>
+ param match {
+ // Exclude dependencies(name separated by #) for this artifact
+ case "exclude" => pvalue.split("#").foreach { ex =>
+ dd.addExcludeRule(ivyConfName,
+ createExclusion("*:*" + ex + "*:*", ivySettings, ivyConfName))}
+
+ // Add ivy conf for the artifact to default conf, so it can be resolved with default
+ // e.g. ivy:
+ case "conf" => dd.addDependencyConfiguration(ivyConfName, pvalue)
+
+ // If this artifact has classifier, add descriptor with classifier so it can be resolved
+ // e.g. ivy:
+ case "classifier" =>
+ val dad = new DefaultDependencyArtifactDescriptor(dd, art.artifactId, "jar", "jar",
+ null, Collections.singletonMap(param, pvalue))
+ dad.addConfiguration(art.extraParams.getOrElse("conf", ""))
+ dd.addDependencyArtifact(ivyConfName, dad)
+
+ // Already used, ignore
+ case "transitive" =>
+ }
+ }
+
// scalastyle:off println
printStream.println(s"${dd.getDependencyId} added as a dependency")
// scalastyle:on println
@@ -1245,11 +1325,11 @@ private[spark] object SparkSubmitUtils {
}
/**
- * Resolves any dependencies that were supplied through maven coordinates
- * @param coordinates Comma-delimited string of maven coordinates
+ * Resolves any dependencies that were supplied through artifact coordinates
+ * @param coordinates Comma-delimited string of artifact coordinates
* @param ivySettings An IvySettings containing resolvers to use
* @param exclusions Exclusions to apply when resolving transitive dependencies
- * @return The comma-delimited path to the jars of the given maven artifacts including their
+ * @return The comma-delimited path to the jars of the given artifacts including their
* transitive dependencies
*/
def resolveMavenCoordinates(
@@ -1299,7 +1379,7 @@ private[spark] object SparkSubmitUtils {
// Add exclusion rules for Spark and Scala Library
addExclusionRules(ivySettings, ivyConfName, md)
// add all supplied maven artifacts as dependencies
- addDependenciesToIvy(md, artifacts, ivyConfName)
+ addDependenciesToIvy(md, artifacts, ivySettings, ivyConfName)
exclusions.foreach { e =>
md.addExcludeRule(createExclusion(e + ":*", ivySettings, ivyConfName))
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
index a0f09891787e..c069331248b1 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.deploy
import java.io.{File, OutputStream, PrintStream}
import java.nio.charset.StandardCharsets
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.{ArrayBuffer, Map}
import com.google.common.io.Files
import org.apache.ivy.core.module.descriptor.MDArtifact
@@ -40,6 +40,28 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
def write(b: Int) = {}
}
+ private def loadIvySettings(dummyIvyLocal: File): IvySettings = {
+ val settingsText =
+ s"""
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |
+ |""".stripMargin
+
+ val settingsFile = new File(tempIvyPath, "ivysettings.xml")
+ Files.write(settingsText, settingsFile, StandardCharsets.UTF_8)
+ SparkSubmitUtils.loadIvySettings(settingsFile.toString, None, None)
+ }
+
/** Simple PrintStream that reads data into a buffer */
private class BufferPrintStream extends PrintStream(noOpOutputStream) {
var lineBuffer = ArrayBuffer[String]()
@@ -93,11 +115,13 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test("add dependencies works correctly") {
+ val repos = "a/1,b/2,c/3"
+ val settings = SparkSubmitUtils.buildIvySettings(Option(repos), None)
val md = SparkSubmitUtils.getModuleDescriptor
val artifacts = SparkSubmitUtils.extractMavenCoordinates("com.databricks:spark-csv_2.11:0.1," +
"com.databricks:spark-avro_2.11:0.1")
- SparkSubmitUtils.addDependenciesToIvy(md, artifacts, "default")
+ SparkSubmitUtils.addDependenciesToIvy(md, artifacts, settings, "default")
assert(md.getDependencies.length === 2)
}
@@ -225,25 +249,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
val main = new MavenCoordinate("my.great.lib", "mylib", "0.1")
val dep = "my.great.dep:mydep:0.5"
val dummyIvyLocal = new File(tempIvyPath, "local" + File.separator)
- val settingsText =
- s"""
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |
- |""".stripMargin
-
- val settingsFile = new File(tempIvyPath, "ivysettings.xml")
- Files.write(settingsText, settingsFile, StandardCharsets.UTF_8)
- val settings = SparkSubmitUtils.loadIvySettings(settingsFile.toString, None, None)
+ val settings = loadIvySettings(dummyIvyLocal)
settings.setDefaultIvyUserDir(new File(tempIvyPath)) // NOTE - can't set this through file
val testUtilSettings = new IvySettings
@@ -271,4 +277,59 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll {
.exists(r.findFirstIn(_).isDefined), "resolution files should be cleaned")
}
}
+
+ test("test artifact with no transitive dependencies") {
+ val main = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("transitive" -> "false"))
+ val dep = "my.great.dep:mydep:0.5"
+
+ IvyTestUtils.withRepository(main, Some(dep), None) { repo =>
+ val jarPath = SparkSubmitUtils.resolveMavenCoordinates(
+ main.toString,
+ SparkSubmitUtils.buildIvySettings(Some(repo), None),
+ isTest = true)
+ assert(jarPath.indexOf("mylib") >= 0, "should find artifact")
+ assert(jarPath.indexOf("mydep") < 0, "should not transitive dependency")
+ }
+ }
+
+ test("test artifact excluding specific dependencies") {
+ val main = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("exclude" -> "mydep"))
+ val dep = "my.great.dep:mydep:0.5"
+
+ IvyTestUtils.withRepository(main, Some(dep), None) { repo =>
+ val jarPath = SparkSubmitUtils.resolveMavenCoordinates(
+ main.toString,
+ SparkSubmitUtils.buildIvySettings(Some(repo), None),
+ isTest = true)
+ assert(jarPath.indexOf("mylib") >= 0, "should find artifact")
+ assert(jarPath.indexOf("mydep") < 0, "should not find excluded dependency")
+ }
+ }
+
+ test("test artifact with classifier") {
+ val main = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("classifier" -> "test"))
+
+ IvyTestUtils.withRepository(main, None, None) { repo =>
+ val jarPath = SparkSubmitUtils.resolveMavenCoordinates(
+ main.toString,
+ SparkSubmitUtils.buildIvySettings(Some(repo), None),
+ isTest = true)
+ assert(jarPath.indexOf("mylib-0.1-test") >= 0, "should find artifact with classifier")
+ }
+ }
+
+ test("test artifact with conf") {
+ val main = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("conf" -> "master"))
+ val badMain = new MavenCoordinate("my.great.lib", "mylib", "0.1", Map("conf" -> "badconf"))
+
+ IvyTestUtils.withRepository(main, None, None) { repo =>
+ val settings = SparkSubmitUtils.buildIvySettings(Some(repo), None)
+ val jarPath = SparkSubmitUtils.resolveMavenCoordinates(main.toString, settings, isTest = true)
+ assert(jarPath.indexOf("mylib") >= 0, "should find artifact with good conf")
+ // artifact with bad conf should fail on runtime exception: configuration 'badconf' not found
+ intercept[RuntimeException] {
+ SparkSubmitUtils.resolveMavenCoordinates(badMain.toString, settings, isTest = true)
+ }
+ }
+ }
}