Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[deploy] object DeployMessages {
port: Int,
cores: Int,
memory: Int,
webUiPort: Int,
workerWebUiUrl: String,
publicAddress: String)
extends DeployMessage {
Utils.checkHost(host, "Required hostname")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ private[spark] class Master(
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.bind()
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort
val masterWebUiUrlPrefix = conf.get("spark.http.policy") + "://"
masterWebUiUrl = masterWebUiUrlPrefix + masterPublicAddress + ":" + webUi.boundPort
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)

masterMetricsSystem.registerSource(masterSource)
Expand Down Expand Up @@ -190,7 +191,7 @@ private[spark] class Master(
System.exit(0)
}

case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) =>
case RegisterWorker(id, workerHost, workerPort, cores, memory, workerWebUiUrl, publicAddress) =>
{
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
workerHost, workerPort, cores, Utils.megabytesToString(memory)))
Expand All @@ -200,7 +201,7 @@ private[spark] class Master(
sender ! RegisterWorkerFailed("Duplicate worker ID")
} else {
val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
sender, workerUiPort, publicAddress)
sender, workerWebUiUrl, publicAddress)
if (registerWorker(worker)) {
persistenceEngine.addWorker(worker)
sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ private[spark] class WorkerInfo(
val cores: Int,
val memory: Int,
val actor: ActorRef,
val webUiPort: Int,
val webUiAddress: String,
val publicAddress: String)
extends Serializable {

Expand Down Expand Up @@ -99,10 +99,6 @@ private[spark] class WorkerInfo(
coresUsed -= driver.desc.cores
}

def webUiAddress : String = {
"http://" + this.publicAddress + ":" + this.webUiPort
}

def setState(state: WorkerState.Value) = {
this.state = state
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ private[spark] class Worker(
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName)
var workerWebUiUrl: String = _
@volatile var registered = false
@volatile var connected = false
val workerId = generateWorkerId()
Expand Down Expand Up @@ -130,8 +131,9 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
workerWebUiUrl = conf.get("spark.http.policy") + "://" + publicAddress + ":" + webUi.boundPort
registerWithMaster()

metricsSystem.registerSource(workerSource)
Expand All @@ -157,7 +159,7 @@ private[spark] class Worker(
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
actor ! RegisterWorker(workerId, host, port, cores, memory, workerWebUiUrl, publicAddress)
}
}

Expand Down Expand Up @@ -369,7 +371,8 @@ private[spark] class Worker(
private[spark] object Worker extends Logging {
def main(argStrings: Array[String]) {
SignalLogger.register(log)
val args = new WorkerArguments(argStrings)
val conf = new SparkConf
val args = new WorkerArguments(argStrings, conf)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.masters, args.workDir)
actorSystem.awaitTermination()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ package org.apache.spark.deploy.worker
import java.lang.management.ManagementFactory

import org.apache.spark.util.{IntParam, MemoryParam, Utils}
import org.apache.spark.SparkConf

/**
* Command-line parser for the worker.
*/
private[spark] class WorkerArguments(args: Array[String]) {
private[spark] class WorkerArguments(args: Array[String], conf: SparkConf) {
var host = Utils.localHostName()
var port = 0
var webUiPort = 8081
Expand All @@ -49,6 +50,9 @@ private[spark] class WorkerArguments(args: Array[String]) {
if (System.getenv("SPARK_WORKER_DIR") != null) {
workDir = System.getenv("SPARK_WORKER_DIR")
}
if (conf.contains("worker.ui.port")) {
webUiPort = conf.get("worker.ui.port").toInt
}

parse(args.toList)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ private[spark]
class WorkerWebUI(
val worker: Worker,
val workDir: File,
port: Option[Int] = None)
extends WebUI(worker.securityMgr, WorkerWebUI.getUIPort(port, worker.conf), worker.conf)
port: Int)
extends WebUI(worker.securityMgr, port, worker.conf)
with Logging {

val timeout = AkkaUtils.askTimeout(worker.conf)
Expand All @@ -54,10 +54,5 @@ class WorkerWebUI(
}

private[spark] object WorkerWebUI {
val DEFAULT_PORT = 8081
val STATIC_RESOURCE_BASE = SparkUI.STATIC_RESOURCE_DIR

def getUIPort(requestedPort: Option[Int], conf: SparkConf): Int = {
requestedPort.getOrElse(conf.getInt("worker.ui.port", WorkerWebUI.DEFAULT_PORT))
}
}
49 changes: 47 additions & 2 deletions core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import scala.language.implicitConversions
import scala.util.{Failure, Success, Try}
import scala.xml.Node

import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.{Connector, Server}
import org.eclipse.jetty.server.handler._
import org.eclipse.jetty.servlet._
import org.eclipse.jetty.util.thread.QueuedThreadPool
Expand All @@ -35,6 +35,8 @@ import org.json4s.jackson.JsonMethods.{pretty, render}

import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.util.Utils
import org.eclipse.jetty.server.nio.SelectChannelConnector
import org.eclipse.jetty.server.ssl.SslSelectChannelConnector

/**
* Utilities for launching a web server using Jetty's HTTP Server class
Expand Down Expand Up @@ -182,7 +184,8 @@ private[spark] object JettyUtils extends Logging {

@tailrec
def connect(currentPort: Int): (Server, Int) = {
val server = new Server(new InetSocketAddress(hostName, currentPort))
val server = new Server
server.addConnector(getConnector(currentPort, conf))
val pool = new QueuedThreadPool
pool.setDaemon(true)
server.setThreadPool(pool)
Expand Down Expand Up @@ -215,6 +218,48 @@ private[spark] object JettyUtils extends Logging {
private def attachPrefix(basePath: String, relativePath: String): String = {
if (basePath == "") relativePath else (basePath + relativePath).stripSuffix("/")
}

private def getConnector(port: Int, conf: SparkConf): Connector = {
val https = getHttpPolicy(conf)
if (https) {
buildSslSelectChannelConnector(port, conf)
} else {
conf.set("spark.http.policy", "http")
val connector = new SelectChannelConnector
connector.setPort(port)
connector
}
}

private def buildSslSelectChannelConnector(port: Int, conf: SparkConf): Connector =
{
val connector = new SslSelectChannelConnector
connector.setPort(port)

val context = connector.getSslContextFactory
val needAuth = conf.getBoolean("spark.client.https.need-auth", false)
context.setNeedClientAuth(needAuth)
context.setKeyManagerPassword(conf.get("spark.ssl.server.keystore.keypassword"))
if (conf.contains("spark.ssl.server.keystore.location")) {
context.setKeyStorePath(conf.get("spark.ssl.server.keystore.location"))
context.setKeyStorePassword(conf.get("spark.ssl.server.keystore.password"))
context.setKeyStoreType(conf.get("spark.ssl.server.keystore.type", "jks"))
}
if (needAuth && conf.contains("spark.ssl.server.truststore.location")) {
context.setTrustStore(conf.get("spark.ssl.server.truststore.location"))
context.setTrustStorePassword(conf.get("spark.ssl.server.truststore.password"))
context.setTrustStoreType(conf.get("spark.ssl.server.truststore.type", "jks"))
}
connector
}

def getHttpPolicy(conf: SparkConf): Boolean = {
if (conf.contains("spark.http.policy") && conf.get("spark.http.policy").equals("https")) {
true
} else {
false
}
}
}

private[spark] case class ServerInfo(
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/SparkUI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ private[spark] class SparkUI(
*/
private[spark] def appUIHostPort = publicHostName + ":" + boundPort

private[spark] def appUIAddress = s"http://$appUIHostPort"
private def appUiAddressPrefix = conf.get("spark.http.policy")

private[spark] def appUIAddress = s"$appUiAddressPrefix://$appUIHostPort"
}

private[spark] object SparkUI {
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/scala/org/apache/spark/ui/WebUI.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ private[spark] abstract class WebUI(
assert(!serverInfo.isDefined, "Attempted to bind %s more than once!".format(className))
try {
serverInfo = Some(startJettyServer("0.0.0.0", port, handlers, conf))
logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort))
if (conf.get("spark.http.policy").equals("https")) {
logInfo("Started %s at https://%s:%d".format(className, publicHostName, boundPort))
} else {
logInfo("Started %s at http://%s:%d".format(className, publicHostName, boundPort))
}
} catch {
case e: Exception =>
logError("Failed to bind %s".format(className), e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class JsonProtocolSuite extends FunSuite {
createDriverDesc(), new Date())

def createWorkerInfo(): WorkerInfo = {
val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, 80, "publicAddress")
val workerInfo = new WorkerInfo("id", "host", 8080, 4, 1234, null, "http://host:8080", "publicAddress")
workerInfo.lastHeartbeat = JsonConstants.currTimeInMillis
workerInfo
}
Expand Down