diff --git a/LICENSE b/LICENSE
index 66a2e8f132953..b948ccaeecea6 100644
--- a/LICENSE
+++ b/LICENSE
@@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt.
(New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf)
(The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net)
(The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net)
- (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/)
+ (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/)
(Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/)
(BSD licence) sbt and sbt-launch-lib.bash
(BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE)
diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R
index 9d82814211bc5..7244cc9f9e38e 100644
--- a/R/pkg/R/client.R
+++ b/R/pkg/R/client.R
@@ -19,7 +19,7 @@
# Creates a SparkR client connection object
# if one doesn't already exist
-connectBackend <- function(hostname, port, timeout) {
+connectBackend <- function(hostname, port, timeout, authSecret) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
cat("SparkRBackend client connection already exists\n")
@@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) {
con <- socketConnection(host = hostname, port = port, server = FALSE,
blocking = TRUE, open = "wb", timeout = timeout)
-
+ doServerAuth(con, authSecret)
assign(".sparkRCon", con, envir = .sparkREnv)
con
}
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index 0e99b171cabeb..dc7d37e064b1d 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -60,14 +60,18 @@ readTypedObject <- function(con, type) {
stop(paste("Unsupported type for deserialization", type)))
}
-readString <- function(con) {
- stringLen <- readInt(con)
- raw <- readBin(con, raw(), stringLen, endian = "big")
+readStringData <- function(con, len) {
+ raw <- readBin(con, raw(), len, endian = "big")
string <- rawToChar(raw)
Encoding(string) <- "UTF-8"
string
}
+readString <- function(con) {
+ stringLen <- readInt(con)
+ readStringData(con, stringLen)
+}
+
readInt <- function(con) {
readBin(con, integer(), n = 1, endian = "big")
}
diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R
index 9ebd34411a1ea..daa855b2459f2 100644
--- a/R/pkg/R/sparkR.R
+++ b/R/pkg/R/sparkR.R
@@ -161,6 +161,10 @@ sparkR.sparkContext <- function(
" please use the --packages commandline instead", sep = ","))
}
backendPort <- existingPort
+ authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET")
+ if (nchar(authSecret) == 0) {
+ stop("Auth secret not provided in environment.")
+ }
} else {
path <- tempfile(pattern = "backend_port")
submitOps <- getClientModeSparkSubmitOpts(
@@ -189,16 +193,27 @@ sparkR.sparkContext <- function(
monitorPort <- readInt(f)
rLibPath <- readString(f)
connectionTimeout <- readInt(f)
+
+ # Don't use readString() so that we can provide a useful
+ # error message if the R and Java versions are mismatched.
+ authSecretLen = readInt(f)
+ if (length(authSecretLen) == 0 || authSecretLen == 0) {
+ stop("Unexpected EOF in JVM connection data. Mismatched versions?")
+ }
+ authSecret <- readStringData(f, authSecretLen)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
length(monitorPort) == 0 || monitorPort == 0 ||
- length(rLibPath) != 1) {
+ length(rLibPath) != 1 || length(authSecret) == 0) {
stop("JVM failed to launch")
}
- assign(".monitorConn",
- socketConnection(port = monitorPort, timeout = connectionTimeout),
- envir = .sparkREnv)
+
+ monitorConn <- socketConnection(port = monitorPort, blocking = TRUE,
+ timeout = connectionTimeout, open = "wb")
+ doServerAuth(monitorConn, authSecret)
+
+ assign(".monitorConn", monitorConn, envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
if (rLibPath != "") {
assign(".libPath", rLibPath, envir = .sparkREnv)
@@ -208,7 +223,7 @@ sparkR.sparkContext <- function(
.sparkREnv$backendPort <- backendPort
tryCatch({
- connectBackend("localhost", backendPort, timeout = connectionTimeout)
+ connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret)
},
error = function(err) {
stop("Failed to connect JVM\n")
@@ -632,3 +647,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) {
NULL
}
}
+
+# Utility function for sending auth data over a socket and checking the server's reply.
+doServerAuth <- function(con, authSecret) {
+ if (nchar(authSecret) == 0) {
+ stop("Auth secret not provided.")
+ }
+ writeString(con, authSecret)
+ flush(con)
+ reply <- readString(con)
+ if (reply != "ok") {
+ close(con)
+ stop("Unexpected reply from server.")
+ }
+}
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
index 3a318b71ea06d..ec9a8f1ee1c95 100644
--- a/R/pkg/inst/worker/daemon.R
+++ b/R/pkg/inst/worker/daemon.R
@@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR))
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
- port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)
+ port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout)
+
+SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
while (TRUE) {
ready <- socketSelect(list(inputCon))
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index 03e7450147865..eb6453fc16976 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR))
port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(
- port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
+ port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
+SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
+
outputCon <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
+SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET"))
# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)
diff --git a/bin/pyspark b/bin/pyspark
index 98387c2ec5b8a..95ab62880654f 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh
export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]"
# In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option
-# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
+# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython
# to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver
# (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython
# and executor Python executables.
# Fail noisily if removed options are set
if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then
- echo "Error in pyspark startup:"
+ echo "Error in pyspark startup:"
echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead."
exit 1
fi
@@ -57,7 +57,7 @@ export PYSPARK_PYTHON
# Add the PySpark classes to the Python path:
export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH"
-export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH"
+export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH"
# Load the PySpark shell.py script when ./pyspark is used interactively:
export OLD_PYTHONSTARTUP="$PYTHONSTARTUP"
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index d1ce9dabab029..15fa910c277b3 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" (
)
set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH%
-set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH%
+set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH%
set OLD_PYTHONSTARTUP=%PYTHONSTARTUP%
set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py
diff --git a/core/pom.xml b/core/pom.xml
index 0b632bca8cd96..6e50010b48f47 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -335,7 +335,7 @@
net.sf.py4j
py4j
- 0.10.4
+ 0.10.7
org.apache.spark
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 2480e56b72ccf..9814f1ae028de 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -17,13 +17,11 @@
package org.apache.spark
-import java.lang.{Byte => JByte}
import java.net.{Authenticator, PasswordAuthentication}
-import java.security.{KeyStore, SecureRandom}
+import java.security.KeyStore
import java.security.cert.X509Certificate
import javax.net.ssl._
-import com.google.common.hash.HashCodes
import com.google.common.io.Files
import org.apache.hadoop.io.Text
@@ -435,12 +433,7 @@ private[spark] class SecurityManager(
val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(SECRET_LOOKUP_KEY)
if (secretKey == null || secretKey.length == 0) {
logDebug("generateSecretKey: yarn mode, secret key from credentials is null")
- val rnd = new SecureRandom()
- val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE
- val secret = new Array[Byte](length)
- rnd.nextBytes(secret)
-
- val cookie = HashCodes.fromBytes(secret).toString()
+ val cookie = Utils.createSecret(sparkConf)
SparkHadoopUtil.get.addSecretKeyToUserCredentials(SECRET_LOOKUP_KEY, cookie)
cookie
} else {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
index 11f2432575d84..9ddc4a4910180 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
@@ -17,26 +17,39 @@
package org.apache.spark.api.python
-import java.io.DataOutputStream
-import java.net.Socket
+import java.io.{DataOutputStream, File, FileOutputStream}
+import java.net.InetAddress
+import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.file.Files
import py4j.GatewayServer
+import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
/**
- * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
- * back to its caller via a callback port specified by the caller.
+ * Process that starts a Py4J GatewayServer on an ephemeral port.
*
* This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
*/
private[spark] object PythonGatewayServer extends Logging {
initializeLogIfNecessary(true)
- def main(args: Array[String]): Unit = Utils.tryOrExit {
- // Start a GatewayServer on an ephemeral port
- val gatewayServer: GatewayServer = new GatewayServer(null, 0)
+ def main(args: Array[String]): Unit = {
+ val secret = Utils.createSecret(new SparkConf())
+
+ // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured
+ // with the same secret, in case the app needs callbacks from the JVM to the underlying
+ // python processes.
+ val localhost = InetAddress.getLoopbackAddress()
+ val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder()
+ .authToken(secret)
+ .javaPort(0)
+ .javaAddress(localhost)
+ .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
+ .build()
+
gatewayServer.start()
val boundPort: Int = gatewayServer.getListeningPort
if (boundPort == -1) {
@@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging {
logDebug(s"Started PythonGatewayServer on port $boundPort")
}
- // Communicate the bound port back to the caller via the caller-specified callback port
- val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
- val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
- logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
- val callbackSocket = new Socket(callbackHost, callbackPort)
- val dos = new DataOutputStream(callbackSocket.getOutputStream)
+ // Communicate the connection information back to the python process by writing the
+ // information in the requested file. This needs to match the read side in java_gateway.py.
+ val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH"))
+ val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(),
+ "connection", ".info").toFile()
+
+ val dos = new DataOutputStream(new FileOutputStream(tmpPath))
dos.writeInt(boundPort)
+
+ val secretBytes = secret.getBytes(UTF_8)
+ dos.writeInt(secretBytes.length)
+ dos.write(secretBytes, 0, secretBytes.length)
dos.close()
- callbackSocket.close()
+
+ if (!tmpPath.renameTo(connectionInfoPath)) {
+ logError(s"Unable to write connection information to $connectionInfoPath.")
+ System.exit(1)
+ }
// Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
while (System.in.read() != -1) {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 63ae705f9c97e..0662792cd1a75 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.input.PortableDataStream
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
+import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util._
@@ -421,6 +422,12 @@ private[spark] object PythonRDD extends Logging {
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
+ // Authentication helper used when serving iterator data.
+ private lazy val authHelper = {
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ new SocketAuthHelper(conf)
+ }
+
def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = {
synchronized {
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
@@ -443,12 +450,13 @@ private[spark] object PythonRDD extends Logging {
* (effectively a collect()), but allows you to run on a certain subset of partitions,
* or to enable local execution.
*
- * @return the port number of a local socket which serves the data collected from this job.
+ * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
+ * data collected from this job, and the secret for authentication.
*/
def runJob(
sc: SparkContext,
rdd: JavaRDD[Array[Byte]],
- partitions: JArrayList[Int]): Int = {
+ partitions: JArrayList[Int]): Array[Any] = {
type ByteArray = Array[Byte]
type UnrolledPartition = Array[ByteArray]
val allPartitions: Array[UnrolledPartition] =
@@ -461,13 +469,14 @@ private[spark] object PythonRDD extends Logging {
/**
* A helper function to collect an RDD as an iterator, then serve it via socket.
*
- * @return the port number of a local socket which serves the data collected from this job.
+ * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
+ * data collected from this job, and the secret for authentication.
*/
- def collectAndServe[T](rdd: RDD[T]): Int = {
+ def collectAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
- def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = {
+ def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
}
@@ -698,8 +707,11 @@ private[spark] object PythonRDD extends Logging {
* and send them into this connection.
*
* The thread will terminate after all the data are sent or any exceptions happen.
+ *
+ * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
+ * data collected from this job, and the secret for authentication.
*/
- def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+ def serveIterator(items: Iterator[_], threadName: String): Array[Any] = {
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
// Close the socket if no connection in 15 seconds
serverSocket.setSoTimeout(15000)
@@ -709,11 +721,14 @@ private[spark] object PythonRDD extends Logging {
override def run() {
try {
val sock = serverSocket.accept()
+ authHelper.authClient(sock)
+
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
Utils.tryWithSafeFinally {
writeIteratorToStream(items, out)
} {
out.close()
+ sock.close()
}
} catch {
case NonFatal(e) =>
@@ -724,7 +739,7 @@ private[spark] object PythonRDD extends Logging {
}
}.start()
- serverSocket.getLocalPort
+ Array(serverSocket.getLocalPort, authHelper.secret)
}
private def getMergedConf(confAsMap: java.util.HashMap[String, String],
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index c4e55b5e89027..27a5e19f96a14 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -32,7 +32,7 @@ private[spark] object PythonUtils {
val pythonPath = new ArrayBuffer[String]
for (sparkHome <- sys.env.get("SPARK_HOME")) {
pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator)
- pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator)
+ pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator)
}
pythonPath ++= SparkContext.jarOfObject(this)
pythonPath.mkString(File.pathSeparator)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 6a5e6f7c5afb1..9c667847382cd 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -27,6 +27,7 @@ import scala.collection.JavaConverters._
import org.apache.spark._
import org.apache.spark.internal.Logging
+import org.apache.spark.security.SocketAuthHelper
import org.apache.spark.util.{RedirectThread, Utils}
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
@@ -40,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// also fall back to launching workers (pyspark/worker.py) directly.
val useDaemon = !System.getProperty("os.name").startsWith("Windows")
+
+ private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
+
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
@@ -80,6 +84,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
if (pid < 0) {
throw new IllegalStateException("Python daemon failed to launch worker with code " + pid)
}
+
+ authHelper.authToServer(socket)
daemonWorkers.put(socket, pid)
socket
}
@@ -117,25 +123,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
workerEnv.put("PYTHONPATH", pythonPath)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
+ workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString)
+ workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
val worker = pb.start()
// Redirect worker stdout and stderr
redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream)
- // Tell the worker our port
- val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8)
- out.write(serverSocket.getLocalPort + "\n")
- out.flush()
-
- // Wait for it to connect to our socket
+ // Wait for it to connect to our socket, and validate the auth secret.
serverSocket.setSoTimeout(10000)
+
try {
val socket = serverSocket.accept()
+ authHelper.authClient(socket)
simpleWorkers.put(socket, worker)
return socket
} catch {
case e: Exception =>
- throw new SparkException("Python worker did not connect back in time", e)
+ throw new SparkException("Python worker failed to connect back.", e)
}
} finally {
if (serverSocket != null) {
@@ -158,6 +163,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
val workerEnv = pb.environment()
workerEnv.putAll(envVars.asJava)
workerEnv.put("PYTHONPATH", pythonPath)
+ workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
workerEnv.put("PYTHONUNBUFFERED", "YES")
daemon = pb.start()
@@ -167,7 +173,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Redirect daemon stdout and stderr
redirectStreamsToStderr(in, daemon.getErrorStream)
-
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
new file mode 100644
index 0000000000000..ac6826a9ec774
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{DataInputStream, DataOutputStream}
+import java.net.Socket
+
+import org.apache.spark.SparkConf
+import org.apache.spark.security.SocketAuthHelper
+
+private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) {
+
+ override protected def readUtf8(s: Socket): String = {
+ SerDe.readString(new DataInputStream(s.getInputStream()))
+ }
+
+ override protected def writeUtf8(str: String, s: Socket): Unit = {
+ val out = s.getOutputStream()
+ SerDe.writeString(new DataOutputStream(out), str)
+ out.flush()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 2d1152a036449..3b2e809408e0f 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -17,8 +17,8 @@
package org.apache.spark.api.r
-import java.io.{DataOutputStream, File, FileOutputStream, IOException}
-import java.net.{InetAddress, InetSocketAddress, ServerSocket}
+import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException}
+import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket}
import java.util.concurrent.TimeUnit
import io.netty.bootstrap.ServerBootstrap
@@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler
import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.Utils
/**
* Netty-based backend server that is used to communicate between R and Java.
@@ -45,7 +47,7 @@ private[spark] class RBackend {
/** Tracks JVM objects returned to R for this RBackend instance. */
private[r] val jvmObjectTracker = new JVMObjectTracker
- def init(): Int = {
+ def init(): (Int, RAuthHelper) = {
val conf = new SparkConf()
val backendConnectionTimeout = conf.getInt(
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
@@ -53,6 +55,7 @@ private[spark] class RBackend {
conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)
+ val authHelper = new RAuthHelper(conf)
bootstrap = new ServerBootstrap()
.group(bossGroup, workerGroup)
@@ -71,13 +74,16 @@ private[spark] class RBackend {
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
+ .addLast(new RBackendAuthHandler(authHelper.secret))
.addLast("handler", handler)
}
})
channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0))
channelFuture.syncUninterruptibly()
- channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
+
+ val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort()
+ (port, authHelper)
}
def run(): Unit = {
@@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging {
val sparkRBackend = new RBackend()
try {
// bind to random port
- val boundPort = sparkRBackend.init()
+ val (boundPort, authHelper) = sparkRBackend.init()
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val listenPort = serverSocket.getLocalPort()
// Connection timeout is set by socket client. To make it configurable we will pass the
@@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging {
dos.writeInt(listenPort)
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
dos.writeInt(backendConnectionTimeout)
+ SerDe.writeString(dos, authHelper.secret)
dos.close()
f.renameTo(new File(path))
@@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging {
val buf = new Array[Byte](1024)
// shutdown JVM if R does not connect back in 10 seconds
serverSocket.setSoTimeout(10000)
+
+ // Wait for the R process to connect back, ignoring any failed auth attempts. Allow
+ // a max number of connection attempts to avoid looping forever.
try {
- val inSocket = serverSocket.accept()
+ var remainingAttempts = 10
+ var inSocket: Socket = null
+ while (inSocket == null) {
+ inSocket = serverSocket.accept()
+ try {
+ authHelper.authClient(inSocket)
+ } catch {
+ case e: Exception =>
+ remainingAttempts -= 1
+ if (remainingAttempts == 0) {
+ val msg = "Too many failed authentication attempts."
+ logError(msg)
+ throw new IllegalStateException(msg)
+ }
+ logInfo("Client connection failed authentication.")
+ inSocket = null
+ }
+ }
+
serverSocket.close()
+
// wait for the end of socket, closed if R process die
inSocket.getInputStream().read(buf)
} finally {
+ serverSocket.close()
sparkRBackend.close()
System.exit(0)
}
@@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging {
}
System.exit(0)
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
new file mode 100644
index 0000000000000..4162e4a6c7476
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.r
+
+import java.io.{ByteArrayOutputStream, DataOutputStream}
+import java.nio.charset.StandardCharsets.UTF_8
+
+import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler}
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * Authentication handler for connections from the R process.
+ */
+private class RBackendAuthHandler(secret: String)
+ extends SimpleChannelInboundHandler[Array[Byte]] with Logging {
+
+ override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = {
+ // The R code adds a null terminator to serialized strings, so ignore it here.
+ val clientSecret = new String(msg, 0, msg.length - 1, UTF_8)
+ try {
+ require(secret == clientSecret, "Auth secret mismatch.")
+ ctx.pipeline().remove(this)
+ writeReply("ok", ctx.channel())
+ } catch {
+ case e: Exception =>
+ logInfo("Authentication failure.", e)
+ writeReply("err", ctx.channel())
+ ctx.close()
+ }
+ }
+
+ private def writeReply(reply: String, chan: Channel): Unit = {
+ val out = new ByteArrayOutputStream()
+ SerDe.writeString(new DataOutputStream(out), reply)
+ chan.writeAndFlush(out.toByteArray())
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
index 88118392003e8..e7fdc3963945a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala
@@ -74,14 +74,19 @@ private[spark] class RRunner[U](
// the socket used to send out the input of task
serverSocket.setSoTimeout(10000)
- val inSocket = serverSocket.accept()
- startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
-
- // the socket used to receive the output of task
- val outSocket = serverSocket.accept()
- val inputStream = new BufferedInputStream(outSocket.getInputStream)
- dataStream = new DataInputStream(inputStream)
- serverSocket.close()
+ dataStream = try {
+ val inSocket = serverSocket.accept()
+ RRunner.authHelper.authClient(inSocket)
+ startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex)
+
+ // the socket used to receive the output of task
+ val outSocket = serverSocket.accept()
+ RRunner.authHelper.authClient(outSocket)
+ val inputStream = new BufferedInputStream(outSocket.getInputStream)
+ new DataInputStream(inputStream)
+ } finally {
+ serverSocket.close()
+ }
try {
return new Iterator[U] {
@@ -315,6 +320,11 @@ private[r] object RRunner {
private[this] var errThread: BufferedStreamThread = _
private[this] var daemonChannel: DataOutputStream = _
+ private lazy val authHelper = {
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
+ new RAuthHelper(conf)
+ }
+
/**
* Start a thread to print the process's stderr to ours
*/
@@ -349,6 +359,7 @@ private[r] object RRunner {
pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString)
pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory())
pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE")
+ pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret)
pb.redirectErrorStream(true) // redirect stderr into stdout
val proc = pb.start()
val errThread = startStdoutThread(proc)
@@ -370,8 +381,12 @@ private[r] object RRunner {
// the socket used to send out the input of task
serverSocket.setSoTimeout(10000)
val sock = serverSocket.accept()
- daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
- serverSocket.close()
+ try {
+ authHelper.authClient(sock)
+ daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+ } finally {
+ serverSocket.close()
+ }
}
try {
daemonChannel.writeInt(port)
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index a8f732b11f6cf..b0f214a44ea42 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -18,7 +18,7 @@
package org.apache.spark.deploy
import java.io.File
-import java.net.URI
+import java.net.{InetAddress, URI}
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConverters._
@@ -39,6 +39,7 @@ object PythonRunner {
val pyFiles = args(1)
val otherArgs = args.slice(2, args.length)
val sparkConf = new SparkConf()
+ val secret = Utils.createSecret(sparkConf)
val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON)
.orElse(sparkConf.get(PYSPARK_PYTHON))
.orElse(sys.env.get("PYSPARK_DRIVER_PYTHON"))
@@ -51,7 +52,13 @@ object PythonRunner {
// Launch a Py4J gateway server for the process to connect to; this will let it see our
// Java system properties and such
- val gatewayServer = new py4j.GatewayServer(null, 0)
+ val localhost = InetAddress.getLoopbackAddress()
+ val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder()
+ .authToken(secret)
+ .javaPort(0)
+ .javaAddress(localhost)
+ .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret)
+ .build()
val thread = new Thread(new Runnable() {
override def run(): Unit = Utils.logUncaughtExceptions {
gatewayServer.start()
@@ -82,6 +89,7 @@ object PythonRunner {
// This is equivalent to setting the -u flag; we use it because ipython doesn't support -u:
env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string
env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort)
+ env.put("PYSPARK_GATEWAY_SECRET", secret)
// pass conf spark.pyspark.python to python process, the only way to pass info to
// python process is through environment variable.
sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _))
diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
index 6eb53a8252205..e86b362639e57 100644
--- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala
@@ -68,10 +68,13 @@ object RRunner {
// Java system properties etc.
val sparkRBackend = new RBackend()
@volatile var sparkRBackendPort = 0
+ @volatile var sparkRBackendSecret: String = null
val initialized = new Semaphore(0)
val sparkRBackendThread = new Thread("SparkR backend") {
override def run() {
- sparkRBackendPort = sparkRBackend.init()
+ val (port, authHelper) = sparkRBackend.init()
+ sparkRBackendPort = port
+ sparkRBackendSecret = authHelper.secret
initialized.release()
sparkRBackend.run()
}
@@ -91,6 +94,7 @@ object RRunner {
env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))
env.put("R_PROFILE_USER",
Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator))
+ env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret)
builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize
val process = builder.start()
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index da3cc04275154..db4c9f9d07e08 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -257,6 +257,11 @@ package object config {
.regexConf
.createOptional
+ private[spark] val AUTH_SECRET_BIT_LENGTH =
+ ConfigBuilder("spark.authenticate.secretBitLength")
+ .intConf
+ .createWithDefault(256)
+
private[spark] val NETWORK_AUTH_ENABLED =
ConfigBuilder("spark.authenticate")
.booleanConf
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 1b6bc9139f9c9..df6407b84195c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -693,6 +693,20 @@ private[spark] class TaskSchedulerImpl private[scheduler](
}
}
+ /**
+ * Marks the task has completed in all TaskSetManagers for the given stage.
+ *
+ * After stage failure and retry, there may be multiple TaskSetManagers for the stage.
+ * If an earlier attempt of a stage completes a task, we should ensure that the later attempts
+ * do not also submit those same tasks. That also means that a task completion from an earlier
+ * attempt can lead to the entire stage getting marked as successful.
+ */
+ private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = {
+ taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
+ tsm.markPartitionCompleted(partitionId)
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 2f4e46c7ec8f1..d9515fb27229e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -74,6 +74,8 @@ private[spark] class TaskSetManager(
val ser = env.closureSerializer.newInstance()
val tasks = taskSet.tasks
+ private[scheduler] val partitionToIndex = tasks.zipWithIndex
+ .map { case (t, idx) => t.partitionId -> idx }.toMap
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
@@ -149,7 +151,7 @@ private[spark] class TaskSetManager(
private[scheduler] val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
- private val taskInfos = new HashMap[Long, TaskInfo]
+ private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
// Use a MedianHeap to record durations of successful tasks so we know when to launch
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
@@ -744,6 +746,9 @@ private[spark] class TaskSetManager(
logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
" because task " + index + " has already completed successfully")
}
+ // There may be multiple tasksets for this stage -- we let all of them know that the partition
+ // was completed. This may result in some of the tasksets getting completed.
+ sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
@@ -754,6 +759,19 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}
+ private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
+ partitionToIndex.get(partitionId).foreach { index =>
+ if (!successful(index)) {
+ tasksSuccessful += 1
+ successful(index) = true
+ if (tasksSuccessful == numTasks) {
+ isZombie = true
+ }
+ maybeFinishTaskSet()
+ }
+ }
+ }
+
/**
* Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
* DAG Scheduler.
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
new file mode 100644
index 0000000000000..d15e7937b0523
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.security
+
+import java.io.{DataInputStream, DataOutputStream, InputStream}
+import java.net.Socket
+import java.nio.charset.StandardCharsets.UTF_8
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.JavaUtils
+import org.apache.spark.util.Utils
+
+/**
+ * A class that can be used to add a simple authentication protocol to socket-based communication.
+ *
+ * The protocol is simple: an auth secret is written to the socket, and the other side checks the
+ * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is
+ * not expected to be valid anymore.
+ *
+ * There's no secrecy, so this relies on the sockets being either local or somehow encrypted.
+ */
+private[spark] class SocketAuthHelper(conf: SparkConf) {
+
+ val secret = Utils.createSecret(conf)
+
+ /**
+ * Read the auth secret from the socket and compare to the expected value. Write the reply back
+ * to the socket.
+ *
+ * If authentication fails, this method will close the socket.
+ *
+ * @param s The client socket.
+ * @throws IllegalArgumentException If authentication fails.
+ */
+ def authClient(s: Socket): Unit = {
+ // Set the socket timeout while checking the auth secret. Reset it before returning.
+ val currentTimeout = s.getSoTimeout()
+ try {
+ s.setSoTimeout(10000)
+ val clientSecret = readUtf8(s)
+ if (secret == clientSecret) {
+ writeUtf8("ok", s)
+ } else {
+ writeUtf8("err", s)
+ JavaUtils.closeQuietly(s)
+ }
+ } finally {
+ s.setSoTimeout(currentTimeout)
+ }
+ }
+
+ /**
+ * Authenticate with a server by writing the auth secret and checking the server's reply.
+ *
+ * If authentication fails, this method will close the socket.
+ *
+ * @param s The socket connected to the server.
+ * @throws IllegalArgumentException If authentication fails.
+ */
+ def authToServer(s: Socket): Unit = {
+ writeUtf8(secret, s)
+
+ val reply = readUtf8(s)
+ if (reply != "ok") {
+ JavaUtils.closeQuietly(s)
+ throw new IllegalArgumentException("Authentication failed.")
+ }
+ }
+
+ protected def readUtf8(s: Socket): String = {
+ val din = new DataInputStream(s.getInputStream())
+ val len = din.readInt()
+ val bytes = new Array[Byte](len)
+ din.readFully(bytes)
+ new String(bytes, UTF_8)
+ }
+
+ protected def writeUtf8(str: String, s: Socket): Unit = {
+ val bytes = str.getBytes(UTF_8)
+ val dout = new DataOutputStream(s.getOutputStream())
+ dout.writeInt(bytes.length)
+ dout.write(bytes, 0, bytes.length)
+ dout.flush()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
index 603c23abb6895..5df17ccb627a3 100644
--- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
+++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala
@@ -482,7 +482,9 @@ class LegacyAccumulatorWrapper[R, T](
param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] {
private[spark] var _value = initialValue // Current value on driver
- override def isZero: Boolean = _value == param.zero(initialValue)
+ @transient private lazy val _zero = param.zero(initialValue)
+
+ override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef])
override def copy(): LegacyAccumulatorWrapper[R, T] = {
val acc = new LegacyAccumulatorWrapper(initialValue, param)
@@ -491,7 +493,7 @@ class LegacyAccumulatorWrapper[R, T](
}
override def reset(): Unit = {
- _value = param.zero(initialValue)
+ _value = _zero
}
override def add(v: T): Unit = _value = param.addAccumulator(_value, v)
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index a5f44bd746f14..6bcaf102d9680 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import java.io._
+import java.lang.{Byte => JByte}
import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo}
import java.math.{MathContext, RoundingMode}
import java.net._
@@ -25,6 +26,7 @@ import java.nio.ByteBuffer
import java.nio.channels.{Channels, FileChannel}
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
+import java.security.SecureRandom
import java.util.{Locale, Properties, Random, UUID}
import java.util.concurrent._
import java.util.concurrent.atomic.AtomicBoolean
@@ -43,6 +45,7 @@ import scala.util.matching.Regex
import _root_.io.netty.channel.unix.Errors.NativeIoException
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
+import com.google.common.hash.HashCodes
import com.google.common.io.{ByteStreams, Files => GFiles}
import com.google.common.net.InetAddresses
import org.apache.commons.lang3.SystemUtils
@@ -2658,6 +2661,14 @@ private[spark] object Utils extends Logging {
redact(redactionPattern, kvs.toArray)
}
+ def createSecret(conf: SparkConf): String = {
+ val bits = conf.get(AUTH_SECRET_BIT_LENGTH)
+ val rnd = new SecureRandom()
+ val secretBytes = new Array[Byte](bits / JByte.SIZE)
+ rnd.nextBytes(secretBytes)
+ HashCodes.fromBytes(secretBytes).toString()
+ }
+
}
private[util] object CallerContext extends Logging {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
index 8b9d45f734cda..38a4f4087873a 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala
@@ -910,4 +910,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
taskScheduler.initialize(new FakeSchedulerBackend)
}
}
+
+ test("Completions in zombie tasksets update status of non-zombie taskset") {
+ val taskScheduler = setupSchedulerWithMockTaskSetBlacklist()
+ val valueSer = SparkEnv.get.serializer.newInstance()
+
+ def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = {
+ val indexInTsm = tsm.partitionToIndex(partition)
+ val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head
+ val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
+ tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result)
+ }
+
+ // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt,
+ // two times, so we have three active task sets for one stage. (For this to really happen,
+ // you'd need the previous stage to also get restarted, and then succeed, in between each
+ // attempt, but that happens outside what we're mocking here.)
+ val zombieAttempts = (0 until 2).map { stageAttempt =>
+ val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt)
+ taskScheduler.submitTasks(attempt)
+ val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get
+ val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
+ taskScheduler.resourceOffers(offers)
+ assert(tsm.runningTasks === 10)
+ // fail attempt
+ tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED,
+ FetchFailed(null, 0, 0, 0, "fetch failed"))
+ // the attempt is a zombie, but the tasks are still running (this could be true even if
+ // we actively killed those tasks, as killing is best-effort)
+ assert(tsm.isZombie)
+ assert(tsm.runningTasks === 9)
+ tsm
+ }
+
+ // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for
+ // the stage, but this time with insufficient resources so not all tasks are active.
+
+ val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2)
+ taskScheduler.submitTasks(finalAttempt)
+ val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get
+ val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) }
+ val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task =>
+ finalAttempt.tasks(task.index).partitionId
+ }.toSet
+ assert(finalTsm.runningTasks === 5)
+ assert(!finalTsm.isZombie)
+
+ // We simulate late completions from our zombie tasksets, corresponding to all the pending
+ // partitions in our final attempt. This means we're only waiting on the tasks we've already
+ // launched.
+ val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions)
+ finalAttemptPendingPartitions.foreach { partition =>
+ completeTaskSuccessfully(zombieAttempts(0), partition)
+ }
+
+ // If there is another resource offer, we shouldn't run anything. Though our final attempt
+ // used to have pending tasks, now those tasks have been completed by zombie attempts. The
+ // remaining tasks to compute are already active in the non-zombie attempt.
+ assert(
+ taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty)
+
+ val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted
+
+ // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be
+ // marked as zombie.
+ // for each of the remaining tasks, find the tasksets with an active copy of the task, and
+ // finish the task.
+ remainingTasks.foreach { partition =>
+ val tsm = if (partition == 0) {
+ // we failed this task on both zombie attempts, this one is only present in the latest
+ // taskset
+ finalTsm
+ } else {
+ // should be active in every taskset. We choose a zombie taskset just to make sure that
+ // we transition the active taskset correctly even if the final completion comes
+ // from a zombie.
+ zombieAttempts(partition % 2)
+ }
+ completeTaskSuccessfully(tsm, partition)
+ }
+
+ assert(finalTsm.isZombie)
+
+ // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet
+ verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject())
+
+ // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything
+ // else succeeds, to make sure we get the right updates to the blacklist in all cases.
+ (zombieAttempts ++ Seq(finalTsm)).foreach { tsm =>
+ val stageAttempt = tsm.taskSet.stageAttemptId
+ tsm.runningTasksSet.foreach { index =>
+ if (stageAttempt == 1) {
+ tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost)
+ } else {
+ val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq())
+ tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result)
+ }
+ }
+
+ // we update the blacklist for the stage attempts with all successful tasks. Even though
+ // some tasksets had failures, we still consider them all successful from a blacklisting
+ // perspective, as the failures weren't from a problem w/ the tasks themselves.
+ verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject())
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
new file mode 100644
index 0000000000000..e57cb701b6284
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.security
+
+import java.io.Closeable
+import java.net._
+
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils
+
+class SocketAuthHelperSuite extends SparkFunSuite {
+
+ private val conf = new SparkConf()
+ private val authHelper = new SocketAuthHelper(conf)
+
+ test("successful auth") {
+ Utils.tryWithResource(new ServerThread()) { server =>
+ Utils.tryWithResource(server.createClient()) { client =>
+ authHelper.authToServer(client)
+ server.close()
+ server.join()
+ assert(server.error == null)
+ assert(server.authenticated)
+ }
+ }
+ }
+
+ test("failed auth") {
+ Utils.tryWithResource(new ServerThread()) { server =>
+ Utils.tryWithResource(server.createClient()) { client =>
+ val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128))
+ intercept[IllegalArgumentException] {
+ badHelper.authToServer(client)
+ }
+ server.close()
+ server.join()
+ assert(server.error != null)
+ assert(!server.authenticated)
+ }
+ }
+ }
+
+ private class ServerThread extends Thread with Closeable {
+
+ private val ss = new ServerSocket()
+ ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0))
+
+ @volatile var error: Exception = _
+ @volatile var authenticated = false
+
+ setDaemon(true)
+ start()
+
+ def createClient(): Socket = {
+ new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort())
+ }
+
+ override def run(): Unit = {
+ var clientConn: Socket = null
+ try {
+ clientConn = ss.accept()
+ authHelper.authClient(clientConn)
+ authenticated = true
+ } catch {
+ case e: Exception =>
+ error = e
+ } finally {
+ Option(clientConn).foreach(_.close())
+ }
+ }
+
+ override def close(): Unit = {
+ try {
+ ss.close()
+ } finally {
+ interrupt()
+ }
+ }
+
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
index a04644d57ed88..fe0a9a471a651 100644
--- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
+++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.util
import org.apache.spark._
+import org.apache.spark.serializer.JavaSerializer
class AccumulatorV2Suite extends SparkFunSuite {
@@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite {
assert(acc3.isZero)
assert(acc3.value === "")
}
+
+ test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") {
+ class MyData(val i: Int) extends Serializable
+ val param = new AccumulatorParam[MyData] {
+ override def zero(initialValue: MyData): MyData = new MyData(0)
+ override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i)
+ }
+
+ val acc = new LegacyAccumulatorWrapper(new MyData(0), param)
+ acc.metadata = AccumulatorMetadata(
+ AccumulatorContext.newId(),
+ Some("test"),
+ countFailedValues = false)
+ AccumulatorContext.register(acc)
+
+ val ser = new JavaSerializer(new SparkConf).newInstance()
+ ser.serialize(acc)
+ }
}
diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6
index 3ed8a0ec0579e..37feb62a9dc28 100644
--- a/dev/deps/spark-deps-hadoop-2.6
+++ b/dev/deps/spark-deps-hadoop-2.6
@@ -156,7 +156,7 @@ parquet-jackson-1.8.2.jar
pmml-model-1.2.15.jar
pmml-schema-1.2.15.jar
protobuf-java-2.5.0.jar
-py4j-0.10.4.jar
+py4j-0.10.7.jar
pyrolite-4.13.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7
index a1137385c0064..d90fb0c762bfd 100644
--- a/dev/deps/spark-deps-hadoop-2.7
+++ b/dev/deps/spark-deps-hadoop-2.7
@@ -157,7 +157,7 @@ parquet-jackson-1.8.2.jar
pmml-model-1.2.15.jar
pmml-schema-1.2.15.jar
protobuf-java-2.5.0.jar
-py4j-0.10.4.jar
+py4j-0.10.7.jar
pyrolite-4.13.jar
scala-compiler-2.11.8.jar
scala-library-2.11.8.jar
diff --git a/dev/run-pip-tests b/dev/run-pip-tests
index d51dde12a03c5..03fc83298dc2f 100755
--- a/dev/run-pip-tests
+++ b/dev/run-pip-tests
@@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do
source "$VIRTUALENV_PATH"/bin/activate
fi
# Upgrade pip & friends if using virutal env
- if [ ! -n "USE_CONDA" ]; then
+ if [ ! -n "$USE_CONDA" ]; then
pip install --upgrade pip pypandoc wheel numpy
fi
diff --git a/python/README.md b/python/README.md
index 0a5c8010b8486..d1f15681d7ebc 100644
--- a/python/README.md
+++ b/python/README.md
@@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c
## Python Requirements
-At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas).
\ No newline at end of file
+At its core PySpark depends on Py4J (currently version 0.10.7), but additional sub-packages have their own requirements (including numpy and pandas).
diff --git a/python/docs/Makefile b/python/docs/Makefile
index 5e4cfb8ab6fe3..b8e079483c90c 100644
--- a/python/docs/Makefile
+++ b/python/docs/Makefile
@@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build
PAPER ?=
BUILDDIR ?= _build
-export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.4-src.zip)
+export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip)
# User-friendly check for sphinx-build
ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1)
diff --git a/python/lib/py4j-0.10.4-src.zip b/python/lib/py4j-0.10.4-src.zip
deleted file mode 100644
index 8c3829e328726..0000000000000
Binary files a/python/lib/py4j-0.10.4-src.zip and /dev/null differ
diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip
new file mode 100644
index 0000000000000..128e321078793
Binary files /dev/null and b/python/lib/py4j-0.10.7-src.zip differ
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index ea58b3a93899e..47ba56b2cb108 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -989,8 +989,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
# by runJob() in order to avoid having to pass a Python lambda into
# SparkContext#runJob.
mappedRDD = rdd.mapPartitions(partitionFunc)
- port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
- return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
+ sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions)
+ return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer))
def show_profiles(self):
""" Print the profile stats to stdout """
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index 7f06d4288c872..e7d1e718c934a 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -29,7 +29,7 @@
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT
from pyspark.worker import main as worker_main
-from pyspark.serializers import read_int, write_int
+from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer
def compute_real_exit_code(exit_code):
@@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code):
return 1
-def worker(sock):
+def worker(sock, authenticated):
"""
Called by a worker process after the fork().
"""
@@ -56,6 +56,18 @@ def worker(sock):
# otherwise writes also cause a seek that makes us miss data on the read side.
infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536)
outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536)
+
+ if not authenticated:
+ client_secret = UTF8Deserializer().loads(infile)
+ if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret:
+ write_with_length("ok".encode("utf-8"), outfile)
+ outfile.flush()
+ else:
+ write_with_length("err".encode("utf-8"), outfile)
+ outfile.flush()
+ sock.close()
+ return 1
+
exit_code = 0
try:
worker_main(infile, outfile)
@@ -153,8 +165,11 @@ def handle_sigterm(*args):
write_int(os.getpid(), outfile)
outfile.flush()
outfile.close()
+ authenticated = False
while True:
- code = worker(sock)
+ code = worker(sock, authenticated)
+ if code == 0:
+ authenticated = True
if not reuse or code:
# wait for closing
try:
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 3c783ae541a1f..7abf2c1c25e72 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -21,16 +21,19 @@
import select
import signal
import shlex
+import shutil
import socket
import platform
+import tempfile
+import time
from subprocess import Popen, PIPE
if sys.version >= '3':
xrange = range
-from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+from py4j.java_gateway import java_import, JavaGateway, GatewayParameters
from pyspark.find_spark_home import _find_spark_home
-from pyspark.serializers import read_int
+from pyspark.serializers import read_int, write_with_length, UTF8Deserializer
def launch_gateway(conf=None):
@@ -41,6 +44,7 @@ def launch_gateway(conf=None):
"""
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
+ gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"]
else:
SPARK_HOME = _find_spark_home()
# Launch the Py4j gateway using Spark's run command so that we pick up the
@@ -59,40 +63,40 @@ def launch_gateway(conf=None):
])
command = command + shlex.split(submit_args)
- # Start a socket that will be used by PythonGatewayServer to communicate its port to us
- callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- callback_socket.bind(('127.0.0.1', 0))
- callback_socket.listen(1)
- callback_host, callback_port = callback_socket.getsockname()
- env = dict(os.environ)
- env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
- env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)
-
- # Launch the Java gateway.
- # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
- if not on_windows:
- # Don't send ctrl-c / SIGINT to the Java gateway:
- def preexec_func():
- signal.signal(signal.SIGINT, signal.SIG_IGN)
- proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
- else:
- # preexec_fn not supported on Windows
- proc = Popen(command, stdin=PIPE, env=env)
-
- gateway_port = None
- # We use select() here in order to avoid blocking indefinitely if the subprocess dies
- # before connecting
- while gateway_port is None and proc.poll() is None:
- timeout = 1 # (seconds)
- readable, _, _ = select.select([callback_socket], [], [], timeout)
- if callback_socket in readable:
- gateway_connection = callback_socket.accept()[0]
- # Determine which ephemeral port the server started on:
- gateway_port = read_int(gateway_connection.makefile(mode="rb"))
- gateway_connection.close()
- callback_socket.close()
- if gateway_port is None:
- raise Exception("Java gateway process exited before sending the driver its port number")
+ # Create a temporary directory where the gateway server should write the connection
+ # information.
+ conn_info_dir = tempfile.mkdtemp()
+ try:
+ fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir)
+ os.close(fd)
+ os.unlink(conn_info_file)
+
+ env = dict(os.environ)
+ env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file
+
+ # Launch the Java gateway.
+ # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
+ if not on_windows:
+ # Don't send ctrl-c / SIGINT to the Java gateway:
+ def preexec_func():
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
+ proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
+ else:
+ # preexec_fn not supported on Windows
+ proc = Popen(command, stdin=PIPE, env=env)
+
+ # Wait for the file to appear, or for the process to exit, whichever happens first.
+ while not proc.poll() and not os.path.isfile(conn_info_file):
+ time.sleep(0.1)
+
+ if not os.path.isfile(conn_info_file):
+ raise Exception("Java gateway process exited before sending its port number")
+
+ with open(conn_info_file, "rb") as info:
+ gateway_port = read_int(info)
+ gateway_secret = UTF8Deserializer().loads(info)
+ finally:
+ shutil.rmtree(conn_info_dir)
# In Windows, ensure the Java child processes do not linger after Python has exited.
# In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
@@ -111,7 +115,9 @@ def killChild():
atexit.register(killChild)
# Connect to the gateway
- gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True)
+ gateway = JavaGateway(
+ gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
+ auto_convert=True))
# Import the classes used by PySpark
java_import(gateway.jvm, "org.apache.spark.SparkConf")
@@ -125,3 +131,16 @@ def killChild():
java_import(gateway.jvm, "scala.Tuple2")
return gateway
+
+
+def do_server_auth(conn, auth_secret):
+ """
+ Performs the authentication protocol defined by the SocketAuthHelper class on the given
+ file-like object 'conn'.
+ """
+ write_with_length(auth_secret.encode("utf-8"), conn)
+ conn.flush()
+ reply = UTF8Deserializer().loads(conn)
+ if reply != "ok":
+ conn.close()
+ raise Exception("Unexpected reply from iterator server.")
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index aca00bc3b26fc..864cebb9517d9 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -39,9 +39,11 @@
else:
from itertools import imap as map, ifilter as filter
+from pyspark.java_gateway import do_server_auth
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
- PickleSerializer, pack_long, AutoBatchedSerializer
+ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
+ UTF8Deserializer
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -119,7 +121,8 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
-def _load_from_socket(port, serializer):
+def _load_from_socket(sock_info, serializer):
+ port, auth_secret = sock_info
sock = None
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
@@ -139,8 +142,12 @@ def _load_from_socket(port, serializer):
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
+
+ sockfile = sock.makefile("rwb", 65536)
+ do_server_auth(sockfile, auth_secret)
+
# The socket will be automatically closed when garbage-collected.
- return serializer.load_stream(sock.makefile("rb", 65536))
+ return serializer.load_stream(sockfile)
def ignore_unicode_prefix(f):
@@ -806,8 +813,8 @@ def collect(self):
to be small, as all the data is loaded into the driver's memory.
"""
with SCCallSiteSync(self.context) as css:
- port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
- return list(_load_from_socket(port, self._jrdd_deserializer))
+ sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
+ return list(_load_from_socket(sock_info, self._jrdd_deserializer))
def reduce(self, f):
"""
@@ -2364,8 +2371,8 @@ def toLocalIterator(self):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
with SCCallSiteSync(self.context) as css:
- port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
- return _load_from_socket(port, self._jrdd_deserializer)
+ sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
+ return _load_from_socket(sock_info, self._jrdd_deserializer)
def _prepare_for_python_RDD(sc, command):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index d1b336df40f91..1747d3635fe4a 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -435,8 +435,8 @@ def collect(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- port = self._jdf.collectToPython()
- return list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
+ sock_info = self._jdf.collectToPython()
+ return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
@ignore_unicode_prefix
@since(2.0)
@@ -449,8 +449,8 @@ def toLocalIterator(self):
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
"""
with SCCallSiteSync(self._sc) as css:
- port = self._jdf.toPythonIterator()
- return _load_from_socket(port, BatchedSerializer(PickleSerializer()))
+ sock_info = self._jdf.toPythonIterator()
+ return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
@ignore_unicode_prefix
@since(1.3)
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index baaa3fe074e9a..0c8996e21ee6c 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -27,6 +27,7 @@
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.java_gateway import do_server_auth
from pyspark.taskcontext import TaskContext
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
@@ -208,9 +209,11 @@ def process():
if __name__ == '__main__':
- # Read a local port to connect to from stdin
- java_port = int(sys.stdin.readline())
+ # Read information about how to connect back to the JVM from the environment.
+ java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+ auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", java_port))
sock_file = sock.makefile("rwb", 65536)
+ do_server_auth(sock_file, auth_secret)
main(sock_file, sock_file)
diff --git a/python/setup.py b/python/setup.py
index 7e63461d289b2..fa8dc12c92caf 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -196,7 +196,7 @@ def _supports_symlinks():
'pyspark.examples.src.main.python': ['*.py', '*/*.py']},
scripts=scripts,
license='http://www.apache.org/licenses/LICENSE-2.0',
- install_requires=['py4j==0.10.4'],
+ install_requires=['py4j==0.10.7'],
setup_requires=['pypandoc'],
extras_require={
'ml': ['numpy>=1.7'],
diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
index 9354a3bef7ba5..41f351a5f5542 100644
--- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
+++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala
@@ -485,9 +485,9 @@ private[spark] class MesosClusterScheduler(
.filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) }
.toMap
(defaultConf ++ driverConf).foreach { case (key, value) =>
- options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) }
+ options ++= Seq("--conf", s"${key}=${value}") }
- options
+ options.map(shellEscape)
}
/**
diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 638ee0f86d06d..a94a8ab515459 100644
--- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -1139,7 +1139,7 @@ private[spark] class Client(
val pyArchivesFile = new File(pyLibPath, "pyspark.zip")
require(pyArchivesFile.exists(),
s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.")
- val py4jFile = new File(pyLibPath, "py4j-0.10.4-src.zip")
+ val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip")
require(py4jFile.exists(),
s"$py4jFile not found; cannot run pyspark application in YARN mode.")
Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath())
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 59adb7e22d185..f18890c842ed6 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -249,7 +249,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
// needed locations.
val sparkHome = sys.props("spark.test.home")
val pythonPath = Seq(
- s"$sparkHome/python/lib/py4j-0.10.4-src.zip",
+ s"$sparkHome/python/lib/py4j-0.10.7-src.zip",
s"$sparkHome/python")
val extraEnvVars = Map(
"PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator),
diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh
index f2d9e6b568a9b..bf3da18c3706e 100755
--- a/sbin/spark-config.sh
+++ b/sbin/spark-config.sh
@@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}"
# Add the PySpark classes to the PYTHONPATH:
if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then
export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}"
- export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:${PYTHONPATH}"
+ export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}"
export PYSPARK_PYTHONPATH_SET=1
fi
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 4c8b177237d23..371eddaa3705d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -17,8 +17,12 @@
package org.apache.spark.sql.catalyst
+import java.util.Locale
+
import com.google.common.collect.Maps
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{StructField, StructType}
@@ -131,6 +135,88 @@ package object expressions {
def indexOf(exprId: ExprId): Int = {
Option(exprIdToOrdinal.get(exprId)).getOrElse(-1)
}
+
+ private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = {
+ m.mapValues(_.distinct).map(identity)
+ }
+
+ /** Map to use for direct case insensitive attribute lookups. */
+ @transient private lazy val direct: Map[String, Seq[Attribute]] = {
+ unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT)))
+ }
+
+ /** Map to use for qualified case insensitive attribute lookups. */
+ @transient private val qualified: Map[(String, String), Seq[Attribute]] = {
+ val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a =>
+ (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT))
+ }
+ unique(grouped)
+ }
+
+ /** Perform attribute resolution given a name and a resolver. */
+ def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = {
+ // Collect matching attributes given a name and a lookup.
+ def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = {
+ candidates.toSeq.flatMap(_.collect {
+ case a if resolver(a.name, name) => a.withName(name)
+ })
+ }
+
+ // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name,
+ // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of
+ // matched attributes and a list of parts that are to be resolved.
+ //
+ // For example, consider an example where "a" is the table name, "b" is the column name,
+ // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b",
+ // and the second element will be List("c").
+ val matches = nameParts match {
+ case qualifier +: name +: nestedFields =>
+ val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT))
+ val attributes = collectMatches(name, qualified.get(key)).filter { a =>
+ resolver(qualifier, a.qualifier.get)
+ }
+ (attributes, nestedFields)
+ case all =>
+ (Nil, all)
+ }
+
+ // If none of attributes match `table.column` pattern, we try to resolve it as a column.
+ val (candidates, nestedFields) = matches match {
+ case (Seq(), _) =>
+ val name = nameParts.head
+ val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT)))
+ (attributes, nameParts.tail)
+ case _ => matches
+ }
+
+ def name = UnresolvedAttribute(nameParts).name
+ candidates match {
+ case Seq(a) if nestedFields.nonEmpty =>
+ // One match, but we also need to extract the requested nested field.
+ // The foldLeft adds ExtractValues for every remaining parts of the identifier,
+ // and aliased it with the last part of the name.
+ // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
+ // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
+ // expression as "c".
+ val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) =>
+ ExtractValue(e, Literal(name), resolver)
+ }
+ Some(Alias(fieldExprs, nestedFields.last)())
+
+ case Seq(a) =>
+ // One match, no nested fields, use it.
+ Some(a)
+
+ case Seq() =>
+ // No matches.
+ None
+
+ case ambiguousReferences =>
+ // More than one match.
+ val referenceNames = ambiguousReferences.mkString(", ")
+ throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.")
+ }
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 2ebb2ff323c6b..5c65bbd16d305 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -160,6 +160,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
}
}
+ private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output))
+
+ private[this] lazy val outputAttributes = AttributeSeq(output)
+
/**
* Optionally resolves the given strings to a [[NamedExpression]] using the input from all child
* nodes of this LogicalPlan. The attribute is expressed as
@@ -168,7 +172,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def resolveChildren(
nameParts: Seq[String],
resolver: Resolver): Option[NamedExpression] =
- resolve(nameParts, children.flatMap(_.output), resolver)
+ childAttributes.resolve(nameParts, resolver)
/**
* Optionally resolves the given strings to a [[NamedExpression]] based on the output of this
@@ -178,7 +182,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def resolve(
nameParts: Seq[String],
resolver: Resolver): Option[NamedExpression] =
- resolve(nameParts, output, resolver)
+ outputAttributes.resolve(nameParts, resolver)
/**
* Given an attribute name, split it to name parts by dot, but
@@ -188,105 +192,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def resolveQuoted(
name: String,
resolver: Resolver): Option[NamedExpression] = {
- resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver)
- }
-
- /**
- * Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
- *
- * This assumes `name` has multiple parts, where the 1st part is a qualifier
- * (i.e. table name, alias, or subquery alias).
- * See the comment above `candidates` variable in resolve() for semantics the returned data.
- */
- private def resolveAsTableColumn(
- nameParts: Seq[String],
- resolver: Resolver,
- attribute: Attribute): Option[(Attribute, List[String])] = {
- assert(nameParts.length > 1)
- if (attribute.qualifier.exists(resolver(_, nameParts.head))) {
- // At least one qualifier matches. See if remaining parts match.
- val remainingParts = nameParts.tail
- resolveAsColumn(remainingParts, resolver, attribute)
- } else {
- None
- }
- }
-
- /**
- * Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
- *
- * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier.
- * See the comment above `candidates` variable in resolve() for semantics the returned data.
- */
- private def resolveAsColumn(
- nameParts: Seq[String],
- resolver: Resolver,
- attribute: Attribute): Option[(Attribute, List[String])] = {
- if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) {
- Option((attribute.withName(nameParts.head), nameParts.tail.toList))
- } else {
- None
- }
- }
-
- /** Performs attribute resolution given a name and a sequence of possible attributes. */
- protected def resolve(
- nameParts: Seq[String],
- input: Seq[Attribute],
- resolver: Resolver): Option[NamedExpression] = {
-
- // A sequence of possible candidate matches.
- // Each candidate is a tuple. The first element is a resolved attribute, followed by a list
- // of parts that are to be resolved.
- // For example, consider an example where "a" is the table name, "b" is the column name,
- // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b",
- // and the second element will be List("c").
- var candidates: Seq[(Attribute, List[String])] = {
- // If the name has 2 or more parts, try to resolve it as `table.column` first.
- if (nameParts.length > 1) {
- input.flatMap { option =>
- resolveAsTableColumn(nameParts, resolver, option)
- }
- } else {
- Seq.empty
- }
- }
-
- // If none of attributes match `table.column` pattern, we try to resolve it as a column.
- if (candidates.isEmpty) {
- candidates = input.flatMap { candidate =>
- resolveAsColumn(nameParts, resolver, candidate)
- }
- }
-
- def name = UnresolvedAttribute(nameParts).name
-
- candidates.distinct match {
- // One match, no nested fields, use it.
- case Seq((a, Nil)) => Some(a)
-
- // One match, but we also need to extract the requested nested field.
- case Seq((a, nestedFields)) =>
- // The foldLeft adds ExtractValues for every remaining parts of the identifier,
- // and aliased it with the last part of the name.
- // For example, consider "a.b.c", where "a" is resolved to an existing attribute.
- // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final
- // expression as "c".
- val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
- ExtractValue(expr, Literal(fieldName), resolver))
- Some(Alias(fieldExprs, nestedFields.last)())
-
- // No matches.
- case Seq() =>
- logTrace(s"Could not find $name in ${input.mkString(", ")}")
- None
-
- // More than one match.
- case ambiguousReferences =>
- val referenceNames = ambiguousReferences.map(_._1).mkString(", ")
- throw new AnalysisException(
- s"Reference '$name' is ambiguous, could be: $referenceNames.")
- }
+ outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index fb700a489c763..4c13daa386a8f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2804,7 +2804,7 @@ class Dataset[T] private[sql](
EvaluatePython.javaToPython(rdd)
}
- private[sql] def collectToPython(): Int = {
+ private[sql] def collectToPython(): Array[Any] = {
EvaluatePython.registerPicklers()
withNewExecutionId {
val toJava: (Any) => Any = EvaluatePython.toJava(_, schema)
@@ -2814,7 +2814,7 @@ class Dataset[T] private[sql](
}
}
- private[sql] def toPythonIterator(): Int = {
+ private[sql] def toPythonIterator(): Array[Any] = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
index a3d5b941a6761..2b37047612dfe 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala
@@ -57,30 +57,31 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
for (i <- 0 until 3) {
val preferredMirror =
Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim
- val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz"
+ val filename = s"spark-$version-bin-hadoop2.7.tgz"
+ val url = s"$preferredMirror/spark/spark-$version/$filename"
logInfo(s"Downloading Spark $version from $url")
if (Seq("wget", url, "-q", "-P", path).! == 0) {
- return
+ val downloaded = new File(sparkTestingDir, filename).getCanonicalPath
+ val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath
+
+ Seq("mkdir", targetDir).!
+ val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").!
+ Seq("rm", downloaded).!
+
+ // For a corrupted file, `tar` returns non-zero values. However, we also need to check
+ // the extracted file because `tar` returns 0 for empty file.
+ val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit")
+ if (exitCode == 0 && sparkSubmit.exists()) {
+ return
+ } else {
+ Seq("rm", "-rf", targetDir).!
+ }
}
logWarning(s"Failed to download Spark $version from $url")
}
fail(s"Unable to download Spark $version")
}
-
- private def downloadSpark(version: String): Unit = {
- tryDownloadSpark(version, sparkTestingDir.getCanonicalPath)
-
- val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath
- val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath
-
- Seq("mkdir", targetDir).!
-
- Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").!
-
- Seq("rm", downloaded).!
- }
-
private def genDataDir(name: String): String = {
new File(tmpDataDir, name).getCanonicalPath
}
@@ -125,7 +126,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils {
PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) =>
val sparkHome = new File(sparkTestingDir, s"spark-$version")
if (!sparkHome.exists()) {
- downloadSpark(version)
+ tryDownloadSpark(version, sparkTestingDir.getCanonicalPath)
}
val args = Seq(