diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index a7368f9f3dfb..352ad2a7980e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -39,7 +39,7 @@ private[deploy] object DeployMessages { port: Int, cores: Int, memory: Int, - webUiPort: Int, + workerWebUiUrl: String, publicAddress: String) extends DeployMessage { Utils.checkHost(host, "Required hostname") diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index a70ecdb37537..7574e1665c1c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -122,7 +122,8 @@ private[spark] class Master( // Listen for remote client disconnection events, since they don't go through Akka's watch() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.bind() - masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort + val masterWebUiUrlPrefix = conf.get("spark.http.policy") + "://" + masterWebUiUrl = masterWebUiUrlPrefix + masterPublicAddress + ":" + webUi.boundPort context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) masterMetricsSystem.registerSource(masterSource) @@ -190,7 +191,7 @@ private[spark] class Master( System.exit(0) } - case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) => + case RegisterWorker(id, workerHost, workerPort, cores, memory, workerWebUiUrl, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) @@ -200,7 +201,7 @@ private[spark] class Master( sender ! RegisterWorkerFailed("Duplicate worker ID") } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - sender, workerUiPort, publicAddress) + sender, workerWebUiUrl, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) sender ! RegisteredWorker(masterUrl, masterWebUiUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index c5fa9cf7d7c2..0335894bae51 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -30,7 +30,7 @@ private[spark] class WorkerInfo( val cores: Int, val memory: Int, val actor: ActorRef, - val webUiPort: Int, + val webUiAddress: String, val publicAddress: String) extends Serializable { @@ -99,10 +99,6 @@ private[spark] class WorkerInfo( coresUsed -= driver.desc.cores } - def webUiAddress : String = { - "http://" + this.publicAddress + ":" + this.webUiPort - } - def setState(state: WorkerState.Value) = { this.state = state } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index fb5252da9651..dd97016e3e32 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -78,6 +78,7 @@ private[spark] class Worker( var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName) + var workerWebUiUrl: String = _ @volatile var registered = false @volatile var connected = false val workerId = generateWorkerId() @@ -130,8 +131,9 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) + webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() + workerWebUiUrl = conf.get("spark.http.policy") + "://" + publicAddress + ":" + webUi.boundPort registerWithMaster() metricsSystem.registerSource(workerSource) @@ -157,7 +159,7 @@ private[spark] class Worker( for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + actor ! RegisterWorker(workerId, host, port, cores, memory, workerWebUiUrl, publicAddress) } } @@ -369,7 +371,8 @@ private[spark] class Worker( private[spark] object Worker extends Logging { def main(argStrings: Array[String]) { SignalLogger.register(log) - val args = new WorkerArguments(argStrings) + val conf = new SparkConf + val args = new WorkerArguments(argStrings, conf) val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) actorSystem.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index dc5158102054..809c4c650178 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -20,11 +20,12 @@ package org.apache.spark.deploy.worker import java.lang.management.ManagementFactory import org.apache.spark.util.{IntParam, MemoryParam, Utils} +import org.apache.spark.SparkConf /** * Command-line parser for the worker. */ -private[spark] class WorkerArguments(args: Array[String]) { +private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) { var host = Utils.localHostName() var port = 0 var webUiPort = 8081 @@ -49,6 +50,9 @@ private[spark] class WorkerArguments(args: Array[String]) { if (System.getenv("SPARK_WORKER_DIR") != null) { workDir = System.getenv("SPARK_WORKER_DIR") } + if (conf.contains("worker.ui.port")) { + webUiPort = conf.get("worker.ui.port").toInt + } parse(args.toList) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 0ad2edba2227..30e25a37ed1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -33,8 +33,8 @@ private[spark] class WorkerWebUI( val worker: Worker, val workDir: File, - port: Option[Int] = None) - extends WebUI(worker.securityMgr, WorkerWebUI.getUIPort(port, worker.conf), worker.conf) + port: Int) + extends WebUI(worker.securityMgr, port, worker.conf) with Logging { val timeout = AkkaUtils.askTimeout(worker.conf) @@ -54,10 +54,5 @@ class WorkerWebUI( } private[spark] object WorkerWebUI { - val DEFAULT_PORT = 8081 val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR - - def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = { - requestedPort.getOrElse(conf.getInt("worker.ui.port", WorkerWebUI.DEFAULT_PORT)) - } } diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index a2535e3c1c41..bc57fbd9d8cf 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -26,7 +26,7 @@ import scala.language.implicitConversions import scala.util.{Failure, Success, Try} import scala.xml.Node -import org.eclipse.jetty.server.Server +import org.eclipse.jetty.server.{Connector, Server} import org.eclipse.jetty.server.handler._ import org.eclipse.jetty.servlet._ import org.eclipse.jetty.util.thread.QueuedThreadPool @@ -35,6 +35,8 @@ import org.json4s.jackson.JsonMethods.{pretty, render} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.util.Utils +import org.eclipse.jetty.server.nio.SelectChannelConnector +import org.eclipse.jetty.server.ssl.SslSelectChannelConnector /** * Utilities for launching a web server using Jetty's HTTP Server class @@ -182,7 +184,8 @@ private[spark] object JettyUtils extends Logging { @tailrec def connect(currentPort: Int): (Server, Int) = { - val server = new Server(new InetSocketAddress(hostName, currentPort)) + val server = new Server + server.addConnector(getConnector(currentPort, conf)) val pool = new QueuedThreadPool pool.setDaemon(true) server.setThreadPool(pool) @@ -215,6 +218,48 @@ private[spark] object JettyUtils extends Logging { private def attachPrefix(basePath: String, relativePath: String): String = { if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/") } + + private def getConnector(port: Int, conf: SparkConf): Connector = { + val https = getHttpPolicy(conf) + if (https) { + buildSslSelectChannelConnector(port, conf) + } else { + conf.set("spark.http.policy", "http") + val connector = new SelectChannelConnector + connector.setPort(port) + connector + } + } + + private def buildSslSelectChannelConnector(port: Int, conf: SparkConf): Connector = + { + val connector = new SslSelectChannelConnector + connector.setPort(port) + + val context = connector.getSslContextFactory + val needAuth = conf.getBoolean("spark.client.https.need-auth", false) + context.setNeedClientAuth(needAuth) + context.setKeyManagerPassword(conf.get("spark.ssl.server.keystore.keypassword")) + if (conf.contains("spark.ssl.server.keystore.location")) { + context.setKeyStorePath(conf.get("spark.ssl.server.keystore.location")) + context.setKeyStorePassword(conf.get("spark.ssl.server.keystore.password")) + context.setKeyStoreType(conf.get("spark.ssl.server.keystore.type", "jks")) + } + if (needAuth && conf.contains("spark.ssl.server.truststore.location")) { + context.setTrustStore(conf.get("spark.ssl.server.truststore.location")) + context.setTrustStorePassword(conf.get("spark.ssl.server.truststore.password")) + context.setTrustStoreType(conf.get("spark.ssl.server.truststore.type", "jks")) + } + connector + } + + def getHttpPolicy(conf: SparkConf): Boolean = { + if (conf.contains("spark.http.policy") && conf.get("spark.http.policy").equals("https")) { + true + } else { + false + } + } } private[spark] case class ServerInfo( diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 097a1b81e1dd..727ad893e27e 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -97,7 +97,9 @@ private[spark] class SparkUI( */ private[spark] def appUIHostPort = publicHostName + ":" + boundPort - private[spark] def appUIAddress = s"http://$appUIHostPort" + private def appUiAddressPrefix = conf.get("spark.http.policy") + + private[spark] def appUIAddress = s"$appUiAddressPrefix://$appUIHostPort" } private[spark] object SparkUI { diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 856273e1d4e2..3a4d2563279f 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -98,7 +98,11 @@ private[spark] abstract class WebUI( assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className)) try { serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf)) - logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) + if (conf.get("spark.http.policy").equals("https")) { + logInfo("Started %s at https://%s:%d".format(className, publicHostName, boundPort)) + } else { + logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort)) + } } catch { case e: Exception => logError("Failed to bind %s".format(className), e) diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 093394ad6d14..39bf5e052aa5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -111,7 +111,7 @@ class JsonProtocolSuite extends FunSuite { createDriverDesc(), new Date()) def createWorkerInfo(): WorkerInfo = { - val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress") + val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, "http://host:8080", "publicAddress") workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis workerInfo }