Skip to content

Commit 40f23a4

Browse files
author
Marcelo Vanzin
committed
[SPARK-4563][core] Allow driver to advertise a different network address.
The goal of this feature is to allow the Spark driver to run in an isolated environment, such as a docker container, and be able to use the host's port forwarding mechanism to be able to accept connections from the outside world. The change is restricted to the driver: there is no support for achieving the same thing on executors (or the YARN AM for that matter). Those still need full access to the outside world so that, for example, connections can be made to an executor's block manager. The core of the change is simple: add a new configuration that tells what's the address the driver should bind to, which can be different than the address it advertises to executors (spark.driver.host). Everything else is plumbing the new configuration where it's needed. To use the feature, the host starting the container needs to set up the driver's port range to fall into a range that is being forwarded; this required the block manager port to need a special configuration just for the driver, which falls back to the existing spark.blockManager.port when not set. This way, users can modify the driver settings without affecting the executors; it would theoretically be nice to also have different retry counts for driver and executors, but given that docker (at least) allows forwarding port ranges, we can probably live without that for now. Because of the nature of the feature it's kinda hard to add unit tests; I just added a simple one to make sure the configuration works. This was tested with a docker image running spark-shell with the following command: docker blah blah blah \ -p 38000-38100:38000-38100 \ [image] \ spark-shell \ --num-executors 3 \ --conf spark.shuffle.service.enabled=false \ --conf spark.dynamicAllocation.enabled=false \ --conf spark.driver.host=[host's address] \ --conf spark.driver.port=38000 \ --conf spark.driver.blockManager.port=38020 \ --conf spark.ui.port=38040 Running on YARN; verified the driver works, executors start up and listen on ephemeral ports (instead of using the driver's config), and that caching and shuffling (without the shuffle service) works. Clicked through the UI to make sure all pages (including executor thread dumps) worked. Also tested apps without docker, and ran unit tests.
1 parent 52738d4 commit 40f23a4

File tree

18 files changed

+126
-38
lines changed

18 files changed

+126
-38
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
383383
logInfo("Spark configuration:\n" + _conf.toDebugString)
384384
}
385385

386-
// Set Spark driver host and port system properties
387-
_conf.setIfMissing("spark.driver.host", Utils.localHostName())
386+
// Set Spark driver host and port system properties. This explicitly sets the configuration
387+
// instead of relying on the default value of the config constant.
388+
_conf.setIfMissing(DRIVER_HOST_ADDRESS, conf.get(DRIVER_HOST_ADDRESS))
388389
_conf.setIfMissing("spark.driver.port", "0")
389390

390391
_conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.annotation.DeveloperApi
2929
import org.apache.spark.api.python.PythonWorkerFactory
3030
import org.apache.spark.broadcast.BroadcastManager
3131
import org.apache.spark.internal.Logging
32+
import org.apache.spark.internal.config._
3233
import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager}
3334
import org.apache.spark.metrics.MetricsSystem
3435
import org.apache.spark.network.netty.NettyBlockTransferService
@@ -160,12 +161,14 @@ object SparkEnv extends Logging {
160161
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
161162
assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
162163
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
163-
val hostname = conf.get("spark.driver.host")
164+
val bindAddress = conf.get(DRIVER_BIND_ADDRESS)
165+
val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS)
164166
val port = conf.get("spark.driver.port").toInt
165167
create(
166168
conf,
167169
SparkContext.DRIVER_IDENTIFIER,
168-
hostname,
170+
bindAddress,
171+
advertiseAddress,
169172
port,
170173
isDriver = true,
171174
isLocal = isLocal,
@@ -190,6 +193,7 @@ object SparkEnv extends Logging {
190193
conf,
191194
executorId,
192195
hostname,
196+
hostname,
193197
port,
194198
isDriver = false,
195199
isLocal = isLocal,
@@ -205,7 +209,8 @@ object SparkEnv extends Logging {
205209
private def create(
206210
conf: SparkConf,
207211
executorId: String,
208-
hostname: String,
212+
bindAddress: String,
213+
advertiseAddress: String,
209214
port: Int,
210215
isDriver: Boolean,
211216
isLocal: Boolean,
@@ -221,8 +226,8 @@ object SparkEnv extends Logging {
221226
val securityManager = new SecurityManager(conf)
222227

223228
val systemName = if (isDriver) driverSystemName else executorSystemName
224-
val rpcEnv = RpcEnv.create(systemName, hostname, port, conf, securityManager,
225-
clientMode = !isDriver)
229+
val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port, conf,
230+
securityManager, clientMode = !isDriver)
226231

227232
// Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied.
228233
// In the non-driver case, the RPC env's address may be null since it may not be listening
@@ -309,8 +314,15 @@ object SparkEnv extends Logging {
309314
UnifiedMemoryManager(conf, numUsableCores)
310315
}
311316

317+
val blockManagerPort = if (isDriver) {
318+
conf.get(DRIVER_BLOCK_MANAGER_PORT)
319+
} else {
320+
conf.get(BLOCK_MANAGER_PORT)
321+
}
322+
312323
val blockTransferService =
313-
new NettyBlockTransferService(conf, securityManager, hostname, numUsableCores)
324+
new NettyBlockTransferService(conf, securityManager, bindAddress, advertiseAddress,
325+
blockManagerPort, numUsableCores)
314326

315327
val blockManagerMaster = new BlockManagerMaster(registerOrLookupEndpoint(
316328
BlockManagerMaster.DRIVER_ENDPOINT_NAME,

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.internal
1919

2020
import org.apache.spark.launcher.SparkLauncher
2121
import org.apache.spark.network.util.ByteUnit
22+
import org.apache.spark.util.Utils
2223

2324
package object config {
2425

@@ -143,4 +144,23 @@ package object config {
143144
.internal()
144145
.stringConf
145146
.createWithDefaultString("AES/CTR/NoPadding")
147+
148+
private[spark] val DRIVER_BIND_ADDRESS = ConfigBuilder("spark.driver.bindAddress")
149+
.doc("Address where to bind network listen sockets on the driver.")
150+
.stringConf
151+
.createWithDefault(Utils.localHostName())
152+
153+
private[spark] val DRIVER_HOST_ADDRESS = ConfigBuilder("spark.driver.host")
154+
.doc("Address of driver endpoints.")
155+
.fallbackConf(DRIVER_BIND_ADDRESS)
156+
157+
private[spark] val BLOCK_MANAGER_PORT = ConfigBuilder("spark.blockManager.port")
158+
.doc("Port to use for the block manager when a more specific setting is not provided.")
159+
.intConf
160+
.createWithDefault(0)
161+
162+
private[spark] val DRIVER_BLOCK_MANAGER_PORT = ConfigBuilder("spark.driver.blockManager.port")
163+
.doc("Port to use for the block managed on the driver.")
164+
.fallbackConf(BLOCK_MANAGER_PORT)
165+
146166
}

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ import org.apache.spark.util.Utils
4242
private[spark] class NettyBlockTransferService(
4343
conf: SparkConf,
4444
securityManager: SecurityManager,
45+
bindAddress: String,
4546
override val hostName: String,
47+
_port: Int,
4648
numCores: Int)
4749
extends BlockTransferService {
4850

@@ -75,12 +77,11 @@ private[spark] class NettyBlockTransferService(
7577
/** Creates and binds the TransportServer, possibly trying multiple ports. */
7678
private def createServer(bootstraps: List[TransportServerBootstrap]): TransportServer = {
7779
def startService(port: Int): (TransportServer, Int) = {
78-
val server = transportContext.createServer(hostName, port, bootstraps.asJava)
80+
val server = transportContext.createServer(bindAddress, port, bootstraps.asJava)
7981
(server, server.getPort)
8082
}
8183

82-
val portToTry = conf.getInt("spark.blockManager.port", 0)
83-
Utils.startServiceOnPort(portToTry, startService, conf, getClass.getName)._1
84+
Utils.startServiceOnPort(_port, startService, conf, getClass.getName)._1
8485
}
8586

8687
override def fetchBlocks(

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,19 @@ private[spark] object RpcEnv {
4040
conf: SparkConf,
4141
securityManager: SecurityManager,
4242
clientMode: Boolean = false): RpcEnv = {
43-
val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode)
43+
create(name, host, host, port, conf, securityManager, clientMode)
44+
}
45+
46+
def create(
47+
name: String,
48+
bindAddress: String,
49+
advertiseAddress: String,
50+
port: Int,
51+
conf: SparkConf,
52+
securityManager: SecurityManager,
53+
clientMode: Boolean): RpcEnv = {
54+
val config = RpcEnvConfig(conf, name, bindAddress, advertiseAddress, port, securityManager,
55+
clientMode)
4456
new NettyRpcEnvFactory().create(config)
4557
}
4658
}
@@ -186,7 +198,8 @@ private[spark] trait RpcEnvFileServer {
186198
private[spark] case class RpcEnvConfig(
187199
conf: SparkConf,
188200
name: String,
189-
host: String,
201+
bindAddress: String,
202+
advertiseAddress: String,
190203
port: Int,
191204
securityManager: SecurityManager,
192205
clientMode: Boolean)

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ private[netty] class NettyRpcEnv(
108108
}
109109
}
110110

111-
def startServer(port: Int): Unit = {
111+
def startServer(bindAddress: String, port: Int): Unit = {
112112
val bootstraps: java.util.List[TransportServerBootstrap] =
113113
if (securityManager.isAuthenticationEnabled()) {
114114
java.util.Arrays.asList(new SaslServerBootstrap(transportConf, securityManager))
115115
} else {
116116
java.util.Collections.emptyList()
117117
}
118-
server = transportContext.createServer(host, port, bootstraps)
118+
server = transportContext.createServer(bindAddress, port, bootstraps)
119119
dispatcher.registerRpcEndpoint(
120120
RpcEndpointVerifier.NAME, new RpcEndpointVerifier(this, dispatcher))
121121
}
@@ -441,10 +441,11 @@ private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging {
441441
val javaSerializerInstance =
442442
new JavaSerializer(sparkConf).newInstance().asInstanceOf[JavaSerializerInstance]
443443
val nettyEnv =
444-
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.host, config.securityManager)
444+
new NettyRpcEnv(sparkConf, javaSerializerInstance, config.advertiseAddress,
445+
config.securityManager)
445446
if (!config.clientMode) {
446447
val startNettyRpcEnv: Int => (NettyRpcEnv, Int) = { actualPort =>
447-
nettyEnv.startServer(actualPort)
448+
nettyEnv.startServer(config.bindAddress, actualPort)
448449
(nettyEnv, nettyEnv.address.port)
449450
}
450451
try {

core/src/main/scala/org/apache/spark/ui/WebUI.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import org.json4s.JsonAST.{JNothing, JValue}
2828

2929
import org.apache.spark.{SecurityManager, SparkConf, SSLOptions}
3030
import org.apache.spark.internal.Logging
31+
import org.apache.spark.internal.config._
3132
import org.apache.spark.ui.JettyUtils._
3233
import org.apache.spark.util.Utils
3334

@@ -50,8 +51,8 @@ private[spark] abstract class WebUI(
5051
protected val handlers = ArrayBuffer[ServletContextHandler]()
5152
protected val pageToHandlers = new HashMap[WebUIPage, ArrayBuffer[ServletContextHandler]]
5253
protected var serverInfo: Option[ServerInfo] = None
53-
protected val localHostName = Utils.localHostNameForURI()
54-
protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(localHostName)
54+
protected val publicHostName = Option(conf.getenv("SPARK_PUBLIC_DNS")).getOrElse(
55+
conf.get(DRIVER_HOST_ADDRESS))
5556
private val className = Utils.getFormattedClassName(this)
5657

5758
def getBasePath: String = basePath

core/src/main/scala/org/apache/spark/util/Utils.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2079,9 +2079,9 @@ private[spark] object Utils extends Logging {
20792079
case e: Exception if isBindCollision(e) =>
20802080
if (offset >= maxRetries) {
20812081
val exceptionMessage = s"${e.getMessage}: Service$serviceString failed after " +
2082-
s"$maxRetries retries! Consider explicitly setting the appropriate port for the " +
2083-
s"service$serviceString (for example spark.ui.port for SparkUI) to an available " +
2084-
"port or increasing spark.port.maxRetries."
2082+
s"$maxRetries retries (starting from $startPort)! Consider explicitly setting " +
2083+
s"the appropriate port for the service$serviceString (for example spark.ui.port " +
2084+
s"for SparkUI) to an available port or increasing spark.port.maxRetries."
20852085
val exception = new BindException(exceptionMessage)
20862086
// restore original stack trace
20872087
exception.setStackTrace(e.getStackTrace)

core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,13 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi
108108
when(blockManager.getBlockData(blockId)).thenReturn(blockBuffer)
109109

110110
val securityManager0 = new SecurityManager(conf0)
111-
val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", numCores = 1)
111+
val exec0 = new NettyBlockTransferService(conf0, securityManager0, "localhost", "localhost", 0,
112+
1)
112113
exec0.init(blockManager)
113114

114115
val securityManager1 = new SecurityManager(conf1)
115-
val exec1 = new NettyBlockTransferService(conf1, securityManager1, "localhost", numCores = 1)
116+
val exec1 = new NettyBlockTransferService(conf1, securityManager1, "localhost", "localhost", 0,
117+
1)
116118
exec1.init(blockManager)
117119

118120
val result = fetchBlock(exec0, exec1, "1", blockId) match {

core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import org.mockito.Mockito.mock
2323
import org.scalatest._
2424

2525
import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
26+
import org.apache.spark.internal.config._
2627
import org.apache.spark.network.BlockDataManager
2728

2829
class NettyBlockTransferServiceSuite
@@ -86,10 +87,10 @@ class NettyBlockTransferServiceSuite
8687
private def createService(port: Int): NettyBlockTransferService = {
8788
val conf = new SparkConf()
8889
.set("spark.app.id", s"test-${getClass.getName}")
89-
.set("spark.blockManager.port", port.toString)
9090
val securityManager = new SecurityManager(conf)
9191
val blockDataManager = mock(classOf[BlockDataManager])
92-
val service = new NettyBlockTransferService(conf, securityManager, "localhost", numCores = 1)
92+
val service = new NettyBlockTransferService(conf, securityManager, "localhost", "localhost",
93+
port, 1)
9394
service.init(blockDataManager)
9495
service
9596
}

0 commit comments

Comments
 (0)