From 3f8209fe664836242ab849d4e8503c4443798556 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Sat, 17 Oct 2015 17:09:21 -0700 Subject: [PATCH 1/9] [SPARK-11140] [core] Transfer files using network lib when using NettyRpcEnv. This change abstracts the code that serves jars / files to executors so that each RpcEnv can have its own implementation; the akka version uses the existing HTTP-based file serving mechanism, while the netty versions uses the new stream support added to the network lib, which makes file transfers benefit from the easier security configuration of the network library, and should also reduce overhead overall. The change includes a small fix to TransportChannelHandler so that it propagates user events to downstream handlers. --- .../scala/org/apache/spark/SparkContext.scala | 8 +- .../scala/org/apache/spark/SparkEnv.scala | 14 -- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 +++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 59 +++++- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 189 ++++++++++++++++-- .../spark/rpc/netty/NettyStreamManager.scala | 63 ++++++ .../scala/org/apache/spark/util/Utils.scala | 6 + .../org/apache/spark/rpc/RpcEnvSuite.scala | 40 +++- .../rpc/netty/NettyRpcHandlerSuite.scala | 10 +- .../launcher/AbstractCommandBuilder.java | 2 +- .../client/TransportClientFactory.java | 15 ++ .../server/TransportChannelHandler.java | 1 + 12 files changed, 410 insertions(+), 43 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 7421821e2601..70e4421c4291 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1362,7 +1362,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } val key = if (!isLocal && scheme == "file") { - env.httpFileServer.addFile(new File(uri.getPath)) + env.rpcEnv.fileServer.addFile(new File(uri.getPath)) } else { schemeCorrectedPath } @@ -1613,7 +1613,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli var key = "" if (path.contains("\\")) { // For local paths with backslashes on Windows, URI throws an exception - key = env.httpFileServer.addJar(new File(path)) + key = env.rpcEnv.fileServer.addJar(new File(path)) } else { val uri = new URI(path) key = uri.getScheme match { @@ -1627,7 +1627,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // of the AM to make it show up in the current working directory. val fileName = new Path(uri.getPath).getName() try { - env.httpFileServer.addJar(new File(fileName)) + env.rpcEnv.fileServer.addJar(new File(fileName)) } catch { case e: Exception => // For now just log an error but allow to go through so spark examples work. @@ -1638,7 +1638,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } else { try { - env.httpFileServer.addJar(new File(uri.getPath)) + env.rpcEnv.fileServer.addJar(new File(uri.getPath)) } catch { case exc: FileNotFoundException => logError(s"Jar not found at $path") diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 23ae9360f6a2..1649a0ca07c5 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -66,7 +66,6 @@ class SparkEnv ( val blockTransferService: BlockTransferService, val blockManager: BlockManager, val securityManager: SecurityManager, - val httpFileServer: HttpFileServer, val sparkFilesDir: String, val metricsSystem: MetricsSystem, val memoryManager: MemoryManager, @@ -91,7 +90,6 @@ class SparkEnv ( if (!isStopped) { isStopped = true pythonWorkers.values.foreach(_.stop()) - Option(httpFileServer).foreach(_.stop()) mapOutputTracker.stop() shuffleManager.stop() broadcastManager.stop() @@ -360,17 +358,6 @@ object SparkEnv extends Logging { val cacheManager = new CacheManager(blockManager) - val httpFileServer = - if (isDriver) { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - conf.set("spark.fileserver.uri", server.serverUri) - server - } else { - null - } - val metricsSystem = if (isDriver) { // Don't start metrics system right now for Driver. // We need to wait for the task scheduler to give us an app ID. @@ -415,7 +402,6 @@ object SparkEnv extends Logging { blockTransferService, blockManager, securityManager, - httpFileServer, sparkFilesDir, metricsSystem, memoryManager, diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a560fd10cdf7..54d6ed7201ef 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -17,6 +17,8 @@ package org.apache.spark.rpc +import java.io.File + import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} @@ -132,8 +134,52 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. */ def deserialize[T](deserializationAction: () => T): T + + /** + * Return the instance of the file server used to serve files. This may be `null` if the + * RpcEnv is not operating in server mode. + */ + def fileServer: RpcEnvFileServer + + /** + * Fetch a file from the given URI. If the URIs returned by the RpcEnvFileServer use the "spark" + * scheme, this method will be called to retrieve the files. + * + * @param uri URI with location of the file. + * @param dest Local destination of file. + * @param overwrite Whether to overwrite the target file if it exists. + */ + def fetchFile(uri: String, dest: File, overwrite: Boolean): Unit + } +/** + * A server used by the RpcEnv to server files to other processes owned by the application. + * + * The file server can return URIs handled by common libraries (such as "http" or "hdfs"), or + * it can return "spark" URIs which will be handled by `RpcEnv#fetchFile`. + */ +private[spark] trait RpcEnvFileServer { + + /** + * Adds a file to be served by this RpcEnv. This is used to serve files from the driver + * to executors when they're stored on the driver's local file system. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addFile(file: File): String + + /** + * Adds a jar to be served by this RpcEnv. Similar to `addFile` but for jars added using + * `SparkContext.addJar`. + * + * @param file Local file to serve. + * @return A URI for the location of the file. + */ + def addJar(file: File): String + +} private[spark] case class RpcEnvConfig( conf: SparkConf, diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 3fad595a0d0b..3552ea5aedc8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -17,6 +17,7 @@ package org.apache.spark.rpc.akka +import java.io.File import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future @@ -30,7 +31,7 @@ import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import akka.serialization.JavaSerializer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} @@ -41,7 +42,10 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} * remove Akka from the dependencies. */ private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + val actorSystem: ActorSystem, + val securityManager: SecurityManager, + conf: SparkConf, + boundPort: Int) extends RpcEnv(conf) with Logging { private val defaultAddress: RpcAddress = { @@ -64,6 +68,8 @@ private[spark] class AkkaRpcEnv private[akka] ( */ private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + private val _fileServer = new AkkaFileServer(conf, securityManager) + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { endpointToRef.put(endpoint, endpointRef) refToEndpoint.put(endpointRef, endpoint) @@ -223,6 +229,7 @@ private[spark] class AkkaRpcEnv private[akka] ( override def shutdown(): Unit = { actorSystem.shutdown() + _fileServer.shutdown() } override def stop(endpoint: RpcEndpointRef): Unit = { @@ -241,6 +248,52 @@ private[spark] class AkkaRpcEnv private[akka] ( deserializationAction() } } + + override def fetchFile(uri: String, dest: File, overwrite: Boolean): Unit = { + throw new UnsupportedOperationException( + "AkkaRpcEnv's files should be retrieved using an HTTP client.") + } + + override def fileServer: RpcEnvFileServer = _fileServer + +} + +private[akka] class AkkaFileServer( + conf: SparkConf, + securityManager: SecurityManager) extends RpcEnvFileServer { + + @volatile private var httpFileServer: HttpFileServer = _ + + override def addFile(file: File): String = { + getFileServer().addFile(file) + } + + override def addJar(file: File): String = { + getFileServer().addJar(file) + } + + def shutdown(): Unit = { + if (httpFileServer != null) { + httpFileServer.stop() + } + } + + private def getFileServer(): HttpFileServer = { + if (httpFileServer == null) synchronized { + if (httpFileServer == null) { + httpFileServer = startFileServer() + } + } + httpFileServer + } + + private def startFileServer(): HttpFileServer = { + val fileServerPort = conf.getInt("spark.fileserver.port", 0) + val server = new HttpFileServer(conf, securityManager, fileServerPort) + server.initialize() + server + } + } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -249,7 +302,7 @@ private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( config.name, config.host, config.port, config.conf, config.securityManager) actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.conf, boundPort) + new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 09093819bb22..63a87a9e804e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,10 +20,11 @@ import java.io._ import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer +import java.nio.channels.FileChannel +import java.nio.file.StandardOpenOption import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean -import javax.annotation.Nullable; -import javax.annotation.concurrent.GuardedBy +import javax.annotation.Nullable import scala.collection.mutable import scala.concurrent.{Future, Promise} @@ -31,7 +32,9 @@ import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal -import com.google.common.base.Preconditions +import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.timeout.IdleStateEvent + import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -48,14 +51,16 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager) extends RpcEnv(conf) with Logging { - private val transportConf = SparkTransportConf.fromSparkConf( + private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set("spark.shuffle.io.numConnectionsPerPeer", "1"), conf.getInt("spark.rpc.io.threads", 0)) private val dispatcher: Dispatcher = new Dispatcher(this) + private val streamManager = new NettyStreamManager(this) + private val transportContext = new TransportContext(transportConf, - new NettyRpcHandler(dispatcher, this)) + new NettyRpcHandler(dispatcher, this, streamManager)) private val clientFactory = { val bootstraps: java.util.List[TransportClientBootstrap] = @@ -70,6 +75,21 @@ private[netty] class NettyRpcEnv( val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") + /** + * A collection of cached clients used for downloading files, to avoid the cost of connection + * establishment when multiple files are downloaded in a short period of time. The clients are + * closed after an inactivity timeout. + * + * Clients are cached for each remote host, although in most cases there will be a single remote + * host serving files. + * + * These TransportClient instances are not used for RPC, so they're never registered with the + * NettyRpcHandler instance. The outcome is that connection / disconnection events are not + * sent for these clients, which is desirable since otherwise other parts of the code might + * think the driver (for example) is disconnecting when that's not the case. + */ + private val fileClients = new mutable.HashMap[RpcAddress, TransportClient]() + // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool // to implement non-blocking send/ask. // TODO: a non-blocking TransportClientFactory.createClient in future @@ -294,6 +314,10 @@ private[netty] class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } + synchronized { + fileClients.values.foreach(_.close()) + fileClients.clear() + } } override def deserialize[T](deserializationAction: () => T): T = { @@ -302,6 +326,138 @@ private[netty] class NettyRpcEnv( } } + override def fileServer: RpcEnvFileServer = streamManager + + override def fetchFile(uri: String, dest: File, overwrite: Boolean): Unit = { + val parsedUri = new URI(uri) + require(parsedUri.getHost() != null, "Host name must be defined.") + require(parsedUri.getPort() > 0, "Port must be defined.") + require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") + + val t1 = System.currentTimeMillis() + val callback = new FileDownloadCallback(dest, overwrite) + try { + val client = fileDownloadClient(parsedUri.getHost(), parsedUri.getPort()) + val t2 = System.currentTimeMillis() + client.stream(parsedUri.getPath(), callback) + callback.waitForCompletion() + val t3 = System.currentTimeMillis() + logDebug(s"Downloaded ${parsedUri.getPath()}: " + + s"${Utils.bytesToString(callback.transferred)} bytes in " + + s"${t3 - t1}ms (${t2 - t1}ms spent in setup)") + } finally { + callback.dispose() + } + } + + private def fileDownloadClient(host: String, port: Int): TransportClient = synchronized { + if (stopped.get()) { + throw new IllegalStateException("RpcEnv already stopped.") + } + + val address = RpcAddress(host, port) + fileClients.get(address).filter(_.isActive()).getOrElse { + // Create a new client and install a handler that will respond to IdleStateEvent. The events + // are generated by the IdleStateHandler installed by the TransportContext when creating + // clients, and the timeout value is controlled by the transport configuration. + val c = clientFactory.createUnmanagedClient(host, port) + c.getChannel().pipeline().addLast("rpcEnvTimeoutHandler", new TimeoutHandler(c)) + fileClients.put(address, c) + c + } + } + + private class TimeoutHandler(client: TransportClient) extends ChannelInboundHandlerAdapter { + + override def userEventTriggered(ctx: ChannelHandlerContext, evt: Object): Unit = { + if (evt.isInstanceOf[IdleStateEvent] && client.isActive()) { + logDebug(s"Closing transport client $client after idle timeout.") + val socketAddr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val address = RpcAddress(socketAddr.getHostName(), socketAddr.getPort()) + ctx.close() + NettyRpcEnv.this.synchronized { + fileClients.remove(address) + } + } + ctx.fireUserEventTriggered(evt) + } + + } + + private class FileDownloadCallback(dest: File, overwrite: Boolean) extends StreamCallback { + + @volatile var error: Throwable = null + @volatile var transferred = 0L + + private val temp = File.createTempFile(dest.getName(), ".tmp", dest.getParentFile()) + private val out = FileChannel.open(temp.toPath(), StandardOpenOption.WRITE) + private val lock = new Object() + @volatile private var complete = false + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + val count = buf.remaining() + while (buf.remaining() > 0) { + out.write(buf) + } + transferred += count + } + + override def onComplete(streamId: String): Unit = { + try { + out.close() + if (dest.exists()) { + if (overwrite) { + // Explicitly delete the file since `renameTo` doesn't work on Win32 if the target + // already exists. + require(dest.delete(), s"Failed to delete $dest.") + } else { + throw new IOException(s"Destination file $dest already exists.") + } + } + if (!temp.renameTo(dest)) { + throw new IOException(s"Failed to rename temp file to $dest.") + } + } catch { + case e: Exception => + temp.delete() + throw e + } finally { + finish() + } + } + + override def onFailure(streamId: String, cause: Throwable): Unit = { + error = cause + dispose() + finish() + } + + def waitForCompletion(): Unit = { + lock.synchronized { + while (!complete) { + lock.wait(TimeUnit.SECONDS.toMillis(5)) + logDebug(s"${dest.getName()}: transferred ${Utils.bytesToString(transferred)} bytes.") + } + } + if (error != null) { + throw error + } + } + + def dispose(): Unit = { + Utils.tryLogNonFatalError { out.close() } + temp.delete() + } + + private def finish(): Unit = { + complete = true + lock.synchronized { + lock.notifyAll() + } + } + + } + } private[netty] object NettyRpcEnv extends Logging { @@ -423,7 +579,7 @@ private[netty] class NettyRpcEndpointRef( override def toString: String = s"NettyRpcEndpointRef(${_address})" - def toURI: URI = new URI(s"spark://${_address}") + def toURI: URI = new URI(_address.toString) final override def equals(that: Any): Boolean = that match { case other: NettyRpcEndpointRef => _address == other._address @@ -474,7 +630,9 @@ private[netty] case class RpcFailure(e: Throwable) * with different `RpcAddress` information). */ private[netty] class NettyRpcHandler( - dispatcher: Dispatcher, nettyEnv: NettyRpcEnv) extends RpcHandler with Logging { + dispatcher: Dispatcher, + nettyEnv: NettyRpcEnv, + streamManager: StreamManager) extends RpcHandler with Logging { // TODO: Can we add connection callback (channel registered) to the underlying framework? // A variable to track whether we should dispatch the RemoteProcessConnected message. @@ -501,13 +659,15 @@ private[netty] class NettyRpcHandler( dispatcher.postRemoteMessage(messageToDispatch, callback) } - override def getStreamManager: StreamManager = new OneForOneStreamManager + override def getStreamManager: StreamManager = streamManager override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) + if (clients.containsKey(client)) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. // See java.net.Socket.getRemoteSocketAddress @@ -519,10 +679,11 @@ private[netty] class NettyRpcHandler( override def connectionTerminated(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - clients.remove(client) - nettyEnv.removeOutbox(clientAddr) - dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) + if (clients.remove(client) != null) { + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + nettyEnv.removeOutbox(clientAddr) + dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) + } } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala new file mode 100644 index 000000000000..eb1d2604fb23 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.rpc.netty + +import java.io.File +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} +import org.apache.spark.network.server.StreamManager +import org.apache.spark.rpc.RpcEnvFileServer + +/** + * StreamManager implementation for serving files from a NettyRpcEnv. + */ +private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) + extends StreamManager with RpcEnvFileServer { + + private val files = new ConcurrentHashMap[String, File]() + private val jars = new ConcurrentHashMap[String, File]() + + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + throw new UnsupportedOperationException() + } + + override def openStream(streamId: String): ManagedBuffer = { + val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) + val file = ftype match { + case "files" => files.get(fname) + case "jars" => jars.get(fname) + case _ => throw new IllegalArgumentException(s"Invalid file type: $ftype") + } + + require(file != null, s"File not found: $streamId") + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } + + override def addFile(file: File): String = { + require(files.putIfAbsent(file.getName(), file) == null, + s"File ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + } + + override def addJar(file: File): String = { + require(jars.putIfAbsent(file.getName(), file) == null, + s"JAR ${file.getName()} already registered.") + s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + } + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 5a976ee839b1..2c15b909ac80 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -535,6 +535,12 @@ private[spark] object Utils extends Logging { val uri = new URI(url) val fileOverwrite = conf.getBoolean("spark.files.overwrite", defaultValue = false) Option(uri.getScheme).getOrElse("file") match { + case "spark" => + if (SparkEnv.get == null) { + throw new IllegalStateException( + "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") + } + SparkEnv.get.rpcEnv.fetchFile(url, new File(targetDir, filename), fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 834e4743df86..b8e688c87c62 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.rpc -import java.io.NotSerializableException +import java.io.{File, NotSerializableException} +import java.util.UUID +import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable @@ -25,27 +27,36 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps +import com.google.common.io.Files +import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Eventually._ -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} +import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.util.Utils /** * Common tests for an RpcEnv implementation. */ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { + val conf = new SparkConf() var env: RpcEnv = _ override def beforeAll(): Unit = { - val conf = new SparkConf() env = createRpcEnv(conf, "local", 12345) + + val sparkEnv = mock(classOf[SparkEnv]) + when(sparkEnv.rpcEnv).thenReturn(env) + SparkEnv.set(sparkEnv) } override def afterAll(): Unit = { if (env != null) { env.shutdown() } + SparkEnv.set(null) } def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv @@ -713,6 +724,29 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { assert(shortTimeout.timeoutProp.r.findAllIn(reply4).length === 1) } + test("file server") { + val tempDir = Utils.createTempDir() + val file = new File(tempDir, "file") + Files.write(UUID.randomUUID().toString(), file, UTF_8) + val jar = new File(tempDir, "jar") + Files.write(UUID.randomUUID().toString(), jar, UTF_8) + + val fileUri = env.fileServer.addFile(file) + val jarUri = env.fileServer.addJar(jar) + + val destDir = Utils.createTempDir() + val destFile = new File(destDir, file.getName()) + val destJar = new File(destDir, jar.getName()) + + val sm = new SecurityManager(conf) + val hc = SparkHadoopUtil.get.conf + Utils.fetchFile(fileUri, destDir, conf, sm, hc, 0L, false) + Utils.fetchFile(jarUri, destDir, conf, sm, hc, 0L, false) + + assert(Files.equal(file, destFile)) + assert(Files.equal(jar, destJar)) + } + } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index f9d8e80c98b6..ccca795683da 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -25,17 +25,19 @@ import org.mockito.Matchers._ import org.apache.spark.SparkFunSuite import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ class NettyRpcHandlerSuite extends SparkFunSuite { val env = mock(classOf[NettyRpcEnv]) - when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())). - thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) + val sm = mock(classOf[StreamManager]) + when(env.deserialize(any(classOf[TransportClient]), any(classOf[Array[Byte]]))(any())) + .thenReturn(RequestMessage(RpcAddress("localhost", 12345), null, null, false)) test("receive") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) @@ -47,7 +49,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { test("connectionTerminated") { val dispatcher = mock(classOf[Dispatcher]) - val nettyRpcHandler = new NettyRpcHandler(dispatcher, env) + val nettyRpcHandler = new NettyRpcHandler(dispatcher, env, sm) val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 3ee6bd92e47f..55fe156cf665 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -148,7 +148,7 @@ List buildClassPath(String appClassPath) throws IOException { String scala = getScalaVersion(); List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", - "yarn", "launcher"); + "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { if (!isTesting) { System.err.println( diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 4952ffb44bb8..ea5f262be83f 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -158,6 +158,21 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO } } + /** + * Creates a new client that will not be managed by this factory. + * + * The client will not be shared with callers of {@link #createClient(String, int)} and will + * not be closed when the factory is disposed of. + * + * As with {@link #createClient(String, int)}, this method is blocking. + */ + public TransportClient createUnmanagedClient(String remoteHost, int remotePort) + throws IOException { + + InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort); + return createClient(address); + } + /** Create a completely new {@link TransportClient} to the remote address. */ private TransportClient createClient(InetSocketAddress address) throws IOException { logger.debug("Creating new connection to " + address); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 8e0ee709e38e..93616a3a8ad7 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -123,5 +123,6 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc ctx.close(); } } + ctx.fireUserEventTriggered(evt); } } From e9f375265fa917bfb440fbebd98179257c021170 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Nov 2015 13:15:23 -0800 Subject: [PATCH 2/9] Change the file retrieval API. This reuses more of the existing code in Utils, and also makes some subsequent changes built on top of this API possible, at the expense of some efficiency (using input streams instead of channels). If desired the Utils class can later be changed to use channels to regain the lost efficiency. --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 10 +- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 3 +- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 106 ++++++------------ .../scala/org/apache/spark/util/Utils.scala | 5 +- 4 files changed, 47 insertions(+), 77 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 54d6ed7201ef..3d7d281b0dd6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc import java.io.File +import java.nio.channels.ReadableByteChannel import scala.concurrent.Future @@ -142,14 +143,13 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { def fileServer: RpcEnvFileServer /** - * Fetch a file from the given URI. If the URIs returned by the RpcEnvFileServer use the "spark" - * scheme, this method will be called to retrieve the files. + * Open a channel to download a file from the given URI. If the URIs returned by the + * RpcEnvFileServer use the "spark" scheme, this method will be called by the Utils class to + * retrieve the files. * * @param uri URI with location of the file. - * @param dest Local destination of file. - * @param overwrite Whether to overwrite the target file if it exists. */ - def fetchFile(uri: String, dest: File, overwrite: Boolean): Unit + def openChannel(uri: String): ReadableByteChannel } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 3552ea5aedc8..1e23529e5523 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.akka import java.io.File +import java.nio.channels.ReadableByteChannel import java.util.concurrent.ConcurrentHashMap import scala.concurrent.Future @@ -249,7 +250,7 @@ private[spark] class AkkaRpcEnv private[akka] ( } } - override def fetchFile(uri: String, dest: File, overwrite: Boolean): Unit = { + override def openChannel(uri: String): ReadableByteChannel = { throw new UnsupportedOperationException( "AkkaRpcEnv's files should be retrieved using an HTTP client.") } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 63a87a9e804e..d4b5ef421a1d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -20,8 +20,7 @@ import java.io._ import java.lang.{Boolean => JBoolean} import java.net.{InetSocketAddress, URI} import java.nio.ByteBuffer -import java.nio.channels.FileChannel -import java.nio.file.StandardOpenOption +import java.nio.channels.{Pipe, ReadableByteChannel, WritableByteChannel} import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable @@ -328,26 +327,26 @@ private[netty] class NettyRpcEnv( override def fileServer: RpcEnvFileServer = streamManager - override def fetchFile(uri: String, dest: File, overwrite: Boolean): Unit = { + override def openChannel(uri: String): ReadableByteChannel = { val parsedUri = new URI(uri) require(parsedUri.getHost() != null, "Host name must be defined.") require(parsedUri.getPort() > 0, "Port must be defined.") require(parsedUri.getPath() != null && parsedUri.getPath().nonEmpty, "Path must be defined.") - val t1 = System.currentTimeMillis() - val callback = new FileDownloadCallback(dest, overwrite) + val pipe = Pipe.open() + val source = new FileDownloadChannel(pipe.source()) try { + val callback = new FileDownloadCallback(pipe.sink(), source) val client = fileDownloadClient(parsedUri.getHost(), parsedUri.getPort()) - val t2 = System.currentTimeMillis() client.stream(parsedUri.getPath(), callback) - callback.waitForCompletion() - val t3 = System.currentTimeMillis() - logDebug(s"Downloaded ${parsedUri.getPath()}: " + - s"${Utils.bytesToString(callback.transferred)} bytes in " + - s"${t3 - t1}ms (${t2 - t1}ms spent in setup)") - } finally { - callback.dispose() + } catch { + case e: Exception => + pipe.sink().close() + source.close() + throw e } + + source } private def fileDownloadClient(host: String, port: Int): TransportClient = synchronized { @@ -384,76 +383,43 @@ private[netty] class NettyRpcEnv( } - private class FileDownloadCallback(dest: File, overwrite: Boolean) extends StreamCallback { + private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { - @volatile var error: Throwable = null - @volatile var transferred = 0L + @volatile private var error: Throwable = _ - private val temp = File.createTempFile(dest.getName(), ".tmp", dest.getParentFile()) - private val out = FileChannel.open(temp.toPath(), StandardOpenOption.WRITE) - private val lock = new Object() - @volatile private var complete = false + def setError(e: Throwable): Unit = error = e - override def onData(streamId: String, buf: ByteBuffer): Unit = { - val count = buf.remaining() - while (buf.remaining() > 0) { - out.write(buf) + override def read(dst: ByteBuffer): Int = { + if (error != null) { + throw error } - transferred += count + source.read(dst) } - override def onComplete(streamId: String): Unit = { - try { - out.close() - if (dest.exists()) { - if (overwrite) { - // Explicitly delete the file since `renameTo` doesn't work on Win32 if the target - // already exists. - require(dest.delete(), s"Failed to delete $dest.") - } else { - throw new IOException(s"Destination file $dest already exists.") - } - } - if (!temp.renameTo(dest)) { - throw new IOException(s"Failed to rename temp file to $dest.") - } - } catch { - case e: Exception => - temp.delete() - throw e - } finally { - finish() - } - } + override def close(): Unit = source.close() - override def onFailure(streamId: String, cause: Throwable): Unit = { - error = cause - dispose() - finish() - } + override def isOpen(): Boolean = source.isOpen() - def waitForCompletion(): Unit = { - lock.synchronized { - while (!complete) { - lock.wait(TimeUnit.SECONDS.toMillis(5)) - logDebug(s"${dest.getName()}: transferred ${Utils.bytesToString(transferred)} bytes.") - } - } - if (error != null) { - throw error + } + + private class FileDownloadCallback( + sink: WritableByteChannel, + source: FileDownloadChannel) extends StreamCallback { + + override def onData(streamId: String, buf: ByteBuffer): Unit = { + while (buf.remaining() > 0) { + sink.write(buf) } } - def dispose(): Unit = { - Utils.tryLogNonFatalError { out.close() } - temp.delete() + override def onComplete(streamId: String): Unit = { + sink.close() } - private def finish(): Unit = { - complete = true - lock.synchronized { - lock.notifyAll() - } + override def onFailure(streamId: String, cause: Throwable): Unit = { + logError(s"Error downloading stream $streamId.", cause) + source.setError(cause) + sink.close() } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b8f316dba856..24659301ddf9 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -21,6 +21,7 @@ import java.io._ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer +import java.nio.channels.Channels import java.util.{Properties, Locale, Random, UUID} import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection @@ -541,7 +542,9 @@ private[spark] object Utils extends Logging { throw new IllegalStateException( "Cannot retrieve files with 'spark' scheme without an active SparkEnv.") } - SparkEnv.get.rpcEnv.fetchFile(url, new File(targetDir, filename), fileOverwrite) + val source = SparkEnv.get.rpcEnv.openChannel(url) + val is = Channels.newInputStream(source) + downloadFile(url, is, targetFile, fileOverwrite) case "http" | "https" | "ftp" => var uc: URLConnection = null if (securityMgr.isAuthenticationEnabled()) { From a222c0344738ec1bd19e6cffea83d11d1b0ad2fc Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 9 Nov 2015 16:28:43 -0800 Subject: [PATCH 3/9] Handle race in timeout handling. --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 71 ++++++++++++++----- 1 file changed, 52 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index d4b5ef421a1d..b50c5ecfe5cf 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -336,8 +336,8 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) try { - val callback = new FileDownloadCallback(pipe.sink(), source) val client = fileDownloadClient(parsedUri.getHost(), parsedUri.getPort()) + val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) } catch { case e: Exception => @@ -355,27 +355,49 @@ private[netty] class NettyRpcEnv( } val address = RpcAddress(host, port) - fileClients.get(address).filter(_.isActive()).getOrElse { - // Create a new client and install a handler that will respond to IdleStateEvent. The events - // are generated by the IdleStateHandler installed by the TransportContext when creating - // clients, and the timeout value is controlled by the transport configuration. - val c = clientFactory.createUnmanagedClient(host, port) - c.getChannel().pipeline().addLast("rpcEnvTimeoutHandler", new TimeoutHandler(c)) - fileClients.put(address, c) - c + val client = fileClients.get(address).filter(_.isActive()).getOrElse(newDownloadClient(address)) + + // Tell the timeout handler this client is in use. This will prevent the handler from + // closing the client if the timeout even triggers before data starts flowing for this + // download. + client.synchronized { + val timeoutHandler = client.getChannel().pipeline().get(classOf[TimeoutHandler]) + timeoutHandler.setInUse(true) + + // After notifying the timeout handler, check that the client is really active, and if not, + // create a new one. + if (client.isActive()) client else newDownloadClient(address) } } + /** + * Create a new client and install a handler that will respond to IdleStateEvent. The events + * are generated by the IdleStateHandler installed by the TransportContext when creating + * clients, and the timeout value is controlled by the transport configuration. + */ + private def newDownloadClient(addr: RpcAddress): TransportClient = { + val c = clientFactory.createUnmanagedClient(addr.host, addr.port) + c.getChannel().pipeline().addLast("rpcEnvTimeoutHandler", new TimeoutHandler(c)) + fileClients.put(addr, c) + c + } + private class TimeoutHandler(client: TransportClient) extends ChannelInboundHandlerAdapter { + @volatile private var inUse = true + + def setInUse(inUse: Boolean): Unit = this.inUse = inUse + override def userEventTriggered(ctx: ChannelHandlerContext, evt: Object): Unit = { - if (evt.isInstanceOf[IdleStateEvent] && client.isActive()) { - logDebug(s"Closing transport client $client after idle timeout.") - val socketAddr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] - val address = RpcAddress(socketAddr.getHostName(), socketAddr.getPort()) - ctx.close() - NettyRpcEnv.this.synchronized { - fileClients.remove(address) + client.synchronized { + if (!inUse && evt.isInstanceOf[IdleStateEvent] && client.isActive()) { + logDebug(s"Closing transport client $client after idle timeout.") + val socketAddr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + val address = RpcAddress(socketAddr.getHostName(), socketAddr.getPort()) + ctx.close() + NettyRpcEnv.this.synchronized { + fileClients.remove(address) + } } } ctx.fireUserEventTriggered(evt) @@ -404,7 +426,8 @@ private[netty] class NettyRpcEnv( private class FileDownloadCallback( sink: WritableByteChannel, - source: FileDownloadChannel) extends StreamCallback { + source: FileDownloadChannel, + client: TransportClient) extends StreamCallback { override def onData(streamId: String, buf: ByteBuffer): Unit = { while (buf.remaining() > 0) { @@ -414,12 +437,22 @@ private[netty] class NettyRpcEnv( override def onComplete(streamId: String): Unit = { sink.close() + releaseClient() } override def onFailure(streamId: String, cause: Throwable): Unit = { logError(s"Error downloading stream $streamId.", cause) - source.setError(cause) - sink.close() + try { + source.setError(cause) + sink.close() + } finally { + releaseClient() + } + } + + private def releaseClient(): Unit = client.synchronized { + val timeoutHandler = client.getChannel().pipeline().get(classOf[TimeoutHandler]) + timeoutHandler.setInUse(false) } } From 71ac0cfdb25d7bb29e180c704a8eac840c7fccd4 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 10 Nov 2015 13:06:40 -0800 Subject: [PATCH 4/9] Fix locking. --- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index b50c5ecfe5cf..ab04d2b54030 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -389,18 +389,23 @@ private[netty] class NettyRpcEnv( def setInUse(inUse: Boolean): Unit = this.inUse = inUse override def userEventTriggered(ctx: ChannelHandlerContext, evt: Object): Unit = { - client.synchronized { + val timedOut = client.synchronized { if (!inUse && evt.isInstanceOf[IdleStateEvent] && client.isActive()) { logDebug(s"Closing transport client $client after idle timeout.") val socketAddr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] val address = RpcAddress(socketAddr.getHostName(), socketAddr.getPort()) ctx.close() - NettyRpcEnv.this.synchronized { - fileClients.remove(address) - } + true + } else { + false } } ctx.fireUserEventTriggered(evt) + if (timedOut) { + NettyRpcEnv.this.synchronized { + fileClients.remove(address) + } + } } } From 11f61a88e4a9a1a1a17572c8579143fc856667d9 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 10 Nov 2015 13:43:21 -0800 Subject: [PATCH 5/9] Doc updates. --- docs/configuration.md | 2 ++ docs/security.md | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/configuration.md b/docs/configuration.md index c276e8e90dec..cdd4099c46a1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1017,6 +1017,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the executor to listen on. This is used for communicating with the driver. + This is only relevant when using the akka-based RPC backend. @@ -1024,6 +1025,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the driver's HTTP file server to listen on. + This is only relevant when using the Akka RPC backend. diff --git a/docs/security.md b/docs/security.md index 177109415180..e1af221d446b 100644 --- a/docs/security.md +++ b/docs/security.md @@ -149,7 +149,8 @@ configure those ports. (random) Schedule tasks spark.executor.port - Akka-based. Set to "0" to choose a port randomly. + Akka-based. Set to "0" to choose a port randomly. Only used if Akka RPC backend is + configured. Executor @@ -157,7 +158,7 @@ configure those ports. (random) File server for files and jars spark.fileserver.port - Jetty-based + Jetty-based. Only used if Akka RPC backend is configured. Executor From fcdc498294e7a8b4a91b7e6ce4e6dc7565e0b4ef Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 10 Nov 2015 16:21:40 -0800 Subject: [PATCH 6/9] Doc consistency fix. --- docs/configuration.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/configuration.md b/docs/configuration.md index cdd4099c46a1..967cc55c8d9a 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -1017,7 +1017,7 @@ Apart from these, the following properties are also available, and may be useful (random) Port for the executor to listen on. This is used for communicating with the driver. - This is only relevant when using the akka-based RPC backend. + This is only relevant when using the Akka RPC backend. From 7cc83e73fc0223c716e73b1ec269c484159092f3 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 18 Nov 2015 14:47:21 -0800 Subject: [PATCH 7/9] Fix borked merge. --- core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 14ae42c567e1..afad7d9928cd 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -25,6 +25,7 @@ import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable +import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} From cfd01bddeeedc99b84e756afe97b9731b8e95fc6 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 18 Nov 2015 18:53:46 -0800 Subject: [PATCH 8/9] Use separate client factory for file downloads. This makes the code friendlier to SPARK-11097 (if and when it's implemented), and also avoids custom timeout handling code by reusing features of the transport library. Need to test on a real cluster, though... --- .../apache/spark/rpc/netty/NettyRpcEnv.scala | 144 +++++------------- 1 file changed, 38 insertions(+), 106 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index afad7d9928cd..b616e9d8956a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -25,15 +25,11 @@ import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.Nullable -import scala.collection.mutable import scala.concurrent.{Future, Promise} import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success} import scala.util.control.NonFatal -import io.netty.channel.{ChannelHandlerContext, ChannelInboundHandlerAdapter} -import io.netty.handler.timeout.IdleStateEvent - import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -62,33 +58,28 @@ private[netty] class NettyRpcEnv( private val transportContext = new TransportContext(transportConf, new NettyRpcHandler(dispatcher, this, streamManager)) - private val clientFactory = { - val bootstraps: java.util.List[TransportClientBootstrap] = - if (securityManager.isAuthenticationEnabled()) { - java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, - securityManager.isSaslEncryptionEnabled())) - } else { - java.util.Collections.emptyList[TransportClientBootstrap] - } - transportContext.createClientFactory(bootstraps) + private def createClientBootstraps(): java.util.List[TransportClientBootstrap] = { + if (securityManager.isAuthenticationEnabled()) { + java.util.Arrays.asList(new SaslClientBootstrap(transportConf, "", securityManager, + securityManager.isSaslEncryptionEnabled())) + } else { + java.util.Collections.emptyList[TransportClientBootstrap] + } } - val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") + private val clientFactory = transportContext.createClientFactory(createClientBootstraps()) /** - * A collection of cached clients used for downloading files, to avoid the cost of connection - * establishment when multiple files are downloaded in a short period of time. The clients are - * closed after an inactivity timeout. + * A separate client factory for file downloads. This avoids using the same RPC handler as + * the main RPC context, so that events caused by these clients are kept isolated from the + * main RPC traffic. * - * Clients are cached for each remote host, although in most cases there will be a single remote - * host serving files. - * - * These TransportClient instances are not used for RPC, so they're never registered with the - * NettyRpcHandler instance. The outcome is that connection / disconnection events are not - * sent for these clients, which is desirable since otherwise other parts of the code might - * think the driver (for example) is disconnecting when that's not the case. + * It also allows for different configuration of certain properties, such as the number of + * connections per peer. */ - private val fileClients = new mutable.HashMap[RpcAddress, TransportClient]() + @volatile private var fileDownloadFactory: TransportClientFactory = _ + + val timeoutScheduler = ThreadUtils.newDaemonSingleThreadScheduledExecutor("netty-rpc-env-timeout") // Because TransportClientFactory.createClient is blocking, we need to run it in this thread pool // to implement non-blocking send/ask. @@ -314,9 +305,8 @@ private[netty] class NettyRpcEnv( if (clientConnectionExecutor != null) { clientConnectionExecutor.shutdownNow() } - synchronized { - fileClients.values.foreach(_.close()) - fileClients.clear() + if (fileDownloadFactory != null) { + fileDownloadFactory.close() } } @@ -337,7 +327,7 @@ private[netty] class NettyRpcEnv( val pipe = Pipe.open() val source = new FileDownloadChannel(pipe.source()) try { - val client = fileDownloadClient(parsedUri.getHost(), parsedUri.getPort()) + val client = downloadClient(parsedUri.getHost(), parsedUri.getPort()) val callback = new FileDownloadCallback(pipe.sink(), source, client) client.stream(parsedUri.getPath(), callback) } catch { @@ -350,65 +340,20 @@ private[netty] class NettyRpcEnv( source } - private def fileDownloadClient(host: String, port: Int): TransportClient = synchronized { - if (stopped.get()) { - throw new IllegalStateException("RpcEnv already stopped.") - } - - val address = RpcAddress(host, port) - val client = fileClients.get(address).filter(_.isActive()).getOrElse(newDownloadClient(address)) - - // Tell the timeout handler this client is in use. This will prevent the handler from - // closing the client if the timeout even triggers before data starts flowing for this - // download. - client.synchronized { - val timeoutHandler = client.getChannel().pipeline().get(classOf[TimeoutHandler]) - timeoutHandler.setInUse(true) - - // After notifying the timeout handler, check that the client is really active, and if not, - // create a new one. - if (client.isActive()) client else newDownloadClient(address) - } - } - - /** - * Create a new client and install a handler that will respond to IdleStateEvent. The events - * are generated by the IdleStateHandler installed by the TransportContext when creating - * clients, and the timeout value is controlled by the transport configuration. - */ - private def newDownloadClient(addr: RpcAddress): TransportClient = { - val c = clientFactory.createUnmanagedClient(addr.host, addr.port) - c.getChannel().pipeline().addLast("rpcEnvTimeoutHandler", new TimeoutHandler(c)) - fileClients.put(addr, c) - c - } - - private class TimeoutHandler(client: TransportClient) extends ChannelInboundHandlerAdapter { - - @volatile private var inUse = true - - def setInUse(inUse: Boolean): Unit = this.inUse = inUse - - override def userEventTriggered(ctx: ChannelHandlerContext, evt: Object): Unit = { - val timedOut = client.synchronized { - if (!inUse && evt.isInstanceOf[IdleStateEvent] && client.isActive()) { - logDebug(s"Closing transport client $client after idle timeout.") - val socketAddr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] - val address = RpcAddress(socketAddr.getHostName(), socketAddr.getPort()) - ctx.close() - true - } else { - false - } - } - ctx.fireUserEventTriggered(evt) - if (timedOut) { - NettyRpcEnv.this.synchronized { - fileClients.remove(address) + private def downloadClient(host: String, port: Int): TransportClient = { + if (fileDownloadFactory == null) synchronized { + if (fileDownloadFactory == null) { + val clone = conf.clone() + conf.getOption("spark.files.maxDownloadClients").foreach { v => + clone.set("spark.rpc.io.numConnectionsPerPeer", v) } + val ioThreads = clone.getInt("spark.files.io.threads", 1) + val downloadConf = SparkTransportConf.fromSparkConf(clone, "rpc", ioThreads) + val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) + fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) } } - + fileDownloadFactory.createClient(host, port) } private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel { @@ -443,22 +388,12 @@ private[netty] class NettyRpcEnv( override def onComplete(streamId: String): Unit = { sink.close() - releaseClient() } override def onFailure(streamId: String, cause: Throwable): Unit = { logError(s"Error downloading stream $streamId.", cause) - try { - source.setError(cause) - sink.close() - } finally { - releaseClient() - } - } - - private def releaseClient(): Unit = client.synchronized { - val timeoutHandler = client.getChannel().pipeline().get(classOf[TimeoutHandler]) - timeoutHandler.setInUse(false) + source.setError(cause) + sink.close() } } @@ -669,10 +604,8 @@ private[netty] class NettyRpcHandler( override def exceptionCaught(cause: Throwable, client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - if (clients.containsKey(client)) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) - } + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + dispatcher.postToAll(RemoteProcessConnectionError(cause, clientAddr)) } else { // If the channel is closed before connecting, its remoteAddress will be null. // See java.net.Socket.getRemoteSocketAddress @@ -684,11 +617,10 @@ private[netty] class NettyRpcHandler( override def connectionTerminated(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - if (clients.remove(client) != null) { - val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - nettyEnv.removeOutbox(clientAddr) - dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) - } + clients.remove(client) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + nettyEnv.removeOutbox(clientAddr) + dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) } else { // If the channel is closed before connecting, its remoteAddress will be null. In this case, // we can ignore it since we don't fire "Associated". From 0e8e4bb10605d218e8c763ac4fd66bf48113b794 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 19 Nov 2015 11:45:57 -0800 Subject: [PATCH 9/9] Allow separate configs for all network opts in file download client. But use the rpc config as the default. Also now properly tested on a real cluster, and verified idle connections are closed. --- .../org/apache/spark/rpc/netty/NettyRpcEnv.scala | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index b616e9d8956a..581a2b071014 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -343,12 +343,20 @@ private[netty] class NettyRpcEnv( private def downloadClient(host: String, port: Int): TransportClient = { if (fileDownloadFactory == null) synchronized { if (fileDownloadFactory == null) { + val module = "files" + val prefix = "spark.rpc.io." val clone = conf.clone() - conf.getOption("spark.files.maxDownloadClients").foreach { v => - clone.set("spark.rpc.io.numConnectionsPerPeer", v) + + // Copy any RPC configuration that is not overridden in the spark.files namespace. + conf.getAll.foreach { case (key, value) => + if (key.startsWith(prefix)) { + val opt = key.substring(prefix.length()) + clone.setIfMissing(s"spark.$module.io.$opt", value) + } } + val ioThreads = clone.getInt("spark.files.io.threads", 1) - val downloadConf = SparkTransportConf.fromSparkConf(clone, "rpc", ioThreads) + val downloadConf = SparkTransportConf.fromSparkConf(clone, module, ioThreads) val downloadContext = new TransportContext(downloadConf, new NoOpRpcHandler(), true) fileDownloadFactory = downloadContext.createClientFactory(createClientBootstraps()) }