diff --git a/LICENSE b/LICENSE index 66a2e8f132953..b948ccaeecea6 100644 --- a/LICENSE +++ b/LICENSE @@ -263,7 +263,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (New BSD license) Protocol Buffer Java API (org.spark-project.protobuf:protobuf-java:2.4.1-shaded - http://code.google.com/p/protobuf) (The BSD License) Fortran to Java ARPACK (net.sourceforge.f2j:arpack_combined_all:0.1 - http://f2j.sourceforge.net) (The BSD License) xmlenc Library (xmlenc:xmlenc:0.52 - http://xmlenc.sourceforge.net) - (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.4 - http://py4j.sourceforge.net/) + (The New BSD License) Py4J (net.sf.py4j:py4j:0.10.7 - http://py4j.sourceforge.net/) (Two-clause BSD-style license) JUnit-Interface (com.novocode:junit-interface:0.10 - http://github.com/szeiger/junit-interface/) (BSD licence) sbt and sbt-launch-lib.bash (BSD 3 Clause) d3.min.js (https://github.com/mbostock/d3/blob/master/LICENSE) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index 9d82814211bc5..7244cc9f9e38e 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -19,7 +19,7 @@ # Creates a SparkR client connection object # if one doesn't already exist -connectBackend <- function(hostname, port, timeout) { +connectBackend <- function(hostname, port, timeout, authSecret) { if (exists(".sparkRcon", envir = .sparkREnv)) { if (isOpen(.sparkREnv[[".sparkRCon"]])) { cat("SparkRBackend client connection already exists\n") @@ -29,7 +29,7 @@ connectBackend <- function(hostname, port, timeout) { con <- socketConnection(host = hostname, port = port, server = FALSE, blocking = TRUE, open = "wb", timeout = timeout) - + doServerAuth(con, authSecret) assign(".sparkRCon", con, envir = .sparkREnv) con } diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 0e99b171cabeb..dc7d37e064b1d 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -60,14 +60,18 @@ readTypedObject <- function(con, type) { stop(paste("Unsupported type for deserialization", type))) } -readString <- function(con) { - stringLen <- readInt(con) - raw <- readBin(con, raw(), stringLen, endian = "big") +readStringData <- function(con, len) { + raw <- readBin(con, raw(), len, endian = "big") string <- rawToChar(raw) Encoding(string) <- "UTF-8" string } +readString <- function(con) { + stringLen <- readInt(con) + readStringData(con, stringLen) +} + readInt <- function(con) { readBin(con, integer(), n = 1, endian = "big") } diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 9ebd34411a1ea..daa855b2459f2 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -161,6 +161,10 @@ sparkR.sparkContext <- function( " please use the --packages commandline instead", sep = ",")) } backendPort <- existingPort + authSecret <- Sys.getenv("SPARKR_BACKEND_AUTH_SECRET") + if (nchar(authSecret) == 0) { + stop("Auth secret not provided in environment.") + } } else { path <- tempfile(pattern = "backend_port") submitOps <- getClientModeSparkSubmitOpts( @@ -189,16 +193,27 @@ sparkR.sparkContext <- function( monitorPort <- readInt(f) rLibPath <- readString(f) connectionTimeout <- readInt(f) + + # Don't use readString() so that we can provide a useful + # error message if the R and Java versions are mismatched. + authSecretLen = readInt(f) + if (length(authSecretLen) == 0 || authSecretLen == 0) { + stop("Unexpected EOF in JVM connection data. Mismatched versions?") + } + authSecret <- readStringData(f, authSecretLen) close(f) file.remove(path) if (length(backendPort) == 0 || backendPort == 0 || length(monitorPort) == 0 || monitorPort == 0 || - length(rLibPath) != 1) { + length(rLibPath) != 1 || length(authSecret) == 0) { stop("JVM failed to launch") } - assign(".monitorConn", - socketConnection(port = monitorPort, timeout = connectionTimeout), - envir = .sparkREnv) + + monitorConn <- socketConnection(port = monitorPort, blocking = TRUE, + timeout = connectionTimeout, open = "wb") + doServerAuth(monitorConn, authSecret) + + assign(".monitorConn", monitorConn, envir = .sparkREnv) assign(".backendLaunched", 1, envir = .sparkREnv) if (rLibPath != "") { assign(".libPath", rLibPath, envir = .sparkREnv) @@ -208,7 +223,7 @@ sparkR.sparkContext <- function( .sparkREnv$backendPort <- backendPort tryCatch({ - connectBackend("localhost", backendPort, timeout = connectionTimeout) + connectBackend("localhost", backendPort, timeout = connectionTimeout, authSecret = authSecret) }, error = function(err) { stop("Failed to connect JVM\n") @@ -632,3 +647,17 @@ sparkCheckInstall <- function(sparkHome, master, deployMode) { NULL } } + +# Utility function for sending auth data over a socket and checking the server's reply. +doServerAuth <- function(con, authSecret) { + if (nchar(authSecret) == 0) { + stop("Auth secret not provided.") + } + writeString(con, authSecret) + flush(con) + reply <- readString(con) + if (reply != "ok") { + close(con) + stop("Unexpected reply from server.") + } +} diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R index 3a318b71ea06d..ec9a8f1ee1c95 100644 --- a/R/pkg/inst/worker/daemon.R +++ b/R/pkg/inst/worker/daemon.R @@ -28,7 +28,9 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout) + port = port, open = "wb", blocking = TRUE, timeout = connectionTimeout) + +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) while (TRUE) { ready <- socketSelect(list(inputCon)) diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 03e7450147865..eb6453fc16976 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -100,9 +100,12 @@ suppressPackageStartupMessages(library(SparkR)) port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT")) inputCon <- socketConnection( - port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout) + port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(inputCon, Sys.getenv("SPARKR_WORKER_SECRET")) + outputCon <- socketConnection( port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout) +SparkR:::doServerAuth(outputCon, Sys.getenv("SPARKR_WORKER_SECRET")) # read the index of the current partition inside the RDD partition <- SparkR:::readInt(inputCon) diff --git a/bin/pyspark b/bin/pyspark index 98387c2ec5b8a..95ab62880654f 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -25,14 +25,14 @@ source "${SPARK_HOME}"/bin/load-spark-env.sh export _SPARK_CMD_USAGE="Usage: ./bin/pyspark [options]" # In Spark 2.0, IPYTHON and IPYTHON_OPTS are removed and pyspark fails to launch if either option -# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython +# is set in the user's environment. Instead, users should set PYSPARK_DRIVER_PYTHON=ipython # to use IPython and set PYSPARK_DRIVER_PYTHON_OPTS to pass options when starting the Python driver # (e.g. PYSPARK_DRIVER_PYTHON_OPTS='notebook'). This supports full customization of the IPython # and executor Python executables. # Fail noisily if removed options are set if [[ -n "$IPYTHON" || -n "$IPYTHON_OPTS" ]]; then - echo "Error in pyspark startup:" + echo "Error in pyspark startup:" echo "IPYTHON and IPYTHON_OPTS are removed in Spark 2.0+. Remove these from the environment and set PYSPARK_DRIVER_PYTHON and PYSPARK_DRIVER_PYTHON_OPTS instead." exit 1 fi @@ -57,7 +57,7 @@ export PYSPARK_PYTHON # Add the PySpark classes to the Python path: export PYTHONPATH="${SPARK_HOME}/python/:$PYTHONPATH" -export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:$PYTHONPATH" +export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:$PYTHONPATH" # Load the PySpark shell.py script when ./pyspark is used interactively: export OLD_PYTHONSTARTUP="$PYTHONSTARTUP" diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd index d1ce9dabab029..15fa910c277b3 100644 --- a/bin/pyspark2.cmd +++ b/bin/pyspark2.cmd @@ -30,7 +30,7 @@ if "x%PYSPARK_DRIVER_PYTHON%"=="x" ( ) set PYTHONPATH=%SPARK_HOME%\python;%PYTHONPATH% -set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.4-src.zip;%PYTHONPATH% +set PYTHONPATH=%SPARK_HOME%\python\lib\py4j-0.10.7-src.zip;%PYTHONPATH% set OLD_PYTHONSTARTUP=%PYTHONSTARTUP% set PYTHONSTARTUP=%SPARK_HOME%\python\pyspark\shell.py diff --git a/core/pom.xml b/core/pom.xml index 0b632bca8cd96..6e50010b48f47 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -335,7 +335,7 @@ net.sf.py4j py4j - 0.10.4 + 0.10.7 org.apache.spark diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 2480e56b72ccf..9814f1ae028de 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -17,13 +17,11 @@ package org.apache.spark -import java.lang.{Byte => JByte} import java.net.{Authenticator, PasswordAuthentication} -import java.security.{KeyStore, SecureRandom} +import java.security.KeyStore import java.security.cert.X509Certificate import javax.net.ssl._ -import com.google.common.hash.HashCodes import com.google.common.io.Files import org.apache.hadoop.io.Text @@ -435,12 +433,7 @@ private[spark] class SecurityManager( val secretKey = SparkHadoopUtil.get.getSecretKeyFromUserCredentials(SECRET_LOOKUP_KEY) if (secretKey == null || secretKey.length == 0) { logDebug("generateSecretKey: yarn mode, secret key from credentials is null") - val rnd = new SecureRandom() - val length = sparkConf.getInt("spark.authenticate.secretBitLength", 256) / JByte.SIZE - val secret = new Array[Byte](length) - rnd.nextBytes(secret) - - val cookie = HashCodes.fromBytes(secret).toString() + val cookie = Utils.createSecret(sparkConf) SparkHadoopUtil.get.addSecretKeyToUserCredentials(SECRET_LOOKUP_KEY, cookie) cookie } else { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 11f2432575d84..9ddc4a4910180 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -17,26 +17,39 @@ package org.apache.spark.api.python -import java.io.DataOutputStream -import java.net.Socket +import java.io.{DataOutputStream, File, FileOutputStream} +import java.net.InetAddress +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.Files import py4j.GatewayServer +import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.util.Utils /** - * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port - * back to its caller via a callback port specified by the caller. + * Process that starts a Py4J GatewayServer on an ephemeral port. * * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). */ private[spark] object PythonGatewayServer extends Logging { initializeLogIfNecessary(true) - def main(args: Array[String]): Unit = Utils.tryOrExit { - // Start a GatewayServer on an ephemeral port - val gatewayServer: GatewayServer = new GatewayServer(null, 0) + def main(args: Array[String]): Unit = { + val secret = Utils.createSecret(new SparkConf()) + + // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured + // with the same secret, in case the app needs callbacks from the JVM to the underlying + // python processes. + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() + gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort if (boundPort == -1) { @@ -46,15 +59,24 @@ private[spark] object PythonGatewayServer extends Logging { logDebug(s"Started PythonGatewayServer on port $boundPort") } - // Communicate the bound port back to the caller via the caller-specified callback port - val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST") - val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt - logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort") - val callbackSocket = new Socket(callbackHost, callbackPort) - val dos = new DataOutputStream(callbackSocket.getOutputStream) + // Communicate the connection information back to the python process by writing the + // information in the requested file. This needs to match the read side in java_gateway.py. + val connectionInfoPath = new File(sys.env("_PYSPARK_DRIVER_CONN_INFO_PATH")) + val tmpPath = Files.createTempFile(connectionInfoPath.getParentFile().toPath(), + "connection", ".info").toFile() + + val dos = new DataOutputStream(new FileOutputStream(tmpPath)) dos.writeInt(boundPort) + + val secretBytes = secret.getBytes(UTF_8) + dos.writeInt(secretBytes.length) + dos.write(secretBytes, 0, secretBytes.length) dos.close() - callbackSocket.close() + + if (!tmpPath.renameTo(connectionInfoPath)) { + logError(s"Unable to write connection information to $connectionInfoPath.") + System.exit(1) + } // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies: while (System.in.read() != -1) { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 63ae705f9c97e..0662792cd1a75 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -38,6 +38,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util._ @@ -421,6 +422,12 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + // Authentication helper used when serving iterator data. + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new SocketAuthHelper(conf) + } + def getWorkerBroadcasts(worker: Socket): mutable.Set[Long] = { synchronized { workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) @@ -443,12 +450,13 @@ private[spark] object PythonRDD extends Logging { * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ def runJob( sc: SparkContext, rdd: JavaRDD[Array[Byte]], - partitions: JArrayList[Int]): Int = { + partitions: JArrayList[Int]): Array[Any] = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = @@ -461,13 +469,14 @@ private[spark] object PythonRDD extends Logging { /** * A helper function to collect an RDD as an iterator, then serve it via socket. * - * @return the port number of a local socket which serves the data collected from this job. + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def collectAndServe[T](rdd: RDD[T]): Int = { + def collectAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } - def toLocalIteratorAndServe[T](rdd: RDD[T]): Int = { + def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") } @@ -698,8 +707,11 @@ private[spark] object PythonRDD extends Logging { * and send them into this connection. * * The thread will terminate after all the data are sent or any exceptions happen. + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from this job, and the secret for authentication. */ - def serveIterator[T](items: Iterator[T], threadName: String): Int = { + def serveIterator(items: Iterator[_], threadName: String): Array[Any] = { val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) // Close the socket if no connection in 15 seconds serverSocket.setSoTimeout(15000) @@ -709,11 +721,14 @@ private[spark] object PythonRDD extends Logging { override def run() { try { val sock = serverSocket.accept() + authHelper.authClient(sock) + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) Utils.tryWithSafeFinally { writeIteratorToStream(items, out) } { out.close() + sock.close() } } catch { case NonFatal(e) => @@ -724,7 +739,7 @@ private[spark] object PythonRDD extends Logging { } }.start() - serverSocket.getLocalPort + Array(serverSocket.getLocalPort, authHelper.secret) } private def getMergedConf(confAsMap: java.util.HashMap[String, String], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index c4e55b5e89027..27a5e19f96a14 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -32,7 +32,7 @@ private[spark] object PythonUtils { val pythonPath = new ArrayBuffer[String] for (sparkHome <- sys.env.get("SPARK_HOME")) { pythonPath += Seq(sparkHome, "python", "lib", "pyspark.zip").mkString(File.separator) - pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.4-src.zip").mkString(File.separator) + pythonPath += Seq(sparkHome, "python", "lib", "py4j-0.10.7-src.zip").mkString(File.separator) } pythonPath ++= SparkContext.jarOfObject(this) pythonPath.mkString(File.pathSeparator) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 6a5e6f7c5afb1..9c667847382cd 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -27,6 +27,7 @@ import scala.collection.JavaConverters._ import org.apache.spark._ import org.apache.spark.internal.Logging +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.util.{RedirectThread, Utils} private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String]) @@ -40,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // also fall back to launching workers (pyspark/worker.py) directly. val useDaemon = !System.getProperty("os.name").startsWith("Windows") + + private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 @@ -80,6 +84,8 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String if (pid < 0) { throw new IllegalStateException("Python daemon failed to launch worker with code " + pid) } + + authHelper.authToServer(socket) daemonWorkers.put(socket, pid) socket } @@ -117,25 +123,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONPATH", pythonPath) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) val worker = pb.start() // Redirect worker stdout and stderr redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) - // Tell the worker our port - val out = new OutputStreamWriter(worker.getOutputStream, StandardCharsets.UTF_8) - out.write(serverSocket.getLocalPort + "\n") - out.flush() - - // Wait for it to connect to our socket + // Wait for it to connect to our socket, and validate the auth secret. serverSocket.setSoTimeout(10000) + try { val socket = serverSocket.accept() + authHelper.authClient(socket) simpleWorkers.put(socket, worker) return socket } catch { case e: Exception => - throw new SparkException("Python worker did not connect back in time", e) + throw new SparkException("Python worker failed to connect back.", e) } } finally { if (serverSocket != null) { @@ -158,6 +163,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val workerEnv = pb.environment() workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() @@ -167,7 +173,6 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String // Redirect daemon stdout and stderr redirectStreamsToStderr(in, daemon.getErrorStream) - } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala new file mode 100644 index 0000000000000..ac6826a9ec774 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{DataInputStream, DataOutputStream} +import java.net.Socket + +import org.apache.spark.SparkConf +import org.apache.spark.security.SocketAuthHelper + +private[spark] class RAuthHelper(conf: SparkConf) extends SocketAuthHelper(conf) { + + override protected def readUtf8(s: Socket): String = { + SerDe.readString(new DataInputStream(s.getInputStream())) + } + + override protected def writeUtf8(str: String, s: Socket): Unit = { + val out = s.getOutputStream() + SerDe.writeString(new DataOutputStream(out), str) + out.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 2d1152a036449..3b2e809408e0f 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -17,8 +17,8 @@ package org.apache.spark.api.r -import java.io.{DataOutputStream, File, FileOutputStream, IOException} -import java.net.{InetAddress, InetSocketAddress, ServerSocket} +import java.io.{DataInputStream, DataOutputStream, File, FileOutputStream, IOException} +import java.net.{InetAddress, InetSocketAddress, ServerSocket, Socket} import java.util.concurrent.TimeUnit import io.netty.bootstrap.ServerBootstrap @@ -32,6 +32,8 @@ import io.netty.handler.timeout.ReadTimeoutHandler import org.apache.spark.SparkConf import org.apache.spark.internal.Logging +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils /** * Netty-based backend server that is used to communicate between R and Java. @@ -45,7 +47,7 @@ private[spark] class RBackend { /** Tracks JVM objects returned to R for this RBackend instance. */ private[r] val jvmObjectTracker = new JVMObjectTracker - def init(): Int = { + def init(): (Int, RAuthHelper) = { val conf = new SparkConf() val backendConnectionTimeout = conf.getInt( "spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT) @@ -53,6 +55,7 @@ private[spark] class RBackend { conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS)) val workerGroup = bossGroup val handler = new RBackendHandler(this) + val authHelper = new RAuthHelper(conf) bootstrap = new ServerBootstrap() .group(bossGroup, workerGroup) @@ -71,13 +74,16 @@ private[spark] class RBackend { new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4)) .addLast("decoder", new ByteArrayDecoder()) .addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout)) + .addLast(new RBackendAuthHandler(authHelper.secret)) .addLast("handler", handler) } }) channelFuture = bootstrap.bind(new InetSocketAddress("localhost", 0)) channelFuture.syncUninterruptibly() - channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + + val port = channelFuture.channel().localAddress().asInstanceOf[InetSocketAddress].getPort() + (port, authHelper) } def run(): Unit = { @@ -116,7 +122,7 @@ private[spark] object RBackend extends Logging { val sparkRBackend = new RBackend() try { // bind to random port - val boundPort = sparkRBackend.init() + val (boundPort, authHelper) = sparkRBackend.init() val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost")) val listenPort = serverSocket.getLocalPort() // Connection timeout is set by socket client. To make it configurable we will pass the @@ -133,6 +139,7 @@ private[spark] object RBackend extends Logging { dos.writeInt(listenPort) SerDe.writeString(dos, RUtils.rPackages.getOrElse("")) dos.writeInt(backendConnectionTimeout) + SerDe.writeString(dos, authHelper.secret) dos.close() f.renameTo(new File(path)) @@ -144,12 +151,35 @@ private[spark] object RBackend extends Logging { val buf = new Array[Byte](1024) // shutdown JVM if R does not connect back in 10 seconds serverSocket.setSoTimeout(10000) + + // Wait for the R process to connect back, ignoring any failed auth attempts. Allow + // a max number of connection attempts to avoid looping forever. try { - val inSocket = serverSocket.accept() + var remainingAttempts = 10 + var inSocket: Socket = null + while (inSocket == null) { + inSocket = serverSocket.accept() + try { + authHelper.authClient(inSocket) + } catch { + case e: Exception => + remainingAttempts -= 1 + if (remainingAttempts == 0) { + val msg = "Too many failed authentication attempts." + logError(msg) + throw new IllegalStateException(msg) + } + logInfo("Client connection failed authentication.") + inSocket = null + } + } + serverSocket.close() + // wait for the end of socket, closed if R process die inSocket.getInputStream().read(buf) } finally { + serverSocket.close() sparkRBackend.close() System.exit(0) } @@ -165,4 +195,5 @@ private[spark] object RBackend extends Logging { } System.exit(0) } + } diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala new file mode 100644 index 0000000000000..4162e4a6c7476 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendAuthHandler.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.r + +import java.io.{ByteArrayOutputStream, DataOutputStream} +import java.nio.charset.StandardCharsets.UTF_8 + +import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler} + +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * Authentication handler for connections from the R process. + */ +private class RBackendAuthHandler(secret: String) + extends SimpleChannelInboundHandler[Array[Byte]] with Logging { + + override def channelRead0(ctx: ChannelHandlerContext, msg: Array[Byte]): Unit = { + // The R code adds a null terminator to serialized strings, so ignore it here. + val clientSecret = new String(msg, 0, msg.length - 1, UTF_8) + try { + require(secret == clientSecret, "Auth secret mismatch.") + ctx.pipeline().remove(this) + writeReply("ok", ctx.channel()) + } catch { + case e: Exception => + logInfo("Authentication failure.", e) + writeReply("err", ctx.channel()) + ctx.close() + } + } + + private def writeReply(reply: String, chan: Channel): Unit = { + val out = new ByteArrayOutputStream() + SerDe.writeString(new DataOutputStream(out), reply) + chan.writeAndFlush(out.toByteArray()) + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index 88118392003e8..e7fdc3963945a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -74,14 +74,19 @@ private[spark] class RRunner[U]( // the socket used to send out the input of task serverSocket.setSoTimeout(10000) - val inSocket = serverSocket.accept() - startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) - - // the socket used to receive the output of task - val outSocket = serverSocket.accept() - val inputStream = new BufferedInputStream(outSocket.getInputStream) - dataStream = new DataInputStream(inputStream) - serverSocket.close() + dataStream = try { + val inSocket = serverSocket.accept() + RRunner.authHelper.authClient(inSocket) + startStdinThread(inSocket.getOutputStream(), inputIterator, partitionIndex) + + // the socket used to receive the output of task + val outSocket = serverSocket.accept() + RRunner.authHelper.authClient(outSocket) + val inputStream = new BufferedInputStream(outSocket.getInputStream) + new DataInputStream(inputStream) + } finally { + serverSocket.close() + } try { return new Iterator[U] { @@ -315,6 +320,11 @@ private[r] object RRunner { private[this] var errThread: BufferedStreamThread = _ private[this] var daemonChannel: DataOutputStream = _ + private lazy val authHelper = { + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) + new RAuthHelper(conf) + } + /** * Start a thread to print the process's stderr to ours */ @@ -349,6 +359,7 @@ private[r] object RRunner { pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString) pb.environment().put("SPARKR_SPARKFILES_ROOT_DIR", SparkFiles.getRootDirectory()) pb.environment().put("SPARKR_IS_RUNNING_ON_WORKER", "TRUE") + pb.environment().put("SPARKR_WORKER_SECRET", authHelper.secret) pb.redirectErrorStream(true) // redirect stderr into stdout val proc = pb.start() val errThread = startStdoutThread(proc) @@ -370,8 +381,12 @@ private[r] object RRunner { // the socket used to send out the input of task serverSocket.setSoTimeout(10000) val sock = serverSocket.accept() - daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - serverSocket.close() + try { + authHelper.authClient(sock) + daemonChannel = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + } finally { + serverSocket.close() + } } try { daemonChannel.writeInt(port) diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index a8f732b11f6cf..b0f214a44ea42 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy import java.io.File -import java.net.URI +import java.net.{InetAddress, URI} import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ @@ -39,6 +39,7 @@ object PythonRunner { val pyFiles = args(1) val otherArgs = args.slice(2, args.length) val sparkConf = new SparkConf() + val secret = Utils.createSecret(sparkConf) val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) .orElse(sparkConf.get(PYSPARK_PYTHON)) .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) @@ -51,7 +52,13 @@ object PythonRunner { // Launch a Py4J gateway server for the process to connect to; this will let it see our // Java system properties and such - val gatewayServer = new py4j.GatewayServer(null, 0) + val localhost = InetAddress.getLoopbackAddress() + val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() val thread = new Thread(new Runnable() { override def run(): Unit = Utils.logUncaughtExceptions { gatewayServer.start() @@ -82,6 +89,7 @@ object PythonRunner { // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) + env.put("PYSPARK_GATEWAY_SECRET", secret) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 6eb53a8252205..e86b362639e57 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -68,10 +68,13 @@ object RRunner { // Java system properties etc. val sparkRBackend = new RBackend() @volatile var sparkRBackendPort = 0 + @volatile var sparkRBackendSecret: String = null val initialized = new Semaphore(0) val sparkRBackendThread = new Thread("SparkR backend") { override def run() { - sparkRBackendPort = sparkRBackend.init() + val (port, authHelper) = sparkRBackend.init() + sparkRBackendPort = port + sparkRBackendSecret = authHelper.secret initialized.release() sparkRBackend.run() } @@ -91,6 +94,7 @@ object RRunner { env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(",")) env.put("R_PROFILE_USER", Seq(rPackageDir(0), "SparkR", "profile", "general.R").mkString(File.separator)) + env.put("SPARKR_BACKEND_AUTH_SECRET", sparkRBackendSecret) builder.redirectErrorStream(true) // Ugly but needed for stdout and stderr to synchronize val process = builder.start() diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index da3cc04275154..db4c9f9d07e08 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -257,6 +257,11 @@ package object config { .regexConf .createOptional + private[spark] val AUTH_SECRET_BIT_LENGTH = + ConfigBuilder("spark.authenticate.secretBitLength") + .intConf + .createWithDefault(256) + private[spark] val NETWORK_AUTH_ENABLED = ConfigBuilder("spark.authenticate") .booleanConf diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 1b6bc9139f9c9..df6407b84195c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -693,6 +693,20 @@ private[spark] class TaskSchedulerImpl private[scheduler]( } } + /** + * Marks the task has completed in all TaskSetManagers for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier attempt of a stage completes a task, we should ensure that the later attempts + * do not also submit those same tasks. That also means that a task completion from an earlier + * attempt can lead to the entire stage getting marked as successful. + */ + private[scheduler] def markPartitionCompletedInAllTaskSets(stageId: Int, partitionId: Int) = { + taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => + tsm.markPartitionCompleted(partitionId) + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 2f4e46c7ec8f1..d9515fb27229e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -74,6 +74,8 @@ private[spark] class TaskSetManager( val ser = env.closureSerializer.newInstance() val tasks = taskSet.tasks + private[scheduler] val partitionToIndex = tasks.zipWithIndex + .map { case (t, idx) => t.partitionId -> idx }.toMap val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) @@ -149,7 +151,7 @@ private[spark] class TaskSetManager( private[scheduler] val speculatableTasks = new HashSet[Int] // Task index, start and finish time for each task attempt (indexed by task ID) - private val taskInfos = new HashMap[Long, TaskInfo] + private[scheduler] val taskInfos = new HashMap[Long, TaskInfo] // Use a MedianHeap to record durations of successful tasks so we know when to launch // speculative tasks. This is only used when speculation is enabled, to avoid the overhead @@ -744,6 +746,9 @@ private[spark] class TaskSetManager( logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } + // There may be multiple tasksets for this stage -- we let all of them know that the partition + // was completed. This may result in some of the tasksets getting completed. + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -754,6 +759,19 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { + partitionToIndex.get(partitionId).foreach { index => + if (!successful(index)) { + tasksSuccessful += 1 + successful(index) = true + if (tasksSuccessful == numTasks) { + isZombie = true + } + maybeFinishTaskSet() + } + } + } + /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala new file mode 100644 index 0000000000000..d15e7937b0523 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.security + +import java.io.{DataInputStream, DataOutputStream, InputStream} +import java.net.Socket +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.spark.SparkConf +import org.apache.spark.network.util.JavaUtils +import org.apache.spark.util.Utils + +/** + * A class that can be used to add a simple authentication protocol to socket-based communication. + * + * The protocol is simple: an auth secret is written to the socket, and the other side checks the + * secret and writes either "ok" or "err" to the output. If authentication fails, the socket is + * not expected to be valid anymore. + * + * There's no secrecy, so this relies on the sockets being either local or somehow encrypted. + */ +private[spark] class SocketAuthHelper(conf: SparkConf) { + + val secret = Utils.createSecret(conf) + + /** + * Read the auth secret from the socket and compare to the expected value. Write the reply back + * to the socket. + * + * If authentication fails, this method will close the socket. + * + * @param s The client socket. + * @throws IllegalArgumentException If authentication fails. + */ + def authClient(s: Socket): Unit = { + // Set the socket timeout while checking the auth secret. Reset it before returning. + val currentTimeout = s.getSoTimeout() + try { + s.setSoTimeout(10000) + val clientSecret = readUtf8(s) + if (secret == clientSecret) { + writeUtf8("ok", s) + } else { + writeUtf8("err", s) + JavaUtils.closeQuietly(s) + } + } finally { + s.setSoTimeout(currentTimeout) + } + } + + /** + * Authenticate with a server by writing the auth secret and checking the server's reply. + * + * If authentication fails, this method will close the socket. + * + * @param s The socket connected to the server. + * @throws IllegalArgumentException If authentication fails. + */ + def authToServer(s: Socket): Unit = { + writeUtf8(secret, s) + + val reply = readUtf8(s) + if (reply != "ok") { + JavaUtils.closeQuietly(s) + throw new IllegalArgumentException("Authentication failed.") + } + } + + protected def readUtf8(s: Socket): String = { + val din = new DataInputStream(s.getInputStream()) + val len = din.readInt() + val bytes = new Array[Byte](len) + din.readFully(bytes) + new String(bytes, UTF_8) + } + + protected def writeUtf8(str: String, s: Socket): Unit = { + val bytes = str.getBytes(UTF_8) + val dout = new DataOutputStream(s.getOutputStream()) + dout.writeInt(bytes.length) + dout.write(bytes, 0, bytes.length) + dout.flush() + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 603c23abb6895..5df17ccb627a3 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -482,7 +482,9 @@ class LegacyAccumulatorWrapper[R, T]( param: org.apache.spark.AccumulableParam[R, T]) extends AccumulatorV2[T, R] { private[spark] var _value = initialValue // Current value on driver - override def isZero: Boolean = _value == param.zero(initialValue) + @transient private lazy val _zero = param.zero(initialValue) + + override def isZero: Boolean = _value.asInstanceOf[AnyRef].eq(_zero.asInstanceOf[AnyRef]) override def copy(): LegacyAccumulatorWrapper[R, T] = { val acc = new LegacyAccumulatorWrapper(initialValue, param) @@ -491,7 +493,7 @@ class LegacyAccumulatorWrapper[R, T]( } override def reset(): Unit = { - _value = param.zero(initialValue) + _value = _zero } override def add(v: T): Unit = _value = param.addAccumulator(_value, v) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index a5f44bd746f14..6bcaf102d9680 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io._ +import java.lang.{Byte => JByte} import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} import java.math.{MathContext, RoundingMode} import java.net._ @@ -25,6 +26,7 @@ import java.nio.ByteBuffer import java.nio.channels.{Channels, FileChannel} import java.nio.charset.StandardCharsets import java.nio.file.{Files, Paths} +import java.security.SecureRandom import java.util.{Locale, Properties, Random, UUID} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean @@ -43,6 +45,7 @@ import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} +import com.google.common.hash.HashCodes import com.google.common.io.{ByteStreams, Files => GFiles} import com.google.common.net.InetAddresses import org.apache.commons.lang3.SystemUtils @@ -2658,6 +2661,14 @@ private[spark] object Utils extends Logging { redact(redactionPattern, kvs.toArray) } + def createSecret(conf: SparkConf): String = { + val bits = conf.get(AUTH_SECRET_BIT_LENGTH) + val rnd = new SecureRandom() + val secretBytes = new Array[Byte](bits / JByte.SIZE) + rnd.nextBytes(secretBytes) + HashCodes.fromBytes(secretBytes).toString() + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 8b9d45f734cda..38a4f4087873a 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -910,4 +910,108 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B taskScheduler.initialize(new FakeSchedulerBackend) } } + + test("Completions in zombie tasksets update status of non-zombie taskset") { + val taskScheduler = setupSchedulerWithMockTaskSetBlacklist() + val valueSer = SparkEnv.get.serializer.newInstance() + + def completeTaskSuccessfully(tsm: TaskSetManager, partition: Int): Unit = { + val indexInTsm = tsm.partitionToIndex(partition) + val matchingTaskInfo = tsm.taskAttempts.flatten.filter(_.index == indexInTsm).head + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(matchingTaskInfo.taskId, result) + } + + // Submit a task set, have it fail with a fetch failed, and then re-submit the task attempt, + // two times, so we have three active task sets for one stage. (For this to really happen, + // you'd need the previous stage to also get restarted, and then succeed, in between each + // attempt, but that happens outside what we're mocking here.) + val zombieAttempts = (0 until 2).map { stageAttempt => + val attempt = FakeTask.createTaskSet(10, stageAttemptId = stageAttempt) + taskScheduler.submitTasks(attempt) + val tsm = taskScheduler.taskSetManagerForAttempt(0, stageAttempt).get + val offers = (0 until 10).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + taskScheduler.resourceOffers(offers) + assert(tsm.runningTasks === 10) + // fail attempt + tsm.handleFailedTask(tsm.taskAttempts.head.head.taskId, TaskState.FAILED, + FetchFailed(null, 0, 0, 0, "fetch failed")) + // the attempt is a zombie, but the tasks are still running (this could be true even if + // we actively killed those tasks, as killing is best-effort) + assert(tsm.isZombie) + assert(tsm.runningTasks === 9) + tsm + } + + // we've now got 2 zombie attempts, each with 9 tasks still active. Submit the 3rd attempt for + // the stage, but this time with insufficient resources so not all tasks are active. + + val finalAttempt = FakeTask.createTaskSet(10, stageAttemptId = 2) + taskScheduler.submitTasks(finalAttempt) + val finalTsm = taskScheduler.taskSetManagerForAttempt(0, 2).get + val offers = (0 until 5).map{ idx => WorkerOffer(s"exec-$idx", s"host-$idx", 1) } + val finalAttemptLaunchedPartitions = taskScheduler.resourceOffers(offers).flatten.map { task => + finalAttempt.tasks(task.index).partitionId + }.toSet + assert(finalTsm.runningTasks === 5) + assert(!finalTsm.isZombie) + + // We simulate late completions from our zombie tasksets, corresponding to all the pending + // partitions in our final attempt. This means we're only waiting on the tasks we've already + // launched. + val finalAttemptPendingPartitions = (0 until 10).toSet.diff(finalAttemptLaunchedPartitions) + finalAttemptPendingPartitions.foreach { partition => + completeTaskSuccessfully(zombieAttempts(0), partition) + } + + // If there is another resource offer, we shouldn't run anything. Though our final attempt + // used to have pending tasks, now those tasks have been completed by zombie attempts. The + // remaining tasks to compute are already active in the non-zombie attempt. + assert( + taskScheduler.resourceOffers(IndexedSeq(WorkerOffer("exec-1", "host-1", 1))).flatten.isEmpty) + + val remainingTasks = finalAttemptLaunchedPartitions.toIndexedSeq.sorted + + // finally, if we finish the remaining partitions from a mix of tasksets, all attempts should be + // marked as zombie. + // for each of the remaining tasks, find the tasksets with an active copy of the task, and + // finish the task. + remainingTasks.foreach { partition => + val tsm = if (partition == 0) { + // we failed this task on both zombie attempts, this one is only present in the latest + // taskset + finalTsm + } else { + // should be active in every taskset. We choose a zombie taskset just to make sure that + // we transition the active taskset correctly even if the final completion comes + // from a zombie. + zombieAttempts(partition % 2) + } + completeTaskSuccessfully(tsm, partition) + } + + assert(finalTsm.isZombie) + + // no taskset has completed all of its tasks, so no updates to the blacklist tracker yet + verify(blacklist, never).updateBlacklistForSuccessfulTaskSet(anyInt(), anyInt(), anyObject()) + + // finally, lets complete all the tasks. We simulate failures in attempt 1, but everything + // else succeeds, to make sure we get the right updates to the blacklist in all cases. + (zombieAttempts ++ Seq(finalTsm)).foreach { tsm => + val stageAttempt = tsm.taskSet.stageAttemptId + tsm.runningTasksSet.foreach { index => + if (stageAttempt == 1) { + tsm.handleFailedTask(tsm.taskInfos(index).taskId, TaskState.FAILED, TaskResultLost) + } else { + val result = new DirectTaskResult[Int](valueSer.serialize(1), Seq()) + tsm.handleSuccessfulTask(tsm.taskInfos(index).taskId, result) + } + } + + // we update the blacklist for the stage attempts with all successful tasks. Even though + // some tasksets had failures, we still consider them all successful from a blacklisting + // perspective, as the failures weren't from a problem w/ the tasks themselves. + verify(blacklist).updateBlacklistForSuccessfulTaskSet(meq(0), meq(stageAttempt), anyObject()) + } + } } diff --git a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala new file mode 100644 index 0000000000000..e57cb701b6284 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.security + +import java.io.Closeable +import java.net._ + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils + +class SocketAuthHelperSuite extends SparkFunSuite { + + private val conf = new SparkConf() + private val authHelper = new SocketAuthHelper(conf) + + test("successful auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + authHelper.authToServer(client) + server.close() + server.join() + assert(server.error == null) + assert(server.authenticated) + } + } + } + + test("failed auth") { + Utils.tryWithResource(new ServerThread()) { server => + Utils.tryWithResource(server.createClient()) { client => + val badHelper = new SocketAuthHelper(new SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128)) + intercept[IllegalArgumentException] { + badHelper.authToServer(client) + } + server.close() + server.join() + assert(server.error != null) + assert(!server.authenticated) + } + } + } + + private class ServerThread extends Thread with Closeable { + + private val ss = new ServerSocket() + ss.bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0)) + + @volatile var error: Exception = _ + @volatile var authenticated = false + + setDaemon(true) + start() + + def createClient(): Socket = { + new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort()) + } + + override def run(): Unit = { + var clientConn: Socket = null + try { + clientConn = ss.accept() + authHelper.authClient(clientConn) + authenticated = true + } catch { + case e: Exception => + error = e + } finally { + Option(clientConn).foreach(_.close()) + } + } + + override def close(): Unit = { + try { + ss.close() + } finally { + interrupt() + } + } + + } + +} diff --git a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala index a04644d57ed88..fe0a9a471a651 100644 --- a/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala +++ b/core/src/test/scala/org/apache/spark/util/AccumulatorV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import org.apache.spark._ +import org.apache.spark.serializer.JavaSerializer class AccumulatorV2Suite extends SparkFunSuite { @@ -162,4 +163,22 @@ class AccumulatorV2Suite extends SparkFunSuite { assert(acc3.isZero) assert(acc3.value === "") } + + test("LegacyAccumulatorWrapper with AccumulatorParam that has no equals/hashCode") { + class MyData(val i: Int) extends Serializable + val param = new AccumulatorParam[MyData] { + override def zero(initialValue: MyData): MyData = new MyData(0) + override def addInPlace(r1: MyData, r2: MyData): MyData = new MyData(r1.i + r2.i) + } + + val acc = new LegacyAccumulatorWrapper(new MyData(0), param) + acc.metadata = AccumulatorMetadata( + AccumulatorContext.newId(), + Some("test"), + countFailedValues = false) + AccumulatorContext.register(acc) + + val ser = new JavaSerializer(new SparkConf).newInstance() + ser.serialize(acc) + } } diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 3ed8a0ec0579e..37feb62a9dc28 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -156,7 +156,7 @@ parquet-jackson-1.8.2.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.4.jar +py4j-0.10.7.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a1137385c0064..d90fb0c762bfd 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -157,7 +157,7 @@ parquet-jackson-1.8.2.jar pmml-model-1.2.15.jar pmml-schema-1.2.15.jar protobuf-java-2.5.0.jar -py4j-0.10.4.jar +py4j-0.10.7.jar pyrolite-4.13.jar scala-compiler-2.11.8.jar scala-library-2.11.8.jar diff --git a/dev/run-pip-tests b/dev/run-pip-tests index d51dde12a03c5..03fc83298dc2f 100755 --- a/dev/run-pip-tests +++ b/dev/run-pip-tests @@ -89,7 +89,7 @@ for python in "${PYTHON_EXECS[@]}"; do source "$VIRTUALENV_PATH"/bin/activate fi # Upgrade pip & friends if using virutal env - if [ ! -n "USE_CONDA" ]; then + if [ ! -n "$USE_CONDA" ]; then pip install --upgrade pip pypandoc wheel numpy fi diff --git a/python/README.md b/python/README.md index 0a5c8010b8486..d1f15681d7ebc 100644 --- a/python/README.md +++ b/python/README.md @@ -29,4 +29,4 @@ The Python packaging for Spark is not intended to replace all of the other use c ## Python Requirements -At its core PySpark depends on Py4J (currently version 0.10.4), but additional sub-packages have their own requirements (including numpy and pandas). \ No newline at end of file +At its core PySpark depends on Py4J (currently version 0.10.7), but additional sub-packages have their own requirements (including numpy and pandas). diff --git a/python/docs/Makefile b/python/docs/Makefile index 5e4cfb8ab6fe3..b8e079483c90c 100644 --- a/python/docs/Makefile +++ b/python/docs/Makefile @@ -7,7 +7,7 @@ SPHINXBUILD ?= sphinx-build PAPER ?= BUILDDIR ?= _build -export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.4-src.zip) +export PYTHONPATH=$(realpath ..):$(realpath ../lib/py4j-0.10.7-src.zip) # User-friendly check for sphinx-build ifeq ($(shell which $(SPHINXBUILD) >/dev/null 2>&1; echo $$?), 1) diff --git a/python/lib/py4j-0.10.4-src.zip b/python/lib/py4j-0.10.4-src.zip deleted file mode 100644 index 8c3829e328726..0000000000000 Binary files a/python/lib/py4j-0.10.4-src.zip and /dev/null differ diff --git a/python/lib/py4j-0.10.7-src.zip b/python/lib/py4j-0.10.7-src.zip new file mode 100644 index 0000000000000..128e321078793 Binary files /dev/null and b/python/lib/py4j-0.10.7-src.zip differ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index ea58b3a93899e..47ba56b2cb108 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -989,8 +989,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) - return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) + sock_info = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, partitions) + return list(_load_from_socket(sock_info, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 7f06d4288c872..e7d1e718c934a 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -29,7 +29,7 @@ from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main -from pyspark.serializers import read_int, write_int +from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer def compute_real_exit_code(exit_code): @@ -40,7 +40,7 @@ def compute_real_exit_code(exit_code): return 1 -def worker(sock): +def worker(sock, authenticated): """ Called by a worker process after the fork(). """ @@ -56,6 +56,18 @@ def worker(sock): # otherwise writes also cause a seek that makes us miss data on the read side. infile = os.fdopen(os.dup(sock.fileno()), "rb", 65536) outfile = os.fdopen(os.dup(sock.fileno()), "wb", 65536) + + if not authenticated: + client_secret = UTF8Deserializer().loads(infile) + if os.environ["PYTHON_WORKER_FACTORY_SECRET"] == client_secret: + write_with_length("ok".encode("utf-8"), outfile) + outfile.flush() + else: + write_with_length("err".encode("utf-8"), outfile) + outfile.flush() + sock.close() + return 1 + exit_code = 0 try: worker_main(infile, outfile) @@ -153,8 +165,11 @@ def handle_sigterm(*args): write_int(os.getpid(), outfile) outfile.flush() outfile.close() + authenticated = False while True: - code = worker(sock) + code = worker(sock, authenticated) + if code == 0: + authenticated = True if not reuse or code: # wait for closing try: diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3c783ae541a1f..7abf2c1c25e72 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -21,16 +21,19 @@ import select import signal import shlex +import shutil import socket import platform +import tempfile +import time from subprocess import Popen, PIPE if sys.version >= '3': xrange = range -from py4j.java_gateway import java_import, JavaGateway, GatewayClient +from py4j.java_gateway import java_import, JavaGateway, GatewayParameters from pyspark.find_spark_home import _find_spark_home -from pyspark.serializers import read_int +from pyspark.serializers import read_int, write_with_length, UTF8Deserializer def launch_gateway(conf=None): @@ -41,6 +44,7 @@ def launch_gateway(conf=None): """ if "PYSPARK_GATEWAY_PORT" in os.environ: gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"]) + gateway_secret = os.environ["PYSPARK_GATEWAY_SECRET"] else: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the @@ -59,40 +63,40 @@ def launch_gateway(conf=None): ]) command = command + shlex.split(submit_args) - # Start a socket that will be used by PythonGatewayServer to communicate its port to us - callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - callback_socket.bind(('127.0.0.1', 0)) - callback_socket.listen(1) - callback_host, callback_port = callback_socket.getsockname() - env = dict(os.environ) - env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host - env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port) - - # Launch the Java gateway. - # We open a pipe to stdin so that the Java gateway can die when the pipe is broken - if not on_windows: - # Don't send ctrl-c / SIGINT to the Java gateway: - def preexec_func(): - signal.signal(signal.SIGINT, signal.SIG_IGN) - proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) - else: - # preexec_fn not supported on Windows - proc = Popen(command, stdin=PIPE, env=env) - - gateway_port = None - # We use select() here in order to avoid blocking indefinitely if the subprocess dies - # before connecting - while gateway_port is None and proc.poll() is None: - timeout = 1 # (seconds) - readable, _, _ = select.select([callback_socket], [], [], timeout) - if callback_socket in readable: - gateway_connection = callback_socket.accept()[0] - # Determine which ephemeral port the server started on: - gateway_port = read_int(gateway_connection.makefile(mode="rb")) - gateway_connection.close() - callback_socket.close() - if gateway_port is None: - raise Exception("Java gateway process exited before sending the driver its port number") + # Create a temporary directory where the gateway server should write the connection + # information. + conn_info_dir = tempfile.mkdtemp() + try: + fd, conn_info_file = tempfile.mkstemp(dir=conn_info_dir) + os.close(fd) + os.unlink(conn_info_file) + + env = dict(os.environ) + env["_PYSPARK_DRIVER_CONN_INFO_PATH"] = conn_info_file + + # Launch the Java gateway. + # We open a pipe to stdin so that the Java gateway can die when the pipe is broken + if not on_windows: + # Don't send ctrl-c / SIGINT to the Java gateway: + def preexec_func(): + signal.signal(signal.SIGINT, signal.SIG_IGN) + proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env) + else: + # preexec_fn not supported on Windows + proc = Popen(command, stdin=PIPE, env=env) + + # Wait for the file to appear, or for the process to exit, whichever happens first. + while not proc.poll() and not os.path.isfile(conn_info_file): + time.sleep(0.1) + + if not os.path.isfile(conn_info_file): + raise Exception("Java gateway process exited before sending its port number") + + with open(conn_info_file, "rb") as info: + gateway_port = read_int(info) + gateway_secret = UTF8Deserializer().loads(info) + finally: + shutil.rmtree(conn_info_dir) # In Windows, ensure the Java child processes do not linger after Python has exited. # In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when @@ -111,7 +115,9 @@ def killChild(): atexit.register(killChild) # Connect to the gateway - gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=True) + gateway = JavaGateway( + gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, + auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") @@ -125,3 +131,16 @@ def killChild(): java_import(gateway.jvm, "scala.Tuple2") return gateway + + +def do_server_auth(conn, auth_secret): + """ + Performs the authentication protocol defined by the SocketAuthHelper class on the given + file-like object 'conn'. + """ + write_with_length(auth_secret.encode("utf-8"), conn) + conn.flush() + reply = UTF8Deserializer().loads(conn) + if reply != "ok": + conn.close() + raise Exception("Unexpected reply from iterator server.") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index aca00bc3b26fc..864cebb9517d9 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,9 +39,11 @@ else: from itertools import imap as map, ifilter as filter +from pyspark.java_gateway import do_server_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ - PickleSerializer, pack_long, AutoBatchedSerializer + PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ + UTF8Deserializer from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_full_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -119,7 +121,8 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) -def _load_from_socket(port, serializer): +def _load_from_socket(sock_info, serializer): + port, auth_secret = sock_info sock = None # Support for both IPv4 and IPv6. # On most of IPv6-ready systems, IPv6 will take precedence. @@ -139,8 +142,12 @@ def _load_from_socket(port, serializer): # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) + + sockfile = sock.makefile("rwb", 65536) + do_server_auth(sockfile, auth_secret) + # The socket will be automatically closed when garbage-collected. - return serializer.load_stream(sock.makefile("rb", 65536)) + return serializer.load_stream(sockfile) def ignore_unicode_prefix(f): @@ -806,8 +813,8 @@ def collect(self): to be small, as all the data is loaded into the driver's memory. """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) - return list(_load_from_socket(port, self._jrdd_deserializer)) + sock_info = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(sock_info, self._jrdd_deserializer)) def reduce(self, f): """ @@ -2364,8 +2371,8 @@ def toLocalIterator(self): [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] """ with SCCallSiteSync(self.context) as css: - port = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) - return _load_from_socket(port, self._jrdd_deserializer) + sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) + return _load_from_socket(sock_info, self._jrdd_deserializer) def _prepare_for_python_RDD(sc, command): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index d1b336df40f91..1747d3635fe4a 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -435,8 +435,8 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.collectToPython() - return list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) + sock_info = self._jdf.collectToPython() + return list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))) @ignore_unicode_prefix @since(2.0) @@ -449,8 +449,8 @@ def toLocalIterator(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - port = self._jdf.toPythonIterator() - return _load_from_socket(port, BatchedSerializer(PickleSerializer())) + sock_info = self._jdf.toPythonIterator() + return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer())) @ignore_unicode_prefix @since(1.3) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index baaa3fe074e9a..0c8996e21ee6c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,6 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry +from pyspark.java_gateway import do_server_auth from pyspark.taskcontext import TaskContext from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ @@ -208,9 +209,11 @@ def process(): if __name__ == '__main__': - # Read a local port to connect to from stdin - java_port = int(sys.stdin.readline()) + # Read information about how to connect back to the JVM from the environment. + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", java_port)) sock_file = sock.makefile("rwb", 65536) + do_server_auth(sock_file, auth_secret) main(sock_file, sock_file) diff --git a/python/setup.py b/python/setup.py index 7e63461d289b2..fa8dc12c92caf 100644 --- a/python/setup.py +++ b/python/setup.py @@ -196,7 +196,7 @@ def _supports_symlinks(): 'pyspark.examples.src.main.python': ['*.py', '*/*.py']}, scripts=scripts, license='http://www.apache.org/licenses/LICENSE-2.0', - install_requires=['py4j==0.10.4'], + install_requires=['py4j==0.10.7'], setup_requires=['pypandoc'], extras_require={ 'ml': ['numpy>=1.7'], diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 9354a3bef7ba5..41f351a5f5542 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -485,9 +485,9 @@ private[spark] class MesosClusterScheduler( .filter { case (key, _) => !replicatedOptionsBlacklist.contains(key) } .toMap (defaultConf ++ driverConf).foreach { case (key, value) => - options ++= Seq("--conf", s""""$key=${shellEscape(value)}"""".stripMargin) } + options ++= Seq("--conf", s"${key}=${value}") } - options + options.map(shellEscape) } /** diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 638ee0f86d06d..a94a8ab515459 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -1139,7 +1139,7 @@ private[spark] class Client( val pyArchivesFile = new File(pyLibPath, "pyspark.zip") require(pyArchivesFile.exists(), s"$pyArchivesFile not found; cannot run pyspark application in YARN mode.") - val py4jFile = new File(pyLibPath, "py4j-0.10.4-src.zip") + val py4jFile = new File(pyLibPath, "py4j-0.10.7-src.zip") require(py4jFile.exists(), s"$py4jFile not found; cannot run pyspark application in YARN mode.") Seq(pyArchivesFile.getAbsolutePath(), py4jFile.getAbsolutePath()) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 59adb7e22d185..f18890c842ed6 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -249,7 +249,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite { // needed locations. val sparkHome = sys.props("spark.test.home") val pythonPath = Seq( - s"$sparkHome/python/lib/py4j-0.10.4-src.zip", + s"$sparkHome/python/lib/py4j-0.10.7-src.zip", s"$sparkHome/python") val extraEnvVars = Map( "PYSPARK_ARCHIVES_PATH" -> pythonPath.map("local:" + _).mkString(File.pathSeparator), diff --git a/sbin/spark-config.sh b/sbin/spark-config.sh index f2d9e6b568a9b..bf3da18c3706e 100755 --- a/sbin/spark-config.sh +++ b/sbin/spark-config.sh @@ -28,6 +28,6 @@ export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"${SPARK_HOME}/conf"}" # Add the PySpark classes to the PYTHONPATH: if [ -z "${PYSPARK_PYTHONPATH_SET}" ]; then export PYTHONPATH="${SPARK_HOME}/python:${PYTHONPATH}" - export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.4-src.zip:${PYTHONPATH}" + export PYTHONPATH="${SPARK_HOME}/python/lib/py4j-0.10.7-src.zip:${PYTHONPATH}" export PYSPARK_PYTHONPATH_SET=1 fi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 4c8b177237d23..371eddaa3705d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalyst +import java.util.Locale + import com.google.common.collect.Maps +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{Resolver, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -131,6 +135,88 @@ package object expressions { def indexOf(exprId: ExprId): Int = { Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) } + + private def unique[T](m: Map[T, Seq[Attribute]]): Map[T, Seq[Attribute]] = { + m.mapValues(_.distinct).map(identity) + } + + /** Map to use for direct case insensitive attribute lookups. */ + @transient private lazy val direct: Map[String, Seq[Attribute]] = { + unique(attrs.groupBy(_.name.toLowerCase(Locale.ROOT))) + } + + /** Map to use for qualified case insensitive attribute lookups. */ + @transient private val qualified: Map[(String, String), Seq[Attribute]] = { + val grouped = attrs.filter(_.qualifier.isDefined).groupBy { a => + (a.qualifier.get.toLowerCase(Locale.ROOT), a.name.toLowerCase(Locale.ROOT)) + } + unique(grouped) + } + + /** Perform attribute resolution given a name and a resolver. */ + def resolve(nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = { + // Collect matching attributes given a name and a lookup. + def collectMatches(name: String, candidates: Option[Seq[Attribute]]): Seq[Attribute] = { + candidates.toSeq.flatMap(_.collect { + case a if resolver(a.name, name) => a.withName(name) + }) + } + + // Find matches for the given name assuming that the 1st part is a qualifier (i.e. table name, + // alias, or subquery alias) and the 2nd part is the actual name. This returns a tuple of + // matched attributes and a list of parts that are to be resolved. + // + // For example, consider an example where "a" is the table name, "b" is the column name, + // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", + // and the second element will be List("c"). + val matches = nameParts match { + case qualifier +: name +: nestedFields => + val key = (qualifier.toLowerCase(Locale.ROOT), name.toLowerCase(Locale.ROOT)) + val attributes = collectMatches(name, qualified.get(key)).filter { a => + resolver(qualifier, a.qualifier.get) + } + (attributes, nestedFields) + case all => + (Nil, all) + } + + // If none of attributes match `table.column` pattern, we try to resolve it as a column. + val (candidates, nestedFields) = matches match { + case (Seq(), _) => + val name = nameParts.head + val attributes = collectMatches(name, direct.get(name.toLowerCase(Locale.ROOT))) + (attributes, nameParts.tail) + case _ => matches + } + + def name = UnresolvedAttribute(nameParts).name + candidates match { + case Seq(a) if nestedFields.nonEmpty => + // One match, but we also need to extract the requested nested field. + // The foldLeft adds ExtractValues for every remaining parts of the identifier, + // and aliased it with the last part of the name. + // For example, consider "a.b.c", where "a" is resolved to an existing attribute. + // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final + // expression as "c". + val fieldExprs = nestedFields.foldLeft(a: Expression) { (e, name) => + ExtractValue(e, Literal(name), resolver) + } + Some(Alias(fieldExprs, nestedFields.last)()) + + case Seq(a) => + // One match, no nested fields, use it. + Some(a) + + case Seq() => + // No matches. + None + + case ambiguousReferences => + // More than one match. + val referenceNames = ambiguousReferences.mkString(", ") + throw new AnalysisException(s"Reference '$name' is ambiguous, could be: $referenceNames.") + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 2ebb2ff323c6b..5c65bbd16d305 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -160,6 +160,10 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } } + private[this] lazy val childAttributes = AttributeSeq(children.flatMap(_.output)) + + private[this] lazy val outputAttributes = AttributeSeq(output) + /** * Optionally resolves the given strings to a [[NamedExpression]] using the input from all child * nodes of this LogicalPlan. The attribute is expressed as @@ -168,7 +172,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { def resolveChildren( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, children.flatMap(_.output), resolver) + childAttributes.resolve(nameParts, resolver) /** * Optionally resolves the given strings to a [[NamedExpression]] based on the output of this @@ -178,7 +182,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { def resolve( nameParts: Seq[String], resolver: Resolver): Option[NamedExpression] = - resolve(nameParts, output, resolver) + outputAttributes.resolve(nameParts, resolver) /** * Given an attribute name, split it to name parts by dot, but @@ -188,105 +192,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { def resolveQuoted( name: String, resolver: Resolver): Option[NamedExpression] = { - resolve(UnresolvedAttribute.parseAttributeName(name), output, resolver) - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * This assumes `name` has multiple parts, where the 1st part is a qualifier - * (i.e. table name, alias, or subquery alias). - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsTableColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - assert(nameParts.length > 1) - if (attribute.qualifier.exists(resolver(_, nameParts.head))) { - // At least one qualifier matches. See if remaining parts match. - val remainingParts = nameParts.tail - resolveAsColumn(remainingParts, resolver, attribute) - } else { - None - } - } - - /** - * Resolve the given `name` string against the given attribute, returning either 0 or 1 match. - * - * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier. - * See the comment above `candidates` variable in resolve() for semantics the returned data. - */ - private def resolveAsColumn( - nameParts: Seq[String], - resolver: Resolver, - attribute: Attribute): Option[(Attribute, List[String])] = { - if (!attribute.isGenerated && resolver(attribute.name, nameParts.head)) { - Option((attribute.withName(nameParts.head), nameParts.tail.toList)) - } else { - None - } - } - - /** Performs attribute resolution given a name and a sequence of possible attributes. */ - protected def resolve( - nameParts: Seq[String], - input: Seq[Attribute], - resolver: Resolver): Option[NamedExpression] = { - - // A sequence of possible candidate matches. - // Each candidate is a tuple. The first element is a resolved attribute, followed by a list - // of parts that are to be resolved. - // For example, consider an example where "a" is the table name, "b" is the column name, - // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b", - // and the second element will be List("c"). - var candidates: Seq[(Attribute, List[String])] = { - // If the name has 2 or more parts, try to resolve it as `table.column` first. - if (nameParts.length > 1) { - input.flatMap { option => - resolveAsTableColumn(nameParts, resolver, option) - } - } else { - Seq.empty - } - } - - // If none of attributes match `table.column` pattern, we try to resolve it as a column. - if (candidates.isEmpty) { - candidates = input.flatMap { candidate => - resolveAsColumn(nameParts, resolver, candidate) - } - } - - def name = UnresolvedAttribute(nameParts).name - - candidates.distinct match { - // One match, no nested fields, use it. - case Seq((a, Nil)) => Some(a) - - // One match, but we also need to extract the requested nested field. - case Seq((a, nestedFields)) => - // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliased it with the last part of the name. - // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias the final - // expression as "c". - val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => - ExtractValue(expr, Literal(fieldName), resolver)) - Some(Alias(fieldExprs, nestedFields.last)()) - - // No matches. - case Seq() => - logTrace(s"Could not find $name in ${input.mkString(", ")}") - None - - // More than one match. - case ambiguousReferences => - val referenceNames = ambiguousReferences.map(_._1).mkString(", ") - throw new AnalysisException( - s"Reference '$name' is ambiguous, could be: $referenceNames.") - } + outputAttributes.resolve(UnresolvedAttribute.parseAttributeName(name), resolver) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index fb700a489c763..4c13daa386a8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2804,7 +2804,7 @@ class Dataset[T] private[sql]( EvaluatePython.javaToPython(rdd) } - private[sql] def collectToPython(): Int = { + private[sql] def collectToPython(): Array[Any] = { EvaluatePython.registerPicklers() withNewExecutionId { val toJava: (Any) => Any = EvaluatePython.toJava(_, schema) @@ -2814,7 +2814,7 @@ class Dataset[T] private[sql]( } } - private[sql] def toPythonIterator(): Int = { + private[sql] def toPythonIterator(): Array[Any] = { withNewExecutionId { PythonRDD.toLocalIteratorAndServe(javaToPython.rdd) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a6761..2b37047612dfe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -57,30 +57,31 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { for (i <- 0 until 3) { val preferredMirror = Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") if (Seq("wget", url, "-q", "-P", path).! == 0) { - return + val downloaded = new File(sparkTestingDir, filename).getCanonicalPath + val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath + + Seq("mkdir", targetDir).! + val exitCode = Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! + Seq("rm", downloaded).! + + // For a corrupted file, `tar` returns non-zero values. However, we also need to check + // the extracted file because `tar` returns 0 for empty file. + val sparkSubmit = new File(sparkTestingDir, s"spark-$version/bin/spark-submit") + if (exitCode == 0 && sparkSubmit.exists()) { + return + } else { + Seq("rm", "-rf", targetDir).! + } } logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } - - private def downloadSpark(version: String): Unit = { - tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) - - val downloaded = new File(sparkTestingDir, s"spark-$version-bin-hadoop2.7.tgz").getCanonicalPath - val targetDir = new File(sparkTestingDir, s"spark-$version").getCanonicalPath - - Seq("mkdir", targetDir).! - - Seq("tar", "-xzf", downloaded, "-C", targetDir, "--strip-components=1").! - - Seq("rm", downloaded).! - } - private def genDataDir(name: String): String = { new File(tmpDataDir, name).getCanonicalPath } @@ -125,7 +126,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { PROCESS_TABLES.testingVersions.zipWithIndex.foreach { case (version, index) => val sparkHome = new File(sparkTestingDir, s"spark-$version") if (!sparkHome.exists()) { - downloadSpark(version) + tryDownloadSpark(version, sparkTestingDir.getCanonicalPath) } val args = Seq(