Skip to content

Commit af9d9cb

Browse files
author
Andrew Or
committed
Integrate REST protocol in standalone mode
This commit embeds the REST server in the standalone Master and forces Spark submit to submit applications through the REST client. This is the first working end-to-end implementation of a stable submission interface in standalone cluster mode.
1 parent 53e7c0e commit af9d9cb

11 files changed

+297
-95
lines changed

core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ import org.apache.spark.util.MemoryParam
2929
* Command-line parser for the driver client.
3030
*/
3131
private[spark] class ClientArguments(args: Array[String]) {
32-
val defaultCores = 1
33-
val defaultMemory = 512
32+
import ClientArguments._
3433

3534
var cmd: String = "" // 'launch' or 'kill'
3635
var logLevel = Level.WARN
@@ -39,9 +38,9 @@ private[spark] class ClientArguments(args: Array[String]) {
3938
var master: String = ""
4039
var jarUrl: String = ""
4140
var mainClass: String = ""
42-
var supervise: Boolean = false
43-
var memory: Int = defaultMemory
44-
var cores: Int = defaultCores
41+
var supervise: Boolean = DEFAULT_SUPERVISE
42+
var memory: Int = DEFAULT_MEMORY
43+
var cores: Int = DEFAULT_CORES
4544
private var _driverOptions = ListBuffer[String]()
4645
def driverOptions = _driverOptions.toSeq
4746

@@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) {
106105
|Usage: DriverClient kill <active-master> <driver-id>
107106
|
108107
|Options:
109-
| -c CORES, --cores CORES Number of cores to request (default: $defaultCores)
110-
| -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory)
108+
| -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES)
109+
| -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY)
111110
| -s, --supervise Whether to restart the driver on failure
111+
| (default: $DEFAULT_SUPERVISE)
112112
| -v, --verbose Print more debugging output
113113
""".stripMargin
114114
System.err.println(usage)
@@ -117,6 +117,10 @@ private[spark] class ClientArguments(args: Array[String]) {
117117
}
118118

119119
object ClientArguments {
120+
private[spark] val DEFAULT_CORES = 1
121+
private[spark] val DEFAULT_MEMORY = 512 // MB
122+
private[spark] val DEFAULT_SUPERVISE = false
123+
120124
def isValidJarUrl(s: String): Boolean = {
121125
try {
122126
val uri = new URI(s)

core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
2525

2626
import org.apache.spark.executor.ExecutorURLClassLoader
2727
import org.apache.spark.util.Utils
28+
import org.apache.spark.deploy.rest.StandaloneRestClient
2829

2930
/**
3031
* Main gateway of launching a Spark application.
@@ -72,6 +73,16 @@ object SparkSubmit {
7273
if (appArgs.verbose) {
7374
printStream.println(appArgs)
7475
}
76+
77+
// In standalone cluster mode, use the brand new REST client to submit the application
78+
val doingRest = appArgs.master.startsWith("spark://") && appArgs.deployMode == "cluster"
79+
if (doingRest) {
80+
println("Submitting driver through the REST interface.")
81+
new StandaloneRestClient().submitDriver(appArgs)
82+
println("Done submitting driver.")
83+
return
84+
}
85+
7586
val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
7687
launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose)
7788
}

core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
104104
.orElse(sparkProperties.get("spark.master"))
105105
.orElse(env.get("MASTER"))
106106
.orNull
107+
driverExtraClassPath = Option(driverExtraClassPath)
108+
.orElse(sparkProperties.get("spark.driver.extraClassPath"))
109+
.orNull
110+
driverExtraJavaOptions = Option(driverExtraJavaOptions)
111+
.orElse(sparkProperties.get("spark.driver.extraJavaOptions"))
112+
.orNull
113+
driverExtraLibraryPath = Option(driverExtraLibraryPath)
114+
.orElse(sparkProperties.get("spark.driver.extraLibraryPath"))
115+
.orNull
107116
driverMemory = Option(driverMemory)
108117
.orElse(sparkProperties.get("spark.driver.memory"))
109118
.orElse(env.get("SPARK_DRIVER_MEMORY"))

core/src/main/scala/org/apache/spark/deploy/master/Master.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import org.apache.spark.deploy.history.HistoryServer
4343
import org.apache.spark.deploy.master.DriverState.DriverState
4444
import org.apache.spark.deploy.master.MasterMessages._
4545
import org.apache.spark.deploy.master.ui.MasterWebUI
46+
import org.apache.spark.deploy.rest.StandaloneRestServer
4647
import org.apache.spark.metrics.MetricsSystem
4748
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
4849
import org.apache.spark.ui.SparkUI
@@ -121,6 +122,8 @@ private[spark] class Master(
121122
throw new SparkException("spark.deploy.defaultCores must be positive")
122123
}
123124

125+
val restServer = new StandaloneRestServer(this, host, 6677)
126+
124127
override def preStart() {
125128
logInfo("Starting Spark master at " + masterUrl)
126129
logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")

core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ private[spark] object KillDriverResponseField extends StandaloneRestProtocolFiel
2727
case object MESSAGE extends KillDriverResponseField
2828
case object MASTER extends KillDriverResponseField
2929
case object DRIVER_ID extends KillDriverResponseField
30-
case object DRIVER_STATE extends SubmitDriverResponseField
31-
override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE)
30+
case object SUCCESS extends SubmitDriverResponseField
31+
override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, SUCCESS)
3232
override val optionalFields = Seq.empty
3333
}
3434

core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import com.google.common.base.Charsets
2727

2828
import org.apache.spark.{SPARK_VERSION => sparkVersion}
2929
import org.apache.spark.deploy.SparkSubmitArguments
30+
import org.apache.spark.util.Utils
3031

3132
/**
3233
* A client that submits Spark applications using a stable REST protocol in standalone
@@ -63,6 +64,12 @@ private[spark] class StandaloneRestClient {
6364
*/
6465
private def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage = {
6566
import SubmitDriverRequestField._
67+
val driverMemory = Option(args.driverMemory)
68+
.map { m => Utils.memoryStringToMb(m).toString }
69+
.orNull
70+
val executorMemory = Option(args.executorMemory)
71+
.map { m => Utils.memoryStringToMb(m).toString }
72+
.orNull
6673
val message = new SubmitDriverRequestMessage()
6774
.setField(SPARK_VERSION, sparkVersion)
6875
.setField(MASTER, args.master)
@@ -72,19 +79,21 @@ private[spark] class StandaloneRestClient {
7279
.setFieldIfNotNull(JARS, args.jars)
7380
.setFieldIfNotNull(FILES, args.files)
7481
.setFieldIfNotNull(PY_FILES, args.pyFiles)
75-
.setFieldIfNotNull(DRIVER_MEMORY, args.driverMemory)
82+
.setFieldIfNotNull(DRIVER_MEMORY, driverMemory)
7683
.setFieldIfNotNull(DRIVER_CORES, args.driverCores)
7784
.setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions)
7885
.setFieldIfNotNull(DRIVER_EXTRA_CLASS_PATH, args.driverExtraClassPath)
7986
.setFieldIfNotNull(DRIVER_EXTRA_LIBRARY_PATH, args.driverExtraLibraryPath)
8087
.setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString)
81-
.setFieldIfNotNull(EXECUTOR_MEMORY, args.executorMemory)
88+
.setFieldIfNotNull(EXECUTOR_MEMORY, executorMemory)
8289
.setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores)
83-
// Set each Spark property as its own field
84-
// TODO: Include environment variables?
90+
args.childArgs.zipWithIndex.foreach { case (arg, i) =>
91+
message.setFieldIfNotNull(APP_ARG(i), arg)
92+
}
8593
args.sparkProperties.foreach { case (k, v) =>
8694
message.setFieldIfNotNull(SPARK_PROPERTY(k), v)
8795
}
96+
// TODO: set environment variables?
8897
message.validate()
8998
}
9099

@@ -175,8 +184,8 @@ private[spark] class StandaloneRestClient {
175184
object StandaloneRestClient {
176185
def main(args: Array[String]): Unit = {
177186
assert(args.length > 0)
178-
val client = new StandaloneRestClient
179-
client.killDriver("spark://" + args(0), "abc_driver")
187+
//val client = new StandaloneRestClient
188+
//client.submitDriver("spark://" + args(0))
180189
println("Done.")
181190
}
182191
}

core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,19 +63,28 @@ private[spark] abstract class StandaloneRestProtocolMessage(
6363

6464
import StandaloneRestProtocolField._
6565

66-
private val fields = new mutable.HashMap[StandaloneRestProtocolField, String]
6766
private val className = Utils.getFormattedClassName(this)
67+
protected val fields = new mutable.HashMap[StandaloneRestProtocolField, String]
6868

6969
// Set the action field
7070
fields(actionField) = action.toString
7171

72+
/** Return all fields currently set in this message. */
73+
def getFields: Map[StandaloneRestProtocolField, String] = fields
74+
75+
/** Return the value of the given field. If the field is not present, return null. */
76+
def getField(key: StandaloneRestProtocolField): String = getFieldOption(key).orNull
77+
7278
/** Return the value of the given field. If the field is not present, throw an exception. */
73-
def getField(key: StandaloneRestProtocolField): String = {
74-
fields.get(key).getOrElse {
79+
def getFieldNotNull(key: StandaloneRestProtocolField): String = {
80+
getFieldOption(key).getOrElse {
7581
throw new IllegalArgumentException(s"Field $key is not set in message $className")
7682
}
7783
}
7884

85+
/** Return the value of the given field as an option. */
86+
def getFieldOption(key: StandaloneRestProtocolField): Option[String] = fields.get(key)
87+
7988
/** Assign the given value to the field, overriding any existing value. */
8089
def setField(key: StandaloneRestProtocolField, value: String): this.type = {
8190
if (key == actionField) {

core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala

Lines changed: 43 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.deploy.rest
1919

2020
import java.io.DataOutputStream
21+
import java.net.InetSocketAddress
2122
import javax.servlet.http.{HttpServletRequest, HttpServletResponse}
2223

2324
import scala.io.Source
@@ -26,25 +27,37 @@ import com.google.common.base.Charsets
2627
import org.eclipse.jetty.server.{Request, Server}
2728
import org.eclipse.jetty.server.handler.AbstractHandler
2829

29-
import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion}
30-
import org.apache.spark.deploy.rest.StandaloneRestProtocolAction._
31-
import org.apache.spark.util.Utils
30+
import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging}
31+
import org.apache.spark.deploy.master.Master
32+
import org.apache.spark.util.{AkkaUtils, Utils}
3233

3334
/**
3435
* A server that responds to requests submitted by the StandaloneRestClient.
3536
*/
36-
private[spark] class StandaloneRestServer(requestedPort: Int) {
37-
val server = new Server(requestedPort)
38-
server.setHandler(new StandaloneRestHandler)
37+
private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) {
38+
val server = new Server(new InetSocketAddress(host, requestedPort))
39+
server.setHandler(new StandaloneRestServerHandler(master))
3940
server.start()
40-
server.join()
4141
}
4242

4343
/**
4444
* A Jetty handler that responds to requests submitted via the standalone REST protocol.
4545
*/
46-
private[spark] class StandaloneRestHandler extends AbstractHandler with Logging {
46+
private[spark] abstract class StandaloneRestHandler(master: Master)
47+
extends AbstractHandler with Logging {
4748

49+
private implicit val askTimeout = AkkaUtils.askTimeout(master.conf)
50+
51+
/** Handle a request to submit a driver. */
52+
protected def handleSubmit(request: SubmitDriverRequestMessage): SubmitDriverResponseMessage
53+
/** Handle a request to kill a driver. */
54+
protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage
55+
/** Handle a request for a driver's status. */
56+
protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage
57+
58+
/**
59+
* Handle a request submitted by the StandaloneRestClient.
60+
*/
4861
override def handle(
4962
target: String,
5063
baseRequest: Request,
@@ -67,74 +80,32 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging
6780
}
6881
}
6982

83+
/**
84+
* Construct the appropriate response message based on the type of the request message.
85+
* If an IllegalArgumentException is thrown in the process, construct an error message.
86+
*/
7087
private def constructResponseMessage(
7188
request: StandaloneRestProtocolMessage): StandaloneRestProtocolMessage = {
7289
// If the request is sent via the StandaloneRestClient, it should have already been
7390
// validated remotely. In case this is not true, validate the request here to guard
7491
// against potential NPEs. If validation fails, return an ERROR message to the sender.
7592
try {
7693
request.validate()
94+
request match {
95+
case submit: SubmitDriverRequestMessage => handleSubmit(submit)
96+
case kill: KillDriverRequestMessage => handleKill(kill)
97+
case status: DriverStatusRequestMessage => handleStatus(status)
98+
case unexpected => handleError(
99+
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
100+
}
77101
} catch {
78-
case e: IllegalArgumentException =>
79-
return handleError(e.getMessage)
80-
}
81-
request match {
82-
case submit: SubmitDriverRequestMessage => handleSubmitRequest(submit)
83-
case kill: KillDriverRequestMessage => handleKillRequest(kill)
84-
case status: DriverStatusRequestMessage => handleStatusRequest(status)
85-
case unexpected => handleError(
86-
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
102+
// Propagate exception to user in an ErrorMessage. If the construction of the
103+
// ErrorMessage itself throws an exception, log the exception and ignore the request.
104+
case e: IllegalArgumentException => handleError(e.getMessage)
87105
}
88106
}
89107

90-
private def handleSubmitRequest(
91-
request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = {
92-
import SubmitDriverResponseField._
93-
// TODO: Actually submit the driver
94-
val message = "Driver is submitted successfully..."
95-
val master = request.getField(SubmitDriverRequestField.MASTER)
96-
val driverId = "new_driver_id"
97-
val driverState = "SUBMITTED"
98-
new SubmitDriverResponseMessage()
99-
.setField(SPARK_VERSION, sparkVersion)
100-
.setField(MESSAGE, message)
101-
.setField(MASTER, master)
102-
.setField(DRIVER_ID, driverId)
103-
.setField(DRIVER_STATE, driverState)
104-
.validate()
105-
}
106-
107-
private def handleKillRequest(request: KillDriverRequestMessage): KillDriverResponseMessage = {
108-
import KillDriverResponseField._
109-
// TODO: Actually kill the driver
110-
val message = "Driver is killed successfully..."
111-
val master = request.getField(KillDriverRequestField.MASTER)
112-
val driverId = request.getField(KillDriverRequestField.DRIVER_ID)
113-
val driverState = "KILLED"
114-
new KillDriverResponseMessage()
115-
.setField(SPARK_VERSION, sparkVersion)
116-
.setField(MESSAGE, message)
117-
.setField(MASTER, master)
118-
.setField(DRIVER_ID, driverId)
119-
.setField(DRIVER_STATE, driverState)
120-
.validate()
121-
}
122-
123-
private def handleStatusRequest(
124-
request: DriverStatusRequestMessage): DriverStatusResponseMessage = {
125-
import DriverStatusResponseField._
126-
// TODO: Actually look up the status of the driver
127-
val master = request.getField(DriverStatusRequestField.MASTER)
128-
val driverId = request.getField(DriverStatusRequestField.DRIVER_ID)
129-
val driverState = "HEALTHY"
130-
new DriverStatusResponseMessage()
131-
.setField(SPARK_VERSION, sparkVersion)
132-
.setField(MASTER, master)
133-
.setField(DRIVER_ID, driverId)
134-
.setField(DRIVER_STATE, driverState)
135-
.validate()
136-
}
137-
108+
/** Construct an error message to signal the fact that an exception has been thrown. */
138109
private def handleError(message: String): ErrorMessage = {
139110
import ErrorField._
140111
new ErrorMessage()
@@ -144,10 +115,10 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging
144115
}
145116
}
146117

147-
object StandaloneRestServer {
148-
def main(args: Array[String]): Unit = {
149-
println("Hey boy I'm starting a server.")
150-
new StandaloneRestServer(6677)
151-
readLine()
152-
}
153-
}
118+
//object StandaloneRestServer {
119+
// def main(args: Array[String]): Unit = {
120+
// println("Hey boy I'm starting a server.")
121+
// new StandaloneRestServer(6677)
122+
// readLine()
123+
// }
124+
//}

0 commit comments

Comments
 (0)