diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a6761..ae4aeb7b4ce4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,14 +58,19 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { + try { + getFileFromUrl(url, path, filename) return + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } @@ -85,6 +93,34 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new File(tmpDataDir, name).getCanonicalPath } + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration + + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } + + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } + + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() + + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) + + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) + } + override def beforeAll(): Unit = { super.beforeAll()