From 53e7c0e9a208e70c374793baf693d64fe4e6b8e1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 14 Jan 2015 16:29:52 -0800 Subject: [PATCH 01/48] Initial client, server, and all the messages This commit introduces type-safe schemas for all messages exchanged in the REST protocol. Each message is expected to contain an ACTION field that has only one possible value for each message type. Before the message is sent, we validate that all required fields are in fact present, and that the value of the action field is the correct type. The next step is to actually integrate this in standalone mode. --- .../org/apache/spark/deploy/SparkSubmit.scala | 5 +- .../rest/DriverStatusRequestMessage.scala | 47 ++++ .../rest/DriverStatusResponseMessage.scala | 51 ++++ .../spark/deploy/rest/ErrorMessage.scala | 44 ++++ .../rest/KillDriverRequestMessage.scala | 47 ++++ .../rest/KillDriverResponseMessage.scala | 48 ++++ .../deploy/rest/StandaloneRestClient.scala | 182 +++++++++++++++ .../rest/StandaloneRestProtocolMessage.scala | 217 ++++++++++++++++++ .../deploy/rest/StandaloneRestServer.scala | 153 ++++++++++++ .../rest/SubmitDriverRequestMessage.scala | 70 ++++++ .../rest/SubmitDriverResponseMessage.scala | 48 ++++ 11 files changed, 910 insertions(+), 2 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 955cbd6dab96..1c89b2452e83 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -39,7 +39,8 @@ object SparkSubmit { private val STANDALONE = 2 private val MESOS = 4 private val LOCAL = 8 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL + private val REST = 16 + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | REST // Deploy modes private val CLIENT = 1 @@ -97,7 +98,7 @@ object SparkSubmit { case m if m.startsWith("spark") => STANDALONE case m if m.startsWith("mesos") => MESOS case m if m.startsWith("local") => LOCAL - case _ => printErrorAndExit("Master must start with yarn, spark, mesos, or local"); -1 + case _ => printErrorAndExit("Master must start with yarn, spark, mesos, local, or rest"); -1 } // Set the deploy mode; default is client mode diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala new file mode 100644 index 000000000000..eed9f7388ff8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * A field used in a DriverStatusRequestMessage. + */ +private[spark] abstract class DriverStatusRequestField extends StandaloneRestProtocolField +private[spark] object DriverStatusRequestField extends StandaloneRestProtocolFieldCompanion { + case object ACTION extends DriverStatusRequestField + case object SPARK_VERSION extends DriverStatusRequestField + case object MESSAGE extends DriverStatusRequestField + case object MASTER extends DriverStatusRequestField + case object DRIVER_ID extends DriverStatusRequestField + override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, DRIVER_ID) + override val optionalFields = Seq(MESSAGE) +} + +/** + * A request sent to the standalone Master to query the status of a driver. + */ +private[spark] class DriverStatusRequestMessage extends StandaloneRestProtocolMessage( + StandaloneRestProtocolAction.DRIVER_STATUS_REQUEST, + DriverStatusRequestField.ACTION, + DriverStatusRequestField.requiredFields) + +private[spark] object DriverStatusRequestMessage extends StandaloneRestProtocolMessageCompanion { + protected override def newMessage(): StandaloneRestProtocolMessage = + new DriverStatusRequestMessage + protected override def fieldWithName(field: String): StandaloneRestProtocolField = + DriverStatusRequestField.withName(field) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala new file mode 100644 index 000000000000..020f80c71ffc --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * A field used in a DriverStatusResponseMessage. + */ +private[spark] abstract class DriverStatusResponseField extends StandaloneRestProtocolField +private[spark] object DriverStatusResponseField extends StandaloneRestProtocolFieldCompanion { + case object ACTION extends DriverStatusResponseField + case object SPARK_VERSION extends DriverStatusResponseField + case object MESSAGE extends DriverStatusResponseField + case object MASTER extends DriverStatusResponseField + case object DRIVER_ID extends DriverStatusResponseField + case object DRIVER_STATE extends SubmitDriverResponseField + case object WORKER_ID extends SubmitDriverResponseField + case object WORKER_HOST_PORT extends SubmitDriverResponseField + override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, + MASTER, DRIVER_ID, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) + override val optionalFields = Seq.empty +} + +/** + * A message sent from the standalone Master in response to a DriverStatusResponseMessage. + */ +private[spark] class DriverStatusResponseMessage extends StandaloneRestProtocolMessage( + StandaloneRestProtocolAction.DRIVER_STATUS_RESPONSE, + DriverStatusResponseField.ACTION, + DriverStatusResponseField.requiredFields) + +private[spark] object DriverStatusResponseMessage extends StandaloneRestProtocolMessageCompanion { + protected override def newMessage(): StandaloneRestProtocolMessage = + new DriverStatusResponseMessage + protected override def fieldWithName(field: String): StandaloneRestProtocolField = + DriverStatusResponseField.withName(field) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala new file mode 100644 index 000000000000..7d8dda73414a --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * A field used in a ErrorMessage. + */ +private[spark] abstract class ErrorField extends StandaloneRestProtocolField +private[spark] object ErrorField extends StandaloneRestProtocolFieldCompanion { + case object ACTION extends ErrorField + case object SPARK_VERSION extends ErrorField + case object MESSAGE extends ErrorField + override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE) + override val optionalFields = Seq.empty +} + +/** + * An error message exchanged in the standalone REST protocol. + */ +private[spark] class ErrorMessage extends StandaloneRestProtocolMessage( + StandaloneRestProtocolAction.ERROR, + ErrorField.ACTION, + ErrorField.requiredFields) + +private[spark] object ErrorMessage extends StandaloneRestProtocolMessageCompanion { + protected override def newMessage(): StandaloneRestProtocolMessage = new ErrorMessage + protected override def fieldWithName(field: String): StandaloneRestProtocolField = + ErrorField.withName(field) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala new file mode 100644 index 000000000000..978f1d75498e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * A field used in a KillDriverRequestMessage. + */ +private[spark] abstract class KillDriverRequestField extends StandaloneRestProtocolField +private[spark] object KillDriverRequestField extends StandaloneRestProtocolFieldCompanion { + case object ACTION extends KillDriverRequestField + case object SPARK_VERSION extends KillDriverRequestField + case object MESSAGE extends KillDriverRequestField + case object MASTER extends KillDriverRequestField + case object DRIVER_ID extends KillDriverRequestField + override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, DRIVER_ID) + override val optionalFields = Seq(MESSAGE) +} + +/** + * A request sent to the standalone Master to kill a driver. + */ +private[spark] class KillDriverRequestMessage extends StandaloneRestProtocolMessage( + StandaloneRestProtocolAction.KILL_DRIVER_REQUEST, + KillDriverRequestField.ACTION, + KillDriverRequestField.requiredFields) + +private[spark] object KillDriverRequestMessage extends StandaloneRestProtocolMessageCompanion { + protected override def newMessage(): StandaloneRestProtocolMessage = + new KillDriverRequestMessage + protected override def fieldWithName(field: String): StandaloneRestProtocolField = + KillDriverRequestField.withName(field) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala new file mode 100644 index 000000000000..0a1b5efc2488 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * A field used in a KillDriverResponseMessage. + */ +private[spark] abstract class KillDriverResponseField extends StandaloneRestProtocolField +private[spark] object KillDriverResponseField extends StandaloneRestProtocolFieldCompanion { + case object ACTION extends KillDriverResponseField + case object SPARK_VERSION extends KillDriverResponseField + case object MESSAGE extends KillDriverResponseField + case object MASTER extends KillDriverResponseField + case object DRIVER_ID extends KillDriverResponseField + case object DRIVER_STATE extends SubmitDriverResponseField + override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE) + override val optionalFields = Seq.empty +} + +/** + * A message sent from the standalone Master in response to a KillDriverResponseMessage. + */ +private[spark] class KillDriverResponseMessage extends StandaloneRestProtocolMessage( + StandaloneRestProtocolAction.KILL_DRIVER_RESPONSE, + KillDriverResponseField.ACTION, + KillDriverResponseField.requiredFields) + +private[spark] object KillDriverResponseMessage extends StandaloneRestProtocolMessageCompanion { + protected override def newMessage(): StandaloneRestProtocolMessage = + new KillDriverResponseMessage + protected override def fieldWithName(field: String): StandaloneRestProtocolField = + KillDriverResponseField.withName(field) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala new file mode 100644 index 000000000000..c805d7596824 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -0,0 +1,182 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.DataOutputStream +import java.net.URL +import java.net.HttpURLConnection + +import scala.io.Source + +import com.google.common.base.Charsets + +import org.apache.spark.{SPARK_VERSION => sparkVersion} +import org.apache.spark.deploy.SparkSubmitArguments + +/** + * A client that submits Spark applications using a stable REST protocol in standalone + * cluster mode. This client is intended to communicate with the StandaloneRestServer. + */ +private[spark] class StandaloneRestClient { + + def submitDriver(args: SparkSubmitArguments): Unit = { + validateSubmitArguments(args) + val url = getHttpUrl(args.master) + val request = constructSubmitRequest(args) + val response = sendHttp(url, request) + println(response.toJson) + } + + def killDriver(master: String, driverId: String): Unit = { + validateMaster(master) + val url = getHttpUrl(master) + val request = constructKillRequest(master, driverId) + val response = sendHttp(url, request) + println(response.toJson) + } + + def requestDriverStatus(master: String, driverId: String): Unit = { + validateMaster(master) + val url = getHttpUrl(master) + val request = constructStatusRequest(master, driverId) + val response = sendHttp(url, request) + println(response.toJson) + } + + /** + * Construct a submit driver request message. + */ + private def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage = { + import SubmitDriverRequestField._ + val message = new SubmitDriverRequestMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MASTER, args.master) + .setField(APP_NAME, args.name) + .setField(APP_RESOURCE, args.primaryResource) + .setFieldIfNotNull(MAIN_CLASS, args.mainClass) + .setFieldIfNotNull(JARS, args.jars) + .setFieldIfNotNull(FILES, args.files) + .setFieldIfNotNull(PY_FILES, args.pyFiles) + .setFieldIfNotNull(DRIVER_MEMORY, args.driverMemory) + .setFieldIfNotNull(DRIVER_CORES, args.driverCores) + .setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions) + .setFieldIfNotNull(DRIVER_EXTRA_CLASS_PATH, args.driverExtraClassPath) + .setFieldIfNotNull(DRIVER_EXTRA_LIBRARY_PATH, args.driverExtraLibraryPath) + .setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString) + .setFieldIfNotNull(EXECUTOR_MEMORY, args.executorMemory) + .setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores) + // Set each Spark property as its own field + // TODO: Include environment variables? + args.sparkProperties.foreach { case (k, v) => + message.setFieldIfNotNull(SPARK_PROPERTY(k), v) + } + message.validate() + } + + /** + * Construct a kill driver request message. + */ + private def constructKillRequest( + master: String, + driverId: String): KillDriverRequestMessage = { + import KillDriverRequestField._ + new KillDriverRequestMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MASTER, master) + .setField(DRIVER_ID, driverId) + .validate() + } + + /** + * Construct a driver status request message. + */ + private def constructStatusRequest( + master: String, + driverId: String): DriverStatusRequestMessage = { + import DriverStatusRequestField._ + new DriverStatusRequestMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MASTER, master) + .setField(DRIVER_ID, driverId) + .validate() + } + + /** + * Send the provided request in an HTTP message to the given URL. + * Return the response received from the REST server. + */ + private def sendHttp( + url: URL, + request: StandaloneRestProtocolMessage): StandaloneRestProtocolMessage = { + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + println("Sending this JSON blob to server:\n" + request.toJson) + val content = request.toJson.getBytes(Charsets.UTF_8) + val out = new DataOutputStream(conn.getOutputStream) + out.write(content) + out.close() + val response = Source.fromInputStream(conn.getInputStream).mkString + StandaloneRestProtocolMessage.fromJson(response) + } + + /** + * Throw an exception if this is not standalone cluster mode. + */ + private def validateSubmitArguments(args: SparkSubmitArguments): Unit = { + validateMaster(args.master) + validateDeployMode(args.deployMode) + } + + /** + * Throw an exception if this is not standalone mode. + */ + private def validateMaster(master: String): Unit = { + if (!master.startsWith("spark://")) { + throw new IllegalArgumentException("This REST client is only supported in standalone mode.") + } + } + + /** + * Throw an exception if this is not cluster deploy mode. + */ + private def validateDeployMode(deployMode: String): Unit = { + if (deployMode != "cluster") { + throw new IllegalArgumentException("This REST client is only supported in cluster mode.") + } + } + + /** + * Extract the URL portion of the master address. + */ + private def getHttpUrl(master: String): URL = { + validateMaster(master) + new URL("http://" + master.stripPrefix("spark://")) + } +} + +object StandaloneRestClient { + def main(args: Array[String]): Unit = { + assert(args.length > 0) + val client = new StandaloneRestClient + client.killDriver("spark://" + args(0), "abc_driver") + println("Done.") + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala new file mode 100644 index 000000000000..c8ea8fd395c6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import scala.collection.Map +import scala.collection.mutable + +import org.json4s.jackson.JsonMethods._ +import org.json4s.JsonAST._ + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.util.Utils + +/** + * A field used in a StandaloneRestProtocolMessage. + * Three special fields ACTION, SPARK_VERSION, and MESSAGE are common across all messages. + */ +private[spark] abstract class StandaloneRestProtocolField +private[spark] object StandaloneRestProtocolField { + /** Return whether the provided field name refers to the ACTION field. */ + def isActionField(field: String): Boolean = field == "ACTION" +} + +/** + * All possible values of the ACTION field. + */ +private[spark] object StandaloneRestProtocolAction extends Enumeration { + type StandaloneRestProtocolAction = Value + val SUBMIT_DRIVER_REQUEST, SUBMIT_DRIVER_RESPONSE = Value + val KILL_DRIVER_REQUEST, KILL_DRIVER_RESPONSE = Value + val DRIVER_STATUS_REQUEST, DRIVER_STATUS_RESPONSE = Value + val ERROR = Value +} +import StandaloneRestProtocolAction.StandaloneRestProtocolAction + +/** + * A general message exchanged in the standalone REST protocol. + * + * The message is represented by a set of fields in the form of key value pairs. + * Each message must contain an ACTION field, which should have only one possible value + * for each type of message. For compatibility with older versions of Spark, existing + * fields must not be removed or modified, though new fields can be added as necessary. + */ +private[spark] abstract class StandaloneRestProtocolMessage( + action: StandaloneRestProtocolAction, + actionField: StandaloneRestProtocolField, + requiredFields: Seq[StandaloneRestProtocolField]) { + + import StandaloneRestProtocolField._ + + private val fields = new mutable.HashMap[StandaloneRestProtocolField, String] + private val className = Utils.getFormattedClassName(this) + + // Set the action field + fields(actionField) = action.toString + + /** Return the value of the given field. If the field is not present, throw an exception. */ + def getField(key: StandaloneRestProtocolField): String = { + fields.get(key).getOrElse { + throw new IllegalArgumentException(s"Field $key is not set in message $className") + } + } + + /** Assign the given value to the field, overriding any existing value. */ + def setField(key: StandaloneRestProtocolField, value: String): this.type = { + if (key == actionField) { + throw new SparkException("Setting the ACTION field is only allowed during instantiation.") + } + fields(key) = value + this + } + + /** Assign the given value to the field only if the value is not null. */ + def setFieldIfNotNull(key: StandaloneRestProtocolField, value: String): this.type = { + if (value != null) { + setField(key, value) + } + this + } + + /** + * Validate that all required fields are set and the value of the action field is as expected. + * If any of these conditions are not met, throw an IllegalArgumentException. + */ + def validate(): this.type = { + if (!fields.contains(actionField)) { + throw new IllegalArgumentException(s"The action field is missing from message $className.") + } + if (fields(actionField) != action.toString) { + throw new IllegalArgumentException( + s"Expected action $action in message $className, but actual was ${fields(actionField)}.") + } + val missingFields = requiredFields.filterNot(fields.contains) + if (missingFields.nonEmpty) { + val missingFieldsString = missingFields.mkString(", ") + throw new IllegalArgumentException( + s"The following fields are missing from message $className: $missingFieldsString.") + } + this + } + + /** Return the JSON representation of this message. */ + def toJson: String = { + val stringFields = fields + .filter { case (_, v) => v != null } + .map { case (k, v) => (k.toString, v) } + val jsonFields = fieldsToJson(stringFields) + pretty(render(jsonFields)) + } + + /** + * Return the JSON representation of the message fields, putting ACTION first. + * This assumes that applying `org.apache.spark.util.JsonProtocol.mapFromJson` + * to the result yields the original input. + */ + private def fieldsToJson(fields: Map[String, String]): JValue = { + val jsonFields = fields.toList + .sortBy { case (k, _) => if (isActionField(k)) 0 else 1 } + .map { case (k, v) => JField(k, JString(v)) } + JObject(jsonFields) + } +} + +private[spark] object StandaloneRestProtocolMessage { + import StandaloneRestProtocolField._ + import StandaloneRestProtocolAction._ + + /** + * Construct a StandaloneRestProtocolMessage from JSON. + * This uses the ACTION field to determine the type of the message to reconstruct. + * If such a field does not exist in the JSON, throw an exception. + */ + def fromJson(json: String): StandaloneRestProtocolMessage = { + val fields = org.apache.spark.util.JsonProtocol.mapFromJson(parse(json)) + val action = fields + .flatMap { case (k, v) => if (isActionField(k)) Some(v) else None } + .headOption + .getOrElse { throw new IllegalArgumentException(s"ACTION not found in message:\n$json") } + StandaloneRestProtocolAction.withName(action) match { + case SUBMIT_DRIVER_REQUEST => SubmitDriverRequestMessage.fromFields(fields) + case SUBMIT_DRIVER_RESPONSE => SubmitDriverResponseMessage.fromFields(fields) + case KILL_DRIVER_REQUEST => KillDriverRequestMessage.fromFields(fields) + case KILL_DRIVER_RESPONSE => KillDriverResponseMessage.fromFields(fields) + case DRIVER_STATUS_REQUEST => DriverStatusRequestMessage.fromFields(fields) + case DRIVER_STATUS_RESPONSE => DriverStatusResponseMessage.fromFields(fields) + case ERROR => ErrorMessage.fromFields(fields) + } + } +} + +/** + * A trait that holds common methods for StandaloneRestProtocolField companion objects. + * + * It is necessary to keep track of all fields that belong to this object in order to + * reconstruct the fields from their names. + */ +private[spark] trait StandaloneRestProtocolFieldCompanion { + val requiredFields: Seq[StandaloneRestProtocolField] + val optionalFields: Seq[StandaloneRestProtocolField] + + /** Listing of all fields indexed by the field's string representation. */ + private lazy val allFieldsMap: Map[String, StandaloneRestProtocolField] = { + (requiredFields ++ optionalFields).map { f => (f.toString, f) }.toMap + } + + /** Return a StandaloneRestProtocolField from its string representation. */ + def withName(field: String): StandaloneRestProtocolField = { + allFieldsMap.get(field).getOrElse { + throw new IllegalArgumentException(s"Unknown field $field") + } + } +} + +/** + * A trait that holds common methods for StandaloneRestProtocolMessage companion objects. + */ +private[spark] trait StandaloneRestProtocolMessageCompanion extends Logging { + import StandaloneRestProtocolField._ + + /** Construct a new message of the relevant type. */ + protected def newMessage(): StandaloneRestProtocolMessage + + /** Return a field of the relevant type from the field's string representation. */ + protected def fieldWithName(field: String): StandaloneRestProtocolField + + /** Construct a StandaloneRestProtocolMessage from the set of fields provided. */ + def fromFields(fields: Map[String, String]): StandaloneRestProtocolMessage = { + val message = newMessage() + fields.foreach { case (k, v) => + try { + // The ACTION field is already set on instantiation + if (!isActionField(k)) { + message.setField(fieldWithName(k), v) + } + } catch { + case e: IllegalArgumentException => + logWarning(s"Unexpected field $k in message ${Utils.getFormattedClassName(this)}") + } + } + message + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala new file mode 100644 index 000000000000..f4294b64a653 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.DataOutputStream +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import scala.io.Source + +import com.google.common.base.Charsets +import org.eclipse.jetty.server.{Request, Server} +import org.eclipse.jetty.server.handler.AbstractHandler + +import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion} +import org.apache.spark.deploy.rest.StandaloneRestProtocolAction._ +import org.apache.spark.util.Utils + +/** + * A server that responds to requests submitted by the StandaloneRestClient. + */ +private[spark] class StandaloneRestServer(requestedPort: Int) { + val server = new Server(requestedPort) + server.setHandler(new StandaloneRestHandler) + server.start() + server.join() +} + +/** + * A Jetty handler that responds to requests submitted via the standalone REST protocol. + */ +private[spark] class StandaloneRestHandler extends AbstractHandler with Logging { + + override def handle( + target: String, + baseRequest: Request, + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + try { + val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString + val requestMessage = StandaloneRestProtocolMessage.fromJson(requestMessageJson) + val responseMessage = constructResponseMessage(requestMessage) + response.setContentType("application/json") + response.setCharacterEncoding("utf-8") + response.setStatus(HttpServletResponse.SC_OK) + val content = responseMessage.toJson.getBytes(Charsets.UTF_8) + val out = new DataOutputStream(response.getOutputStream) + out.write(content) + out.close() + baseRequest.setHandled(true) + } catch { + case e: Exception => logError("Exception while handling request", e) + } + } + + private def constructResponseMessage( + request: StandaloneRestProtocolMessage): StandaloneRestProtocolMessage = { + // If the request is sent via the StandaloneRestClient, it should have already been + // validated remotely. In case this is not true, validate the request here to guard + // against potential NPEs. If validation fails, return an ERROR message to the sender. + try { + request.validate() + } catch { + case e: IllegalArgumentException => + return handleError(e.getMessage) + } + request match { + case submit: SubmitDriverRequestMessage => handleSubmitRequest(submit) + case kill: KillDriverRequestMessage => handleKillRequest(kill) + case status: DriverStatusRequestMessage => handleStatusRequest(status) + case unexpected => handleError( + s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + } + } + + private def handleSubmitRequest( + request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { + import SubmitDriverResponseField._ + // TODO: Actually submit the driver + val message = "Driver is submitted successfully..." + val master = request.getField(SubmitDriverRequestField.MASTER) + val driverId = "new_driver_id" + val driverState = "SUBMITTED" + new SubmitDriverResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, message) + .setField(MASTER, master) + .setField(DRIVER_ID, driverId) + .setField(DRIVER_STATE, driverState) + .validate() + } + + private def handleKillRequest(request: KillDriverRequestMessage): KillDriverResponseMessage = { + import KillDriverResponseField._ + // TODO: Actually kill the driver + val message = "Driver is killed successfully..." + val master = request.getField(KillDriverRequestField.MASTER) + val driverId = request.getField(KillDriverRequestField.DRIVER_ID) + val driverState = "KILLED" + new KillDriverResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, message) + .setField(MASTER, master) + .setField(DRIVER_ID, driverId) + .setField(DRIVER_STATE, driverState) + .validate() + } + + private def handleStatusRequest( + request: DriverStatusRequestMessage): DriverStatusResponseMessage = { + import DriverStatusResponseField._ + // TODO: Actually look up the status of the driver + val master = request.getField(DriverStatusRequestField.MASTER) + val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) + val driverState = "HEALTHY" + new DriverStatusResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MASTER, master) + .setField(DRIVER_ID, driverId) + .setField(DRIVER_STATE, driverState) + .validate() + } + + private def handleError(message: String): ErrorMessage = { + import ErrorField._ + new ErrorMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, message) + .validate() + } +} + +object StandaloneRestServer { + def main(args: Array[String]): Unit = { + println("Hey boy I'm starting a server.") + new StandaloneRestServer(6677) + readLine() + } +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala new file mode 100644 index 000000000000..fd95ecb1aefb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import org.apache.spark.util.Utils + +/** + * A field used in a SubmitDriverRequestMessage. + */ +private[spark] abstract class SubmitDriverRequestField extends StandaloneRestProtocolField +private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFieldCompanion { + case object ACTION extends SubmitDriverRequestField + case object SPARK_VERSION extends SubmitDriverRequestField + case object MESSAGE extends SubmitDriverRequestField + case object MASTER extends SubmitDriverRequestField + case object APP_NAME extends SubmitDriverRequestField + case object APP_RESOURCE extends SubmitDriverRequestField + case object MAIN_CLASS extends SubmitDriverRequestField + case object JARS extends SubmitDriverRequestField + case object FILES extends SubmitDriverRequestField + case object PY_FILES extends SubmitDriverRequestField + case object DRIVER_MEMORY extends SubmitDriverRequestField + case object DRIVER_CORES extends SubmitDriverRequestField + case object DRIVER_EXTRA_JAVA_OPTIONS extends SubmitDriverRequestField + case object DRIVER_EXTRA_CLASS_PATH extends SubmitDriverRequestField + case object DRIVER_EXTRA_LIBRARY_PATH extends SubmitDriverRequestField + case object SUPERVISE_DRIVER extends SubmitDriverRequestField + case object EXECUTOR_MEMORY extends SubmitDriverRequestField + case object TOTAL_EXECUTOR_CORES extends SubmitDriverRequestField + case class SPARK_PROPERTY(prop: String) extends SubmitDriverRequestField { + override def toString: String = Utils.getFormattedClassName(this) + "_" + prop + } + case class ENVIRONMENT_VARIABLE(envVar: String) extends SubmitDriverRequestField { + override def toString: String = Utils.getFormattedClassName(this) + "_" + envVar + } + override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, APP_NAME, APP_RESOURCE) + override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, + DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, + SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES) +} + +/** + * A request sent to the standalone Master to submit a driver. + */ +private[spark] class SubmitDriverRequestMessage extends StandaloneRestProtocolMessage( + StandaloneRestProtocolAction.SUBMIT_DRIVER_REQUEST, + SubmitDriverRequestField.ACTION, + SubmitDriverRequestField.requiredFields) + +private[spark] object SubmitDriverRequestMessage extends StandaloneRestProtocolMessageCompanion { + protected override def newMessage(): StandaloneRestProtocolMessage = + new SubmitDriverRequestMessage + protected override def fieldWithName(field: String): StandaloneRestProtocolField = + SubmitDriverRequestField.withName(field) +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala new file mode 100644 index 000000000000..034c7f80de23 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * A field used in a SubmitDriverResponseMessage. + */ +private[spark] abstract class SubmitDriverResponseField extends StandaloneRestProtocolField +private[spark] object SubmitDriverResponseField extends StandaloneRestProtocolFieldCompanion { + case object ACTION extends SubmitDriverResponseField + case object SPARK_VERSION extends SubmitDriverResponseField + case object MESSAGE extends SubmitDriverResponseField + case object MASTER extends SubmitDriverResponseField + case object DRIVER_ID extends SubmitDriverResponseField + case object DRIVER_STATE extends SubmitDriverResponseField + override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE) + override val optionalFields = Seq.empty +} + +/** + * A message sent from the standalone Master in response to a SubmitDriverRequestMessage. + */ +private[spark] class SubmitDriverResponseMessage extends StandaloneRestProtocolMessage( + StandaloneRestProtocolAction.SUBMIT_DRIVER_RESPONSE, + SubmitDriverResponseField.ACTION, + SubmitDriverResponseField.requiredFields) + +private[spark] object SubmitDriverResponseMessage extends StandaloneRestProtocolMessageCompanion { + protected override def newMessage(): StandaloneRestProtocolMessage = + new SubmitDriverResponseMessage + protected override def fieldWithName(field: String): StandaloneRestProtocolField = + SubmitDriverResponseField.withName(field) +} From af9d9cb2933774e33278022a38196b522193a766 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 16 Jan 2015 19:12:32 -0800 Subject: [PATCH 02/48] Integrate REST protocol in standalone mode This commit embeds the REST server in the standalone Master and forces Spark submit to submit applications through the REST client. This is the first working end-to-end implementation of a stable submission interface in standalone cluster mode. --- .../apache/spark/deploy/ClientArguments.scala | 18 ++- .../org/apache/spark/deploy/SparkSubmit.scala | 11 ++ .../spark/deploy/SparkSubmitArguments.scala | 9 ++ .../apache/spark/deploy/master/Master.scala | 3 + .../rest/KillDriverResponseMessage.scala | 4 +- .../deploy/rest/StandaloneRestClient.scala | 21 ++- .../rest/StandaloneRestProtocolMessage.scala | 15 +- .../deploy/rest/StandaloneRestServer.scala | 115 +++++-------- .../rest/StandaloneRestServerHandler.scala | 153 ++++++++++++++++++ .../rest/SubmitDriverRequestMessage.scala | 37 ++++- .../rest/SubmitDriverResponseMessage.scala | 6 +- 11 files changed, 297 insertions(+), 95 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 2e1e52906cee..936e7dd2d191 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -29,8 +29,7 @@ import org.apache.spark.util.MemoryParam * Command-line parser for the driver client. */ private[spark] class ClientArguments(args: Array[String]) { - val defaultCores = 1 - val defaultMemory = 512 + import ClientArguments._ var cmd: String = "" // 'launch' or 'kill' var logLevel = Level.WARN @@ -39,9 +38,9 @@ private[spark] class ClientArguments(args: Array[String]) { var master: String = "" var jarUrl: String = "" var mainClass: String = "" - var supervise: Boolean = false - var memory: Int = defaultMemory - var cores: Int = defaultCores + var supervise: Boolean = DEFAULT_SUPERVISE + var memory: Int = DEFAULT_MEMORY + var cores: Int = DEFAULT_CORES private var _driverOptions = ListBuffer[String]() def driverOptions = _driverOptions.toSeq @@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) { |Usage: DriverClient kill | |Options: - | -c CORES, --cores CORES Number of cores to request (default: $defaultCores) - | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory) + | -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES) + | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY) | -s, --supervise Whether to restart the driver on failure + | (default: $DEFAULT_SUPERVISE) | -v, --verbose Print more debugging output """.stripMargin System.err.println(usage) @@ -117,6 +117,10 @@ private[spark] class ClientArguments(args: Array[String]) { } object ClientArguments { + private[spark] val DEFAULT_CORES = 1 + private[spark] val DEFAULT_MEMORY = 512 // MB + private[spark] val DEFAULT_SUPERVISE = false + def isValidJarUrl(s: String): Boolean = { try { val uri = new URI(s) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1c89b2452e83..7c89f0bcaeba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import org.apache.spark.executor.ExecutorURLClassLoader import org.apache.spark.util.Utils +import org.apache.spark.deploy.rest.StandaloneRestClient /** * Main gateway of launching a Spark application. @@ -72,6 +73,16 @@ object SparkSubmit { if (appArgs.verbose) { printStream.println(appArgs) } + + // In standalone cluster mode, use the brand new REST client to submit the application + val doingRest = appArgs.master.startsWith("spark://") && appArgs.deployMode == "cluster" + if (doingRest) { + println("Submitting driver through the REST interface.") + new StandaloneRestClient().submitDriver(appArgs) + println("Done submitting driver.") + return + } + val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 47059b08a397..310b34a92633 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -104,6 +104,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St .orElse(sparkProperties.get("spark.master")) .orElse(env.get("MASTER")) .orNull + driverExtraClassPath = Option(driverExtraClassPath) + .orElse(sparkProperties.get("spark.driver.extraClassPath")) + .orNull + driverExtraJavaOptions = Option(driverExtraJavaOptions) + .orElse(sparkProperties.get("spark.driver.extraJavaOptions")) + .orNull + driverExtraLibraryPath = Option(driverExtraLibraryPath) + .orElse(sparkProperties.get("spark.driver.extraLibraryPath")) + .orNull driverMemory = Option(driverMemory) .orElse(sparkProperties.get("spark.driver.memory")) .orElse(env.get("SPARK_DRIVER_MEMORY")) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 4b631ec63907..24c08373ade4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -43,6 +43,7 @@ import org.apache.spark.deploy.history.HistoryServer import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI +import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI @@ -121,6 +122,8 @@ private[spark] class Master( throw new SparkException("spark.deploy.defaultCores must be positive") } + val restServer = new StandaloneRestServer(this, host, 6677) + override def preStart() { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index 0a1b5efc2488..69b9a4f4bdab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -27,8 +27,8 @@ private[spark] object KillDriverResponseField extends StandaloneRestProtocolFiel case object MESSAGE extends KillDriverResponseField case object MASTER extends KillDriverResponseField case object DRIVER_ID extends KillDriverResponseField - case object DRIVER_STATE extends SubmitDriverResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE) + case object SUCCESS extends SubmitDriverResponseField + override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, SUCCESS) override val optionalFields = Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index c805d7596824..6059344d93b6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -27,6 +27,7 @@ import com.google.common.base.Charsets import org.apache.spark.{SPARK_VERSION => sparkVersion} import org.apache.spark.deploy.SparkSubmitArguments +import org.apache.spark.util.Utils /** * A client that submits Spark applications using a stable REST protocol in standalone @@ -63,6 +64,12 @@ private[spark] class StandaloneRestClient { */ private def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage = { import SubmitDriverRequestField._ + val driverMemory = Option(args.driverMemory) + .map { m => Utils.memoryStringToMb(m).toString } + .orNull + val executorMemory = Option(args.executorMemory) + .map { m => Utils.memoryStringToMb(m).toString } + .orNull val message = new SubmitDriverRequestMessage() .setField(SPARK_VERSION, sparkVersion) .setField(MASTER, args.master) @@ -72,19 +79,21 @@ private[spark] class StandaloneRestClient { .setFieldIfNotNull(JARS, args.jars) .setFieldIfNotNull(FILES, args.files) .setFieldIfNotNull(PY_FILES, args.pyFiles) - .setFieldIfNotNull(DRIVER_MEMORY, args.driverMemory) + .setFieldIfNotNull(DRIVER_MEMORY, driverMemory) .setFieldIfNotNull(DRIVER_CORES, args.driverCores) .setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions) .setFieldIfNotNull(DRIVER_EXTRA_CLASS_PATH, args.driverExtraClassPath) .setFieldIfNotNull(DRIVER_EXTRA_LIBRARY_PATH, args.driverExtraLibraryPath) .setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString) - .setFieldIfNotNull(EXECUTOR_MEMORY, args.executorMemory) + .setFieldIfNotNull(EXECUTOR_MEMORY, executorMemory) .setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores) - // Set each Spark property as its own field - // TODO: Include environment variables? + args.childArgs.zipWithIndex.foreach { case (arg, i) => + message.setFieldIfNotNull(APP_ARG(i), arg) + } args.sparkProperties.foreach { case (k, v) => message.setFieldIfNotNull(SPARK_PROPERTY(k), v) } + // TODO: set environment variables? message.validate() } @@ -175,8 +184,8 @@ private[spark] class StandaloneRestClient { object StandaloneRestClient { def main(args: Array[String]): Unit = { assert(args.length > 0) - val client = new StandaloneRestClient - client.killDriver("spark://" + args(0), "abc_driver") + //val client = new StandaloneRestClient + //client.submitDriver("spark://" + args(0)) println("Done.") } } \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala index c8ea8fd395c6..7945271a870f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala @@ -63,19 +63,28 @@ private[spark] abstract class StandaloneRestProtocolMessage( import StandaloneRestProtocolField._ - private val fields = new mutable.HashMap[StandaloneRestProtocolField, String] private val className = Utils.getFormattedClassName(this) + protected val fields = new mutable.HashMap[StandaloneRestProtocolField, String] // Set the action field fields(actionField) = action.toString + /** Return all fields currently set in this message. */ + def getFields: Map[StandaloneRestProtocolField, String] = fields + + /** Return the value of the given field. If the field is not present, return null. */ + def getField(key: StandaloneRestProtocolField): String = getFieldOption(key).orNull + /** Return the value of the given field. If the field is not present, throw an exception. */ - def getField(key: StandaloneRestProtocolField): String = { - fields.get(key).getOrElse { + def getFieldNotNull(key: StandaloneRestProtocolField): String = { + getFieldOption(key).getOrElse { throw new IllegalArgumentException(s"Field $key is not set in message $className") } } + /** Return the value of the given field as an option. */ + def getFieldOption(key: StandaloneRestProtocolField): Option[String] = fields.get(key) + /** Assign the given value to the field, overriding any existing value. */ def setField(key: StandaloneRestProtocolField, value: String): this.type = { if (key == actionField) { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index f4294b64a653..344a3ef89a4d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -18,6 +18,7 @@ package org.apache.spark.deploy.rest import java.io.DataOutputStream +import java.net.InetSocketAddress import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source @@ -26,25 +27,37 @@ import com.google.common.base.Charsets import org.eclipse.jetty.server.{Request, Server} import org.eclipse.jetty.server.handler.AbstractHandler -import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion} -import org.apache.spark.deploy.rest.StandaloneRestProtocolAction._ -import org.apache.spark.util.Utils +import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging} +import org.apache.spark.deploy.master.Master +import org.apache.spark.util.{AkkaUtils, Utils} /** * A server that responds to requests submitted by the StandaloneRestClient. */ -private[spark] class StandaloneRestServer(requestedPort: Int) { - val server = new Server(requestedPort) - server.setHandler(new StandaloneRestHandler) +private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) { + val server = new Server(new InetSocketAddress(host, requestedPort)) + server.setHandler(new StandaloneRestServerHandler(master)) server.start() - server.join() } /** * A Jetty handler that responds to requests submitted via the standalone REST protocol. */ -private[spark] class StandaloneRestHandler extends AbstractHandler with Logging { +private[spark] abstract class StandaloneRestHandler(master: Master) + extends AbstractHandler with Logging { + private implicit val askTimeout = AkkaUtils.askTimeout(master.conf) + + /** Handle a request to submit a driver. */ + protected def handleSubmit(request: SubmitDriverRequestMessage): SubmitDriverResponseMessage + /** Handle a request to kill a driver. */ + protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage + /** Handle a request for a driver's status. */ + protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage + + /** + * Handle a request submitted by the StandaloneRestClient. + */ override def handle( target: String, baseRequest: Request, @@ -67,6 +80,10 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging } } + /** + * Construct the appropriate response message based on the type of the request message. + * If an IllegalArgumentException is thrown in the process, construct an error message. + */ private def constructResponseMessage( request: StandaloneRestProtocolMessage): StandaloneRestProtocolMessage = { // If the request is sent via the StandaloneRestClient, it should have already been @@ -74,67 +91,21 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging // against potential NPEs. If validation fails, return an ERROR message to the sender. try { request.validate() + request match { + case submit: SubmitDriverRequestMessage => handleSubmit(submit) + case kill: KillDriverRequestMessage => handleKill(kill) + case status: DriverStatusRequestMessage => handleStatus(status) + case unexpected => handleError( + s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + } } catch { - case e: IllegalArgumentException => - return handleError(e.getMessage) - } - request match { - case submit: SubmitDriverRequestMessage => handleSubmitRequest(submit) - case kill: KillDriverRequestMessage => handleKillRequest(kill) - case status: DriverStatusRequestMessage => handleStatusRequest(status) - case unexpected => handleError( - s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + // Propagate exception to user in an ErrorMessage. If the construction of the + // ErrorMessage itself throws an exception, log the exception and ignore the request. + case e: IllegalArgumentException => handleError(e.getMessage) } } - private def handleSubmitRequest( - request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { - import SubmitDriverResponseField._ - // TODO: Actually submit the driver - val message = "Driver is submitted successfully..." - val master = request.getField(SubmitDriverRequestField.MASTER) - val driverId = "new_driver_id" - val driverState = "SUBMITTED" - new SubmitDriverResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MESSAGE, message) - .setField(MASTER, master) - .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() - } - - private def handleKillRequest(request: KillDriverRequestMessage): KillDriverResponseMessage = { - import KillDriverResponseField._ - // TODO: Actually kill the driver - val message = "Driver is killed successfully..." - val master = request.getField(KillDriverRequestField.MASTER) - val driverId = request.getField(KillDriverRequestField.DRIVER_ID) - val driverState = "KILLED" - new KillDriverResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MESSAGE, message) - .setField(MASTER, master) - .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() - } - - private def handleStatusRequest( - request: DriverStatusRequestMessage): DriverStatusResponseMessage = { - import DriverStatusResponseField._ - // TODO: Actually look up the status of the driver - val master = request.getField(DriverStatusRequestField.MASTER) - val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) - val driverState = "HEALTHY" - new DriverStatusResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MASTER, master) - .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() - } - + /** Construct an error message to signal the fact that an exception has been thrown. */ private def handleError(message: String): ErrorMessage = { import ErrorField._ new ErrorMessage() @@ -144,10 +115,10 @@ private[spark] class StandaloneRestHandler extends AbstractHandler with Logging } } -object StandaloneRestServer { - def main(args: Array[String]): Unit = { - println("Hey boy I'm starting a server.") - new StandaloneRestServer(6677) - readLine() - } -} \ No newline at end of file +//object StandaloneRestServer { +// def main(args: Array[String]): Unit = { +// println("Hey boy I'm starting a server.") +// new StandaloneRestServer(6677) +// readLine() +// } +//} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala new file mode 100644 index 000000000000..e11698e51bf1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.File + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.{SPARK_VERSION => sparkVersion} +import org.apache.spark.SparkConf +import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.Master + +/** + * + */ +private[spark] class StandaloneRestServerHandler(master: Master) + extends StandaloneRestHandler(master) { + + private implicit val askTimeout = AkkaUtils.askTimeout(master.conf) + + override protected def handleSubmit( + request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { + import SubmitDriverResponseField._ + val driverDescription = buildDriverDescription(request) + val response = AkkaUtils.askWithReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription), master.self, askTimeout) + new SubmitDriverResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, response.message) + .setField(MASTER, master.masterUrl) + .setField(SUCCESS, response.success.toString) + .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) + .validate() + } + + override protected def handleKill( + request: KillDriverRequestMessage): KillDriverResponseMessage = { + import KillDriverResponseField._ + val driverId = request.getFieldNotNull(KillDriverRequestField.DRIVER_ID) + val response = AkkaUtils.askWithReply[KillDriverResponse]( + RequestKillDriver(driverId), master.self, askTimeout) + new KillDriverResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, response.message) + .setField(MASTER, master.masterUrl) + .setField(DRIVER_ID, driverId) + .setField(SUCCESS, response.success.toString) + .validate() + } + + override protected def handleStatus( + request: DriverStatusRequestMessage): DriverStatusResponseMessage = { + import DriverStatusResponseField._ + // TODO: Actually look up the status of the driver + val master = request.getField(DriverStatusRequestField.MASTER) + val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) + val driverState = "HEALTHY" + new DriverStatusResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MASTER, master) + .setField(DRIVER_ID, driverId) + .setField(DRIVER_STATE, driverState) + .validate() + } + + private def buildDriverDescription(request: SubmitDriverRequestMessage): DriverDescription = { + import SubmitDriverRequestField._ + + // Required fields + //val _master = request.getFieldNotNull(MASTER) + val appName = request.getFieldNotNull(APP_NAME) + val appResource = request.getFieldNotNull(APP_RESOURCE) + + // Since standalone cluster mode does not yet support python, + // we treat the main class as required + val mainClass = request.getFieldNotNull(MAIN_CLASS) + + // Optional fields + val jars = request.getFieldOption(JARS) + val files = request.getFieldOption(FILES) + val driverMemory = request.getFieldOption(DRIVER_MEMORY) + val driverCores = request.getFieldOption(DRIVER_CORES) + val driverExtraJavaOptions = request.getFieldOption(DRIVER_EXTRA_JAVA_OPTIONS) + val driverExtraClassPath = request.getFieldOption(DRIVER_EXTRA_CLASS_PATH) + val driverExtraLibraryPath = request.getFieldOption(DRIVER_EXTRA_LIBRARY_PATH) + val superviseDriver = request.getFieldOption(SUPERVISE_DRIVER) + val executorMemory = request.getFieldOption(EXECUTOR_MEMORY) + val totalExecutorCores = request.getFieldOption(TOTAL_EXECUTOR_CORES) + + // Parse special fields that take in parameters + val conf = new SparkConf(false) + val env = new mutable.HashMap[String, String] + val appArgs = new ArrayBuffer[(Int, String)] + request.getFields.foreach { case (k, v) => + println(s"> Found this field: $k = $v") + k match { + case APP_ARG(index) => appArgs += ((index, v)) + case SPARK_PROPERTY(propKey) => conf.set(propKey, v) + case ENVIRONMENT_VARIABLE(envKey) => env(envKey) = v + case _ => + } + } + + // Use the actual master URL instead of the one that refers to this REST server + // Otherwise, once the driver is launched it will contact with the wrong server + conf.set("spark.master", master.masterUrl) + conf.set("spark.app.name", appName) + conf.set("spark.jars", jars.map(_ + ",").getOrElse("") + appResource) // include app resource + files.foreach { f => conf.set("spark.files", f) } + driverExtraJavaOptions.foreach { j => conf.set("spark.driver.extraJavaOptions", j) } + driverExtraClassPath.foreach { cp => conf.set("spark.driver.extraClassPath", cp) } + driverExtraLibraryPath.foreach { lp => conf.set("spark.driver.extraLibraryPath", lp) } + executorMemory.foreach { m => conf.set("spark.executor.memory", m) } + totalExecutorCores.foreach { c => conf.set("spark.cores.max", c) } + + // Construct driver description and submit it + val actualDriverMemory = driverMemory.map(_.toInt).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualAppArgs = appArgs.sortBy(_._1).map(_._2) // sort by index, map to value + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", mainClass) ++ actualAppArgs, + env, extraClassPath, extraLibraryPath, javaOpts) + new DriverDescription( + appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index fd95ecb1aefb..72f92f2c0d49 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -17,6 +17,8 @@ package org.apache.spark.deploy.rest +import scala.util.matching.Regex + import org.apache.spark.util.Utils /** @@ -39,9 +41,12 @@ private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFie case object DRIVER_EXTRA_JAVA_OPTIONS extends SubmitDriverRequestField case object DRIVER_EXTRA_CLASS_PATH extends SubmitDriverRequestField case object DRIVER_EXTRA_LIBRARY_PATH extends SubmitDriverRequestField - case object SUPERVISE_DRIVER extends SubmitDriverRequestField + case object SUPERVISE_DRIVER extends SubmitDriverRequestField // standalone cluster mode only case object EXECUTOR_MEMORY extends SubmitDriverRequestField case object TOTAL_EXECUTOR_CORES extends SubmitDriverRequestField + case class APP_ARG(index: Int) extends SubmitDriverRequestField { + override def toString: String = Utils.getFormattedClassName(this) + "_" + index + } case class SPARK_PROPERTY(prop: String) extends SubmitDriverRequestField { override def toString: String = Utils.getFormattedClassName(this) + "_" + prop } @@ -52,6 +57,22 @@ private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFie override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES) + + // Because certain fields taken in arguments, we cannot simply rely on the + // list of all fields to reconstruct a field from its String representation. + // Instead, we must treat these fields as special cases and match on their prefixes. + override def withName(field: String): StandaloneRestProtocolField = { + def buildRegex(obj: AnyRef): Regex = s"${Utils.getFormattedClassName(obj)}_(.*)".r + val appArg = buildRegex(APP_ARG) + val sparkProperty = buildRegex(SPARK_PROPERTY) + val environmentVariable = buildRegex(ENVIRONMENT_VARIABLE) + field match { + case appArg(f) => APP_ARG(f.toInt) + case sparkProperty(f) => SPARK_PROPERTY(f) + case environmentVariable(f) => ENVIRONMENT_VARIABLE(f) + case _ => super.withName(field) + } + } } /** @@ -60,7 +81,19 @@ private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFie private[spark] class SubmitDriverRequestMessage extends StandaloneRestProtocolMessage( StandaloneRestProtocolAction.SUBMIT_DRIVER_REQUEST, SubmitDriverRequestField.ACTION, - SubmitDriverRequestField.requiredFields) + SubmitDriverRequestField.requiredFields) { + + // Ensure continuous range of app arg indices starting from 0 + override def validate(): this.type = { + import SubmitDriverRequestField._ + val indices = fields.collect { case (a: APP_ARG, _) => a }.toSeq.sortBy(_.index).map(_.index) + val expectedIndices = (0 until indices.size).toSeq + if (indices != expectedIndices) { + throw new IllegalArgumentException(s"Malformed app arg indices: ${indices.mkString(",")}") + } + super.validate() + } +} private[spark] object SubmitDriverRequestMessage extends StandaloneRestProtocolMessageCompanion { protected override def newMessage(): StandaloneRestProtocolMessage = diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index 034c7f80de23..e656c35ad965 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -26,10 +26,10 @@ private[spark] object SubmitDriverResponseField extends StandaloneRestProtocolFi case object SPARK_VERSION extends SubmitDriverResponseField case object MESSAGE extends SubmitDriverResponseField case object MASTER extends SubmitDriverResponseField + case object SUCCESS extends SubmitDriverResponseField case object DRIVER_ID extends SubmitDriverResponseField - case object DRIVER_STATE extends SubmitDriverResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE) - override val optionalFields = Seq.empty + override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, SUCCESS) + override val optionalFields = Seq(DRIVER_ID) } /** From 6ff088dca37f3f888a94c00a5d533ef8c6a6e6f5 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 19 Jan 2015 16:24:05 -0800 Subject: [PATCH 03/48] Rename classes to generalize REST protocol Previously the REST protocol was very explicitly tied to the standalone mode. This commit frees the protocol from this restriction. --- .../org/apache/spark/deploy/SparkSubmit.scala | 7 +- .../apache/spark/deploy/master/Master.scala | 5 +- .../rest/DriverStatusRequestMessage.scala | 16 +- .../rest/DriverStatusResponseMessage.scala | 16 +- .../spark/deploy/rest/ErrorMessage.scala | 16 +- .../rest/KillDriverRequestMessage.scala | 16 +- .../rest/KillDriverResponseMessage.scala | 16 +- .../deploy/rest/StandaloneRestClient.scala | 113 ++------- .../deploy/rest/StandaloneRestServer.scala | 216 +++++++++++------- .../rest/StandaloneRestServerHandler.scala | 153 ------------- .../rest/SubmitDriverRequestMessage.scala | 18 +- .../rest/SubmitDriverResponseMessage.scala | 16 +- .../spark/deploy/rest/SubmitRestClient.scala | 98 ++++++++ ....scala => SubmitRestProtocolMessage.scala} | 78 +++---- .../spark/deploy/rest/SubmitRestServer.scala | 112 +++++++++ 15 files changed, 466 insertions(+), 430 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala rename core/src/main/scala/org/apache/spark/deploy/rest/{StandaloneRestProtocolMessage.scala => SubmitRestProtocolMessage.scala} (72%) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 7c89f0bcaeba..ec3def339759 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -75,11 +75,10 @@ object SparkSubmit { } // In standalone cluster mode, use the brand new REST client to submit the application - val doingRest = appArgs.master.startsWith("spark://") && appArgs.deployMode == "cluster" - if (doingRest) { - println("Submitting driver through the REST interface.") + val isStandaloneCluster = + appArgs.master.startsWith("spark://") && appArgs.deployMode == "cluster" + if (isStandaloneCluster) { new StandaloneRestClient().submitDriver(appArgs) - println("Done submitting driver.") return } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 24c08373ade4..1bd1992c95b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -122,7 +122,10 @@ private[spark] class Master( throw new SparkException("spark.deploy.defaultCores must be positive") } - val restServer = new StandaloneRestServer(this, host, 6677) + // Alternative application submission gateway that is stable across Spark versions + private val restServerPort = conf.getInt("spark.master.rest.port", 17077) + private val restServer = new StandaloneRestServer(this, host, restServerPort) + restServer.start() override def preStart() { logInfo("Starting Spark master at " + masterUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala index eed9f7388ff8..14b77e4da9c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.rest /** * A field used in a DriverStatusRequestMessage. */ -private[spark] abstract class DriverStatusRequestField extends StandaloneRestProtocolField -private[spark] object DriverStatusRequestField extends StandaloneRestProtocolFieldCompanion { +private[spark] abstract class DriverStatusRequestField extends SubmitRestProtocolField +private[spark] object DriverStatusRequestField extends SubmitRestProtocolFieldCompanion { case object ACTION extends DriverStatusRequestField case object SPARK_VERSION extends DriverStatusRequestField case object MESSAGE extends DriverStatusRequestField @@ -32,16 +32,16 @@ private[spark] object DriverStatusRequestField extends StandaloneRestProtocolFie } /** - * A request sent to the standalone Master to query the status of a driver. + * A request sent to the cluster manager to query the status of a driver. */ -private[spark] class DriverStatusRequestMessage extends StandaloneRestProtocolMessage( - StandaloneRestProtocolAction.DRIVER_STATUS_REQUEST, +private[spark] class DriverStatusRequestMessage extends SubmitRestProtocolMessage( + SubmitRestProtocolAction.DRIVER_STATUS_REQUEST, DriverStatusRequestField.ACTION, DriverStatusRequestField.requiredFields) -private[spark] object DriverStatusRequestMessage extends StandaloneRestProtocolMessageCompanion { - protected override def newMessage(): StandaloneRestProtocolMessage = +private[spark] object DriverStatusRequestMessage extends SubmitRestProtocolMessageCompanion { + protected override def newMessage(): SubmitRestProtocolMessage = new DriverStatusRequestMessage - protected override def fieldWithName(field: String): StandaloneRestProtocolField = + protected override def fieldWithName(field: String): SubmitRestProtocolField = DriverStatusRequestField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala index 020f80c71ffc..8d45ceaa3ee2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.rest /** * A field used in a DriverStatusResponseMessage. */ -private[spark] abstract class DriverStatusResponseField extends StandaloneRestProtocolField -private[spark] object DriverStatusResponseField extends StandaloneRestProtocolFieldCompanion { +private[spark] abstract class DriverStatusResponseField extends SubmitRestProtocolField +private[spark] object DriverStatusResponseField extends SubmitRestProtocolFieldCompanion { case object ACTION extends DriverStatusResponseField case object SPARK_VERSION extends DriverStatusResponseField case object MESSAGE extends DriverStatusResponseField @@ -36,16 +36,16 @@ private[spark] object DriverStatusResponseField extends StandaloneRestProtocolFi } /** - * A message sent from the standalone Master in response to a DriverStatusResponseMessage. + * A message sent from the cluster manager in response to a DriverStatusResponseMessage. */ -private[spark] class DriverStatusResponseMessage extends StandaloneRestProtocolMessage( - StandaloneRestProtocolAction.DRIVER_STATUS_RESPONSE, +private[spark] class DriverStatusResponseMessage extends SubmitRestProtocolMessage( + SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE, DriverStatusResponseField.ACTION, DriverStatusResponseField.requiredFields) -private[spark] object DriverStatusResponseMessage extends StandaloneRestProtocolMessageCompanion { - protected override def newMessage(): StandaloneRestProtocolMessage = +private[spark] object DriverStatusResponseMessage extends SubmitRestProtocolMessageCompanion { + protected override def newMessage(): SubmitRestProtocolMessage = new DriverStatusResponseMessage - protected override def fieldWithName(field: String): StandaloneRestProtocolField = + protected override def fieldWithName(field: String): SubmitRestProtocolField = DriverStatusResponseField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala index 7d8dda73414a..020c7c28dc36 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.rest /** * A field used in a ErrorMessage. */ -private[spark] abstract class ErrorField extends StandaloneRestProtocolField -private[spark] object ErrorField extends StandaloneRestProtocolFieldCompanion { +private[spark] abstract class ErrorField extends SubmitRestProtocolField +private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion { case object ACTION extends ErrorField case object SPARK_VERSION extends ErrorField case object MESSAGE extends ErrorField @@ -30,15 +30,15 @@ private[spark] object ErrorField extends StandaloneRestProtocolFieldCompanion { } /** - * An error message exchanged in the standalone REST protocol. + * An error message exchanged in the stable application submission protocol. */ -private[spark] class ErrorMessage extends StandaloneRestProtocolMessage( - StandaloneRestProtocolAction.ERROR, +private[spark] class ErrorMessage extends SubmitRestProtocolMessage( + SubmitRestProtocolAction.ERROR, ErrorField.ACTION, ErrorField.requiredFields) -private[spark] object ErrorMessage extends StandaloneRestProtocolMessageCompanion { - protected override def newMessage(): StandaloneRestProtocolMessage = new ErrorMessage - protected override def fieldWithName(field: String): StandaloneRestProtocolField = +private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion { + protected override def newMessage(): SubmitRestProtocolMessage = new ErrorMessage + protected override def fieldWithName(field: String): SubmitRestProtocolField = ErrorField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala index 978f1d75498e..cdca193e6aed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.rest /** * A field used in a KillDriverRequestMessage. */ -private[spark] abstract class KillDriverRequestField extends StandaloneRestProtocolField -private[spark] object KillDriverRequestField extends StandaloneRestProtocolFieldCompanion { +private[spark] abstract class KillDriverRequestField extends SubmitRestProtocolField +private[spark] object KillDriverRequestField extends SubmitRestProtocolFieldCompanion { case object ACTION extends KillDriverRequestField case object SPARK_VERSION extends KillDriverRequestField case object MESSAGE extends KillDriverRequestField @@ -32,16 +32,16 @@ private[spark] object KillDriverRequestField extends StandaloneRestProtocolField } /** - * A request sent to the standalone Master to kill a driver. + * A request sent to the cluster manager to kill a driver. */ -private[spark] class KillDriverRequestMessage extends StandaloneRestProtocolMessage( - StandaloneRestProtocolAction.KILL_DRIVER_REQUEST, +private[spark] class KillDriverRequestMessage extends SubmitRestProtocolMessage( + SubmitRestProtocolAction.KILL_DRIVER_REQUEST, KillDriverRequestField.ACTION, KillDriverRequestField.requiredFields) -private[spark] object KillDriverRequestMessage extends StandaloneRestProtocolMessageCompanion { - protected override def newMessage(): StandaloneRestProtocolMessage = +private[spark] object KillDriverRequestMessage extends SubmitRestProtocolMessageCompanion { + protected override def newMessage(): SubmitRestProtocolMessage = new KillDriverRequestMessage - protected override def fieldWithName(field: String): StandaloneRestProtocolField = + protected override def fieldWithName(field: String): SubmitRestProtocolField = KillDriverRequestField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index 69b9a4f4bdab..60eac1b1f26d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.rest /** * A field used in a KillDriverResponseMessage. */ -private[spark] abstract class KillDriverResponseField extends StandaloneRestProtocolField -private[spark] object KillDriverResponseField extends StandaloneRestProtocolFieldCompanion { +private[spark] abstract class KillDriverResponseField extends SubmitRestProtocolField +private[spark] object KillDriverResponseField extends SubmitRestProtocolFieldCompanion { case object ACTION extends KillDriverResponseField case object SPARK_VERSION extends KillDriverResponseField case object MESSAGE extends KillDriverResponseField @@ -33,16 +33,16 @@ private[spark] object KillDriverResponseField extends StandaloneRestProtocolFiel } /** - * A message sent from the standalone Master in response to a KillDriverResponseMessage. + * A message sent from the cluster manager in response to a KillDriverResponseMessage. */ -private[spark] class KillDriverResponseMessage extends StandaloneRestProtocolMessage( - StandaloneRestProtocolAction.KILL_DRIVER_RESPONSE, +private[spark] class KillDriverResponseMessage extends SubmitRestProtocolMessage( + SubmitRestProtocolAction.KILL_DRIVER_RESPONSE, KillDriverResponseField.ACTION, KillDriverResponseField.requiredFields) -private[spark] object KillDriverResponseMessage extends StandaloneRestProtocolMessageCompanion { - protected override def newMessage(): StandaloneRestProtocolMessage = +private[spark] object KillDriverResponseMessage extends SubmitRestProtocolMessageCompanion { + protected override def newMessage(): SubmitRestProtocolMessage = new KillDriverResponseMessage - protected override def fieldWithName(field: String): StandaloneRestProtocolField = + protected override def fieldWithName(field: String): SubmitRestProtocolField = KillDriverResponseField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 6059344d93b6..cb1aba45a218 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -17,52 +17,22 @@ package org.apache.spark.deploy.rest -import java.io.DataOutputStream import java.net.URL -import java.net.HttpURLConnection - -import scala.io.Source - -import com.google.common.base.Charsets import org.apache.spark.{SPARK_VERSION => sparkVersion} import org.apache.spark.deploy.SparkSubmitArguments import org.apache.spark.util.Utils /** - * A client that submits Spark applications using a stable REST protocol in standalone - * cluster mode. This client is intended to communicate with the StandaloneRestServer. + * A client that submits Spark applications to the standalone Master using a stable + * REST protocol. This client is intended to communicate with the StandaloneRestServer, + * and currently only used in cluster mode. */ -private[spark] class StandaloneRestClient { - - def submitDriver(args: SparkSubmitArguments): Unit = { - validateSubmitArguments(args) - val url = getHttpUrl(args.master) - val request = constructSubmitRequest(args) - val response = sendHttp(url, request) - println(response.toJson) - } - - def killDriver(master: String, driverId: String): Unit = { - validateMaster(master) - val url = getHttpUrl(master) - val request = constructKillRequest(master, driverId) - val response = sendHttp(url, request) - println(response.toJson) - } - - def requestDriverStatus(master: String, driverId: String): Unit = { - validateMaster(master) - val url = getHttpUrl(master) - val request = constructStatusRequest(master, driverId) - val response = sendHttp(url, request) - println(response.toJson) - } +private[spark] class StandaloneRestClient extends SubmitRestClient { - /** - * Construct a submit driver request message. - */ - private def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage = { + /** Construct a submit driver request message. */ + override protected def constructSubmitRequest( + args: SparkSubmitArguments): SubmitDriverRequestMessage = { import SubmitDriverRequestField._ val driverMemory = Option(args.driverMemory) .map { m => Utils.memoryStringToMb(m).toString } @@ -78,7 +48,6 @@ private[spark] class StandaloneRestClient { .setFieldIfNotNull(MAIN_CLASS, args.mainClass) .setFieldIfNotNull(JARS, args.jars) .setFieldIfNotNull(FILES, args.files) - .setFieldIfNotNull(PY_FILES, args.pyFiles) .setFieldIfNotNull(DRIVER_MEMORY, driverMemory) .setFieldIfNotNull(DRIVER_CORES, args.driverCores) .setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions) @@ -97,10 +66,8 @@ private[spark] class StandaloneRestClient { message.validate() } - /** - * Construct a kill driver request message. - */ - private def constructKillRequest( + /** Construct a kill driver request message. */ + override protected def constructKillRequest( master: String, driverId: String): KillDriverRequestMessage = { import KillDriverRequestField._ @@ -111,10 +78,8 @@ private[spark] class StandaloneRestClient { .validate() } - /** - * Construct a driver status request message. - */ - private def constructStatusRequest( + /** Construct a driver status request message. */ + override protected def constructStatusRequest( master: String, driverId: String): DriverStatusRequestMessage = { import DriverStatusRequestField._ @@ -125,67 +90,23 @@ private[spark] class StandaloneRestClient { .validate() } - /** - * Send the provided request in an HTTP message to the given URL. - * Return the response received from the REST server. - */ - private def sendHttp( - url: URL, - request: StandaloneRestProtocolMessage): StandaloneRestProtocolMessage = { - val conn = url.openConnection().asInstanceOf[HttpURLConnection] - conn.setRequestMethod("POST") - conn.setRequestProperty("Content-Type", "application/json") - conn.setRequestProperty("charset", "utf-8") - conn.setDoOutput(true) - println("Sending this JSON blob to server:\n" + request.toJson) - val content = request.toJson.getBytes(Charsets.UTF_8) - val out = new DataOutputStream(conn.getOutputStream) - out.write(content) - out.close() - val response = Source.fromInputStream(conn.getInputStream).mkString - StandaloneRestProtocolMessage.fromJson(response) - } - - /** - * Throw an exception if this is not standalone cluster mode. - */ - private def validateSubmitArguments(args: SparkSubmitArguments): Unit = { - validateMaster(args.master) - validateDeployMode(args.deployMode) - } - - /** - * Throw an exception if this is not standalone mode. - */ - private def validateMaster(master: String): Unit = { + /** Throw an exception if this is not standalone mode. */ + override protected def validateMaster(master: String): Unit = { if (!master.startsWith("spark://")) { throw new IllegalArgumentException("This REST client is only supported in standalone mode.") } } - /** - * Throw an exception if this is not cluster deploy mode. - */ - private def validateDeployMode(deployMode: String): Unit = { + /** Throw an exception if this is not cluster deploy mode. */ + override protected def validateDeployMode(deployMode: String): Unit = { if (deployMode != "cluster") { throw new IllegalArgumentException("This REST client is only supported in cluster mode.") } } - /** - * Extract the URL portion of the master address. - */ - private def getHttpUrl(master: String): URL = { + /** Extract the URL portion of the master address. */ + override protected def getHttpUrl(master: String): URL = { validateMaster(master) new URL("http://" + master.stripPrefix("spark://")) } } - -object StandaloneRestClient { - def main(args: Array[String]): Unit = { - assert(args.length > 0) - //val client = new StandaloneRestClient - //client.submitDriver("spark://" + args(0)) - println("Done.") - } -} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 344a3ef89a4d..e160b8e0a0d2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -17,108 +17,164 @@ package org.apache.spark.deploy.rest -import java.io.DataOutputStream -import java.net.InetSocketAddress -import javax.servlet.http.{HttpServletRequest, HttpServletResponse} +import java.io.File -import scala.io.Source +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import com.google.common.base.Charsets -import org.eclipse.jetty.server.{Request, Server} -import org.eclipse.jetty.server.handler.AbstractHandler - -import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging} -import org.apache.spark.deploy.master.Master +import org.apache.spark.{SPARK_VERSION => sparkVersion} +import org.apache.spark.SparkConf import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.master.Master +import akka.actor.ActorRef /** * A server that responds to requests submitted by the StandaloneRestClient. + * This is intended to be embedded in the standalone Master. */ -private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) { - val server = new Server(new InetSocketAddress(host, requestedPort)) - server.setHandler(new StandaloneRestServerHandler(master)) - server.start() +private[spark] class StandaloneRestServer( + master: Master, + host: String, + requestedPort: Int) + extends SubmitRestServer(host, requestedPort) { + override protected val handler = new StandaloneRestServerHandler(master) } /** - * A Jetty handler that responds to requests submitted via the standalone REST protocol. + * A handler that responds to requests submitted to the standalone Master + * through the REST protocol. */ -private[spark] abstract class StandaloneRestHandler(master: Master) - extends AbstractHandler with Logging { +private[spark] class StandaloneRestServerHandler( + conf: SparkConf, + masterActor: ActorRef, + masterUrl: String) + extends SubmitRestServerHandler { - private implicit val askTimeout = AkkaUtils.askTimeout(master.conf) + private implicit val askTimeout = AkkaUtils.askTimeout(conf) + + def this(master: Master) = { + this(master.conf, master.self, master.masterUrl) + } /** Handle a request to submit a driver. */ - protected def handleSubmit(request: SubmitDriverRequestMessage): SubmitDriverResponseMessage + override protected def handleSubmit( + request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { + import SubmitDriverResponseField._ + val driverDescription = buildDriverDescription(request) + val response = AkkaUtils.askWithReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription), masterActor, askTimeout) + new SubmitDriverResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, response.message) + .setField(MASTER, masterUrl) + .setField(SUCCESS, response.success.toString) + .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) + .validate() + } + /** Handle a request to kill a driver. */ - protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage - /** Handle a request for a driver's status. */ - protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage + override protected def handleKill( + request: KillDriverRequestMessage): KillDriverResponseMessage = { + import KillDriverResponseField._ + val driverId = request.getFieldNotNull(KillDriverRequestField.DRIVER_ID) + val response = AkkaUtils.askWithReply[KillDriverResponse]( + RequestKillDriver(driverId), masterActor, askTimeout) + new KillDriverResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, response.message) + .setField(MASTER, masterUrl) + .setField(DRIVER_ID, driverId) + .setField(SUCCESS, response.success.toString) + .validate() + } - /** - * Handle a request submitted by the StandaloneRestClient. - */ - override def handle( - target: String, - baseRequest: Request, - request: HttpServletRequest, - response: HttpServletResponse): Unit = { - try { - val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString - val requestMessage = StandaloneRestProtocolMessage.fromJson(requestMessageJson) - val responseMessage = constructResponseMessage(requestMessage) - response.setContentType("application/json") - response.setCharacterEncoding("utf-8") - response.setStatus(HttpServletResponse.SC_OK) - val content = responseMessage.toJson.getBytes(Charsets.UTF_8) - val out = new DataOutputStream(response.getOutputStream) - out.write(content) - out.close() - baseRequest.setHandled(true) - } catch { - case e: Exception => logError("Exception while handling request", e) - } + /** Handle a request for a driver's status. */ + override protected def handleStatus( + request: DriverStatusRequestMessage): DriverStatusResponseMessage = { + import DriverStatusResponseField._ + // TODO: Actually look up the status of the driver + val master = request.getField(DriverStatusRequestField.MASTER) + val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) + val driverState = "HEALTHY" + new DriverStatusResponseMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MASTER, master) + .setField(DRIVER_ID, driverId) + .setField(DRIVER_STATE, driverState) + .validate() } /** - * Construct the appropriate response message based on the type of the request message. - * If an IllegalArgumentException is thrown in the process, construct an error message. + * Build a driver description from the fields specified in the submit request. + * This currently does not consider fields used by python applications since + * python is not supported in standalone cluster mode yet. */ - private def constructResponseMessage( - request: StandaloneRestProtocolMessage): StandaloneRestProtocolMessage = { - // If the request is sent via the StandaloneRestClient, it should have already been - // validated remotely. In case this is not true, validate the request here to guard - // against potential NPEs. If validation fails, return an ERROR message to the sender. - try { - request.validate() - request match { - case submit: SubmitDriverRequestMessage => handleSubmit(submit) - case kill: KillDriverRequestMessage => handleKill(kill) - case status: DriverStatusRequestMessage => handleStatus(status) - case unexpected => handleError( - s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + private def buildDriverDescription(request: SubmitDriverRequestMessage): DriverDescription = { + import SubmitDriverRequestField._ + + // Required fields + val appName = request.getFieldNotNull(APP_NAME) + val appResource = request.getFieldNotNull(APP_RESOURCE) + + // Since standalone cluster mode does not yet support python, + // we treat the main class as required + val mainClass = request.getFieldNotNull(MAIN_CLASS) + + // Optional fields + val jars = request.getFieldOption(JARS) + val files = request.getFieldOption(FILES) + val driverMemory = request.getFieldOption(DRIVER_MEMORY) + val driverCores = request.getFieldOption(DRIVER_CORES) + val driverExtraJavaOptions = request.getFieldOption(DRIVER_EXTRA_JAVA_OPTIONS) + val driverExtraClassPath = request.getFieldOption(DRIVER_EXTRA_CLASS_PATH) + val driverExtraLibraryPath = request.getFieldOption(DRIVER_EXTRA_LIBRARY_PATH) + val superviseDriver = request.getFieldOption(SUPERVISE_DRIVER) + val executorMemory = request.getFieldOption(EXECUTOR_MEMORY) + val totalExecutorCores = request.getFieldOption(TOTAL_EXECUTOR_CORES) + + // Parse special fields that take in parameters + val conf = new SparkConf(false) + val env = new mutable.HashMap[String, String] + val appArgs = new ArrayBuffer[(Int, String)] + request.getFields.foreach { case (k, v) => + k match { + case APP_ARG(index) => appArgs += ((index, v)) + case SPARK_PROPERTY(propKey) => conf.set(propKey, v) + case ENVIRONMENT_VARIABLE(envKey) => env(envKey) = v + case _ => } - } catch { - // Propagate exception to user in an ErrorMessage. If the construction of the - // ErrorMessage itself throws an exception, log the exception and ignore the request. - case e: IllegalArgumentException => handleError(e.getMessage) } - } - /** Construct an error message to signal the fact that an exception has been thrown. */ - private def handleError(message: String): ErrorMessage = { - import ErrorField._ - new ErrorMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MESSAGE, message) - .validate() + // Use the actual master URL instead of the one that refers to this REST server + // Otherwise, once the driver is launched it will contact with the wrong server + conf.set("spark.master", masterUrl) + conf.set("spark.app.name", appName) + conf.set("spark.jars", jars.map(_ + ",").getOrElse("") + appResource) // include app resource + files.foreach { f => conf.set("spark.files", f) } + driverExtraJavaOptions.foreach { j => conf.set("spark.driver.extraJavaOptions", j) } + driverExtraClassPath.foreach { cp => conf.set("spark.driver.extraClassPath", cp) } + driverExtraLibraryPath.foreach { lp => conf.set("spark.driver.extraLibraryPath", lp) } + executorMemory.foreach { m => conf.set("spark.executor.memory", m) } + totalExecutorCores.foreach { c => conf.set("spark.cores.max", c) } + + // Construct driver description and submit it + val actualDriverMemory = driverMemory.map(_.toInt).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) + val actualAppArgs = appArgs.sortBy(_._1).map(_._2) // sort by index, map to value + val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) + val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) + val sparkJavaOpts = Utils.sparkJavaOpts(conf) + val javaOpts = sparkJavaOpts ++ extraJavaOpts + val command = new Command( + "org.apache.spark.deploy.worker.DriverWrapper", + Seq("{{WORKER_URL}}", mainClass) ++ actualAppArgs, + env, extraClassPath, extraLibraryPath, javaOpts) + new DriverDescription( + appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) } } - -//object StandaloneRestServer { -// def main(args: Array[String]): Unit = { -// println("Hey boy I'm starting a server.") -// new StandaloneRestServer(6677) -// readLine() -// } -//} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala deleted file mode 100644 index e11698e51bf1..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServerHandler.scala +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -import java.io.File - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.{SPARK_VERSION => sparkVersion} -import org.apache.spark.SparkConf -import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.deploy.{Command, DriverDescription} -import org.apache.spark.deploy.ClientArguments._ -import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.deploy.master.Master - -/** - * - */ -private[spark] class StandaloneRestServerHandler(master: Master) - extends StandaloneRestHandler(master) { - - private implicit val askTimeout = AkkaUtils.askTimeout(master.conf) - - override protected def handleSubmit( - request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { - import SubmitDriverResponseField._ - val driverDescription = buildDriverDescription(request) - val response = AkkaUtils.askWithReply[SubmitDriverResponse]( - RequestSubmitDriver(driverDescription), master.self, askTimeout) - new SubmitDriverResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MESSAGE, response.message) - .setField(MASTER, master.masterUrl) - .setField(SUCCESS, response.success.toString) - .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) - .validate() - } - - override protected def handleKill( - request: KillDriverRequestMessage): KillDriverResponseMessage = { - import KillDriverResponseField._ - val driverId = request.getFieldNotNull(KillDriverRequestField.DRIVER_ID) - val response = AkkaUtils.askWithReply[KillDriverResponse]( - RequestKillDriver(driverId), master.self, askTimeout) - new KillDriverResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MESSAGE, response.message) - .setField(MASTER, master.masterUrl) - .setField(DRIVER_ID, driverId) - .setField(SUCCESS, response.success.toString) - .validate() - } - - override protected def handleStatus( - request: DriverStatusRequestMessage): DriverStatusResponseMessage = { - import DriverStatusResponseField._ - // TODO: Actually look up the status of the driver - val master = request.getField(DriverStatusRequestField.MASTER) - val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) - val driverState = "HEALTHY" - new DriverStatusResponseMessage() - .setField(SPARK_VERSION, sparkVersion) - .setField(MASTER, master) - .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() - } - - private def buildDriverDescription(request: SubmitDriverRequestMessage): DriverDescription = { - import SubmitDriverRequestField._ - - // Required fields - //val _master = request.getFieldNotNull(MASTER) - val appName = request.getFieldNotNull(APP_NAME) - val appResource = request.getFieldNotNull(APP_RESOURCE) - - // Since standalone cluster mode does not yet support python, - // we treat the main class as required - val mainClass = request.getFieldNotNull(MAIN_CLASS) - - // Optional fields - val jars = request.getFieldOption(JARS) - val files = request.getFieldOption(FILES) - val driverMemory = request.getFieldOption(DRIVER_MEMORY) - val driverCores = request.getFieldOption(DRIVER_CORES) - val driverExtraJavaOptions = request.getFieldOption(DRIVER_EXTRA_JAVA_OPTIONS) - val driverExtraClassPath = request.getFieldOption(DRIVER_EXTRA_CLASS_PATH) - val driverExtraLibraryPath = request.getFieldOption(DRIVER_EXTRA_LIBRARY_PATH) - val superviseDriver = request.getFieldOption(SUPERVISE_DRIVER) - val executorMemory = request.getFieldOption(EXECUTOR_MEMORY) - val totalExecutorCores = request.getFieldOption(TOTAL_EXECUTOR_CORES) - - // Parse special fields that take in parameters - val conf = new SparkConf(false) - val env = new mutable.HashMap[String, String] - val appArgs = new ArrayBuffer[(Int, String)] - request.getFields.foreach { case (k, v) => - println(s"> Found this field: $k = $v") - k match { - case APP_ARG(index) => appArgs += ((index, v)) - case SPARK_PROPERTY(propKey) => conf.set(propKey, v) - case ENVIRONMENT_VARIABLE(envKey) => env(envKey) = v - case _ => - } - } - - // Use the actual master URL instead of the one that refers to this REST server - // Otherwise, once the driver is launched it will contact with the wrong server - conf.set("spark.master", master.masterUrl) - conf.set("spark.app.name", appName) - conf.set("spark.jars", jars.map(_ + ",").getOrElse("") + appResource) // include app resource - files.foreach { f => conf.set("spark.files", f) } - driverExtraJavaOptions.foreach { j => conf.set("spark.driver.extraJavaOptions", j) } - driverExtraClassPath.foreach { cp => conf.set("spark.driver.extraClassPath", cp) } - driverExtraLibraryPath.foreach { lp => conf.set("spark.driver.extraLibraryPath", lp) } - executorMemory.foreach { m => conf.set("spark.executor.memory", m) } - totalExecutorCores.foreach { c => conf.set("spark.cores.max", c) } - - // Construct driver description and submit it - val actualDriverMemory = driverMemory.map(_.toInt).getOrElse(DEFAULT_MEMORY) - val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) - val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) - val actualAppArgs = appArgs.sortBy(_._1).map(_._2) // sort by index, map to value - val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) - val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) - val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) - val sparkJavaOpts = Utils.sparkJavaOpts(conf) - val javaOpts = sparkJavaOpts ++ extraJavaOpts - val command = new Command( - "org.apache.spark.deploy.worker.DriverWrapper", - Seq("{{WORKER_URL}}", mainClass) ++ actualAppArgs, - env, extraClassPath, extraLibraryPath, javaOpts) - new DriverDescription( - appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index 72f92f2c0d49..e55cb69ed112 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -24,8 +24,8 @@ import org.apache.spark.util.Utils /** * A field used in a SubmitDriverRequestMessage. */ -private[spark] abstract class SubmitDriverRequestField extends StandaloneRestProtocolField -private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFieldCompanion { +private[spark] abstract class SubmitDriverRequestField extends SubmitRestProtocolField +private[spark] object SubmitDriverRequestField extends SubmitRestProtocolFieldCompanion { case object ACTION extends SubmitDriverRequestField case object SPARK_VERSION extends SubmitDriverRequestField case object MESSAGE extends SubmitDriverRequestField @@ -61,7 +61,7 @@ private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFie // Because certain fields taken in arguments, we cannot simply rely on the // list of all fields to reconstruct a field from its String representation. // Instead, we must treat these fields as special cases and match on their prefixes. - override def withName(field: String): StandaloneRestProtocolField = { + override def withName(field: String): SubmitRestProtocolField = { def buildRegex(obj: AnyRef): Regex = s"${Utils.getFormattedClassName(obj)}_(.*)".r val appArg = buildRegex(APP_ARG) val sparkProperty = buildRegex(SPARK_PROPERTY) @@ -76,10 +76,10 @@ private[spark] object SubmitDriverRequestField extends StandaloneRestProtocolFie } /** - * A request sent to the standalone Master to submit a driver. + * A request sent to the cluster manager to submit a driver. */ -private[spark] class SubmitDriverRequestMessage extends StandaloneRestProtocolMessage( - StandaloneRestProtocolAction.SUBMIT_DRIVER_REQUEST, +private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessage( + SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST, SubmitDriverRequestField.ACTION, SubmitDriverRequestField.requiredFields) { @@ -95,9 +95,9 @@ private[spark] class SubmitDriverRequestMessage extends StandaloneRestProtocolMe } } -private[spark] object SubmitDriverRequestMessage extends StandaloneRestProtocolMessageCompanion { - protected override def newMessage(): StandaloneRestProtocolMessage = +private[spark] object SubmitDriverRequestMessage extends SubmitRestProtocolMessageCompanion { + protected override def newMessage(): SubmitRestProtocolMessage = new SubmitDriverRequestMessage - protected override def fieldWithName(field: String): StandaloneRestProtocolField = + protected override def fieldWithName(field: String): SubmitRestProtocolField = SubmitDriverRequestField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index e656c35ad965..a877ec4fccff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -20,8 +20,8 @@ package org.apache.spark.deploy.rest /** * A field used in a SubmitDriverResponseMessage. */ -private[spark] abstract class SubmitDriverResponseField extends StandaloneRestProtocolField -private[spark] object SubmitDriverResponseField extends StandaloneRestProtocolFieldCompanion { +private[spark] abstract class SubmitDriverResponseField extends SubmitRestProtocolField +private[spark] object SubmitDriverResponseField extends SubmitRestProtocolFieldCompanion { case object ACTION extends SubmitDriverResponseField case object SPARK_VERSION extends SubmitDriverResponseField case object MESSAGE extends SubmitDriverResponseField @@ -33,16 +33,16 @@ private[spark] object SubmitDriverResponseField extends StandaloneRestProtocolFi } /** - * A message sent from the standalone Master in response to a SubmitDriverRequestMessage. + * A message sent from the cluster manager in response to a SubmitDriverRequestMessage. */ -private[spark] class SubmitDriverResponseMessage extends StandaloneRestProtocolMessage( - StandaloneRestProtocolAction.SUBMIT_DRIVER_RESPONSE, +private[spark] class SubmitDriverResponseMessage extends SubmitRestProtocolMessage( + SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE, SubmitDriverResponseField.ACTION, SubmitDriverResponseField.requiredFields) -private[spark] object SubmitDriverResponseMessage extends StandaloneRestProtocolMessageCompanion { - protected override def newMessage(): StandaloneRestProtocolMessage = +private[spark] object SubmitDriverResponseMessage extends SubmitRestProtocolMessageCompanion { + protected override def newMessage(): SubmitRestProtocolMessage = new SubmitDriverResponseMessage - protected override def fieldWithName(field: String): StandaloneRestProtocolField = + protected override def fieldWithName(field: String): SubmitRestProtocolField = SubmitDriverResponseField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala new file mode 100644 index 000000000000..4bd0d022e607 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.DataOutputStream +import java.net.{HttpURLConnection, URL} + +import scala.io.Source + +import com.google.common.base.Charsets + +import org.apache.spark.Logging +import org.apache.spark.deploy.SparkSubmitArguments + +/** + * An abstract client that submits Spark applications using a stable REST protocol. + * This client is intended to communicate with the SubmitRestServer. + */ +private[spark] abstract class SubmitRestClient extends Logging { + + /** Request that the REST server submits a driver specified by the provided arguments. */ + def submitDriver(args: SparkSubmitArguments): Unit = { + validateSubmitArguments(args) + val url = getHttpUrl(args.master) + val request = constructSubmitRequest(args) + logInfo(s"Submitting a request to launch a driver in ${args.master}.") + sendHttp(url, request) + } + + /** Request that the REST server kills the specified driver. */ + def killDriver(master: String, driverId: String): Unit = { + validateMaster(master) + val url = getHttpUrl(master) + val request = constructKillRequest(master, driverId) + logInfo(s"Submitting a request to kill driver $driverId in $master.") + sendHttp(url, request) + } + + /** Request the status of the specified driver from the REST server. */ + def requestDriverStatus(master: String, driverId: String): Unit = { + validateMaster(master) + val url = getHttpUrl(master) + val request = constructStatusRequest(master, driverId) + logInfo(s"Submitting a request for the status of driver $driverId in $master.") + sendHttp(url, request) + } + + /** Return the HTTP URL of the REST server that corresponds to the given master URL. */ + protected def getHttpUrl(master: String): URL + + // Construct the appropriate type of message based on the request type + protected def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage + protected def constructKillRequest(master: String, driverId: String): KillDriverRequestMessage + protected def constructStatusRequest(master: String, driverId: String): DriverStatusRequestMessage + + // If the provided arguments are not as expected, throw an exception + protected def validateMaster(master: String): Unit + protected def validateDeployMode(deployMode: String): Unit + protected def validateSubmitArguments(args: SparkSubmitArguments): Unit = { + validateMaster(args.master) + validateDeployMode(args.deployMode) + } + + /** + * Send the provided request in an HTTP message to the given URL. + * Return the response received from the REST server. + */ + private def sendHttp(url: URL, request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + val requestJson = request.toJson + logDebug(s"Sending the following request to the REST server:\n$requestJson") + val out = new DataOutputStream(conn.getOutputStream) + out.write(requestJson.getBytes(Charsets.UTF_8)) + out.close() + val responseJson = Source.fromInputStream(conn.getInputStream).mkString + logDebug(s"Response from the REST server:\n$responseJson") + SubmitRestProtocolMessage.fromJson(responseJson) + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala similarity index 72% rename from core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala rename to core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 7945271a870f..77d38c40ae80 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -27,11 +27,11 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.util.Utils /** - * A field used in a StandaloneRestProtocolMessage. + * A field used in a SubmitRestProtocolMessage. * Three special fields ACTION, SPARK_VERSION, and MESSAGE are common across all messages. */ -private[spark] abstract class StandaloneRestProtocolField -private[spark] object StandaloneRestProtocolField { +private[spark] abstract class SubmitRestProtocolField +private[spark] object SubmitRestProtocolField { /** Return whether the provided field name refers to the ACTION field. */ def isActionField(field: String): Boolean = field == "ACTION" } @@ -39,54 +39,54 @@ private[spark] object StandaloneRestProtocolField { /** * All possible values of the ACTION field. */ -private[spark] object StandaloneRestProtocolAction extends Enumeration { - type StandaloneRestProtocolAction = Value +private[spark] object SubmitRestProtocolAction extends Enumeration { + type SubmitRestProtocolAction = Value val SUBMIT_DRIVER_REQUEST, SUBMIT_DRIVER_RESPONSE = Value val KILL_DRIVER_REQUEST, KILL_DRIVER_RESPONSE = Value val DRIVER_STATUS_REQUEST, DRIVER_STATUS_RESPONSE = Value val ERROR = Value } -import StandaloneRestProtocolAction.StandaloneRestProtocolAction +import SubmitRestProtocolAction.SubmitRestProtocolAction /** - * A general message exchanged in the standalone REST protocol. + * A general message exchanged in the stable application submission REST protocol. * * The message is represented by a set of fields in the form of key value pairs. * Each message must contain an ACTION field, which should have only one possible value * for each type of message. For compatibility with older versions of Spark, existing * fields must not be removed or modified, though new fields can be added as necessary. */ -private[spark] abstract class StandaloneRestProtocolMessage( - action: StandaloneRestProtocolAction, - actionField: StandaloneRestProtocolField, - requiredFields: Seq[StandaloneRestProtocolField]) { +private[spark] abstract class SubmitRestProtocolMessage( + action: SubmitRestProtocolAction, + actionField: SubmitRestProtocolField, + requiredFields: Seq[SubmitRestProtocolField]) { - import StandaloneRestProtocolField._ + import SubmitRestProtocolField._ private val className = Utils.getFormattedClassName(this) - protected val fields = new mutable.HashMap[StandaloneRestProtocolField, String] + protected val fields = new mutable.HashMap[SubmitRestProtocolField, String] // Set the action field fields(actionField) = action.toString /** Return all fields currently set in this message. */ - def getFields: Map[StandaloneRestProtocolField, String] = fields + def getFields: Map[SubmitRestProtocolField, String] = fields /** Return the value of the given field. If the field is not present, return null. */ - def getField(key: StandaloneRestProtocolField): String = getFieldOption(key).orNull + def getField(key: SubmitRestProtocolField): String = getFieldOption(key).orNull /** Return the value of the given field. If the field is not present, throw an exception. */ - def getFieldNotNull(key: StandaloneRestProtocolField): String = { + def getFieldNotNull(key: SubmitRestProtocolField): String = { getFieldOption(key).getOrElse { throw new IllegalArgumentException(s"Field $key is not set in message $className") } } /** Return the value of the given field as an option. */ - def getFieldOption(key: StandaloneRestProtocolField): Option[String] = fields.get(key) + def getFieldOption(key: SubmitRestProtocolField): Option[String] = fields.get(key) /** Assign the given value to the field, overriding any existing value. */ - def setField(key: StandaloneRestProtocolField, value: String): this.type = { + def setField(key: SubmitRestProtocolField, value: String): this.type = { if (key == actionField) { throw new SparkException("Setting the ACTION field is only allowed during instantiation.") } @@ -95,7 +95,7 @@ private[spark] abstract class StandaloneRestProtocolMessage( } /** Assign the given value to the field only if the value is not null. */ - def setFieldIfNotNull(key: StandaloneRestProtocolField, value: String): this.type = { + def setFieldIfNotNull(key: SubmitRestProtocolField, value: String): this.type = { if (value != null) { setField(key, value) } @@ -145,22 +145,22 @@ private[spark] abstract class StandaloneRestProtocolMessage( } } -private[spark] object StandaloneRestProtocolMessage { - import StandaloneRestProtocolField._ - import StandaloneRestProtocolAction._ +private[spark] object SubmitRestProtocolMessage { + import SubmitRestProtocolField._ + import SubmitRestProtocolAction._ /** - * Construct a StandaloneRestProtocolMessage from JSON. + * Construct a SubmitRestProtocolMessage from JSON. * This uses the ACTION field to determine the type of the message to reconstruct. * If such a field does not exist in the JSON, throw an exception. */ - def fromJson(json: String): StandaloneRestProtocolMessage = { + def fromJson(json: String): SubmitRestProtocolMessage = { val fields = org.apache.spark.util.JsonProtocol.mapFromJson(parse(json)) val action = fields .flatMap { case (k, v) => if (isActionField(k)) Some(v) else None } .headOption .getOrElse { throw new IllegalArgumentException(s"ACTION not found in message:\n$json") } - StandaloneRestProtocolAction.withName(action) match { + SubmitRestProtocolAction.withName(action) match { case SUBMIT_DRIVER_REQUEST => SubmitDriverRequestMessage.fromFields(fields) case SUBMIT_DRIVER_RESPONSE => SubmitDriverResponseMessage.fromFields(fields) case KILL_DRIVER_REQUEST => KillDriverRequestMessage.fromFields(fields) @@ -173,22 +173,22 @@ private[spark] object StandaloneRestProtocolMessage { } /** - * A trait that holds common methods for StandaloneRestProtocolField companion objects. + * A trait that holds common methods for SubmitRestProtocolField companion objects. * * It is necessary to keep track of all fields that belong to this object in order to * reconstruct the fields from their names. */ -private[spark] trait StandaloneRestProtocolFieldCompanion { - val requiredFields: Seq[StandaloneRestProtocolField] - val optionalFields: Seq[StandaloneRestProtocolField] +private[spark] trait SubmitRestProtocolFieldCompanion { + val requiredFields: Seq[SubmitRestProtocolField] + val optionalFields: Seq[SubmitRestProtocolField] /** Listing of all fields indexed by the field's string representation. */ - private lazy val allFieldsMap: Map[String, StandaloneRestProtocolField] = { + private lazy val allFieldsMap: Map[String, SubmitRestProtocolField] = { (requiredFields ++ optionalFields).map { f => (f.toString, f) }.toMap } - /** Return a StandaloneRestProtocolField from its string representation. */ - def withName(field: String): StandaloneRestProtocolField = { + /** Return a SubmitRestProtocolField from its string representation. */ + def withName(field: String): SubmitRestProtocolField = { allFieldsMap.get(field).getOrElse { throw new IllegalArgumentException(s"Unknown field $field") } @@ -196,19 +196,19 @@ private[spark] trait StandaloneRestProtocolFieldCompanion { } /** - * A trait that holds common methods for StandaloneRestProtocolMessage companion objects. + * A trait that holds common methods for SubmitRestProtocolMessage companion objects. */ -private[spark] trait StandaloneRestProtocolMessageCompanion extends Logging { - import StandaloneRestProtocolField._ +private[spark] trait SubmitRestProtocolMessageCompanion extends Logging { + import SubmitRestProtocolField._ /** Construct a new message of the relevant type. */ - protected def newMessage(): StandaloneRestProtocolMessage + protected def newMessage(): SubmitRestProtocolMessage /** Return a field of the relevant type from the field's string representation. */ - protected def fieldWithName(field: String): StandaloneRestProtocolField + protected def fieldWithName(field: String): SubmitRestProtocolField - /** Construct a StandaloneRestProtocolMessage from the set of fields provided. */ - def fromFields(fields: Map[String, String]): StandaloneRestProtocolMessage = { + /** Construct a SubmitRestProtocolMessage from the set of fields provided. */ + def fromFields(fields: Map[String, String]): SubmitRestProtocolMessage = { val message = newMessage() fields.foreach { case (k, v) => try { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala new file mode 100644 index 000000000000..533597c9275d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.DataOutputStream +import java.net.InetSocketAddress +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} + +import scala.io.Source + +import com.google.common.base.Charsets +import org.eclipse.jetty.server.{Request, Server} +import org.eclipse.jetty.server.handler.AbstractHandler + +import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging} +import org.apache.spark.util.Utils + +/** + * A server that responds to requests submitted by the SubmitRestClient. + */ +private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int) { + protected val handler: SubmitRestServerHandler + + /** Start the server. */ + def start(): Unit = { + val server = new Server(new InetSocketAddress(host, requestedPort)) + server.setHandler(handler) + server.start() + } +} + +/** + * A handler that responds to requests submitted via the submit REST protocol. + * This represents the main handler used in the SubmitRestServer. + */ +private[spark] abstract class SubmitRestServerHandler extends AbstractHandler with Logging { + protected def handleSubmit(request: SubmitDriverRequestMessage): SubmitDriverResponseMessage + protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage + protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage + + /** Handle a request submitted by the SubmitRestClient. */ + override def handle( + target: String, + baseRequest: Request, + request: HttpServletRequest, + response: HttpServletResponse): Unit = { + try { + val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + val responseMessage = constructResponseMessage(requestMessage) + response.setContentType("application/json") + response.setCharacterEncoding("utf-8") + response.setStatus(HttpServletResponse.SC_OK) + val content = responseMessage.toJson.getBytes(Charsets.UTF_8) + val out = new DataOutputStream(response.getOutputStream) + out.write(content) + out.close() + baseRequest.setHandled(true) + } catch { + case e: Exception => logError("Exception while handling request", e) + } + } + + /** + * Construct the appropriate response message based on the type of the request message. + * If an IllegalArgumentException is thrown in the process, construct an error message. + */ + private def constructResponseMessage( + request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { + // If the request is sent via the SubmitRestClient, it should have already been + // validated remotely. In case this is not true, validate the request here to guard + // against potential NPEs. If validation fails, return an ERROR message to the sender. + try { + request.validate() + request match { + case submit: SubmitDriverRequestMessage => handleSubmit(submit) + case kill: KillDriverRequestMessage => handleKill(kill) + case status: DriverStatusRequestMessage => handleStatus(status) + case unexpected => handleError( + s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + } + } catch { + // Propagate exception to user in an ErrorMessage. If the construction of the + // ErrorMessage itself throws an exception, log the exception and ignore the request. + case e: IllegalArgumentException => handleError(e.getMessage) + } + } + + /** Construct an error message to signal the fact that an exception has been thrown. */ + private def handleError(message: String): ErrorMessage = { + import ErrorField._ + new ErrorMessage() + .setField(SPARK_VERSION, sparkVersion) + .setField(MESSAGE, message) + .validate() + } +} From 484bd2172b847433c989d7c450fbbc99dddb1f56 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 19 Jan 2015 17:02:25 -0800 Subject: [PATCH 04/48] Specify an ordering for fields in SubmitDriverRequestMessage Previously APP_ARGs, SPARK_PROPERTYs and ENVIRONMENT_VARIABLEs will appear in the JSON at random places. Now they are grouped together at the end of the JSON blob. --- .../rest/SubmitDriverRequestMessage.scala | 18 ++++++++++ .../rest/SubmitRestProtocolMessage.scala | 34 +++++++++++-------- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index e55cb69ed112..30c203e003f1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -83,6 +83,8 @@ private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessag SubmitDriverRequestField.ACTION, SubmitDriverRequestField.requiredFields) { + import SubmitDriverRequestField._ + // Ensure continuous range of app arg indices starting from 0 override def validate(): this.type = { import SubmitDriverRequestField._ @@ -93,6 +95,22 @@ private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessag } super.validate() } + + // List the fields in the following order: + // ACTION < SPARK_VERSION < * < APP_ARG < SPARK_PROPERTY < ENVIRONMENT_VARIABLE < MESSAGE + protected override def sortedFields: Seq[(SubmitRestProtocolField, String)] = { + fields.toSeq.sortBy { case (k, _) => + k match { + case ACTION => 0 + case SPARK_VERSION => 1 + case APP_ARG(index) => 10 + index + case SPARK_PROPERTY(propKey) => 100 + case ENVIRONMENT_VARIABLE(envKey) => 1000 + case MESSAGE => Int.MaxValue + case _ => 2 + } + } + } } private[spark] object SubmitDriverRequestMessage extends SubmitRestProtocolMessageCompanion { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 77d38c40ae80..db5a42d7b17d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -24,7 +24,7 @@ import org.json4s.jackson.JsonMethods._ import org.json4s.JsonAST._ import org.apache.spark.{Logging, SparkException} -import org.apache.spark.util.Utils +import org.apache.spark.util.{JsonProtocol, Utils} /** * A field used in a SubmitRestProtocolMessage. @@ -32,8 +32,9 @@ import org.apache.spark.util.Utils */ private[spark] abstract class SubmitRestProtocolField private[spark] object SubmitRestProtocolField { - /** Return whether the provided field name refers to the ACTION field. */ def isActionField(field: String): Boolean = field == "ACTION" + def isSparkVersionField(field: String): Boolean = field == "SPARK_VERSION" + def isMessageField(field: String): Boolean = field == "MESSAGE" } /** @@ -125,23 +126,26 @@ private[spark] abstract class SubmitRestProtocolMessage( /** Return the JSON representation of this message. */ def toJson: String = { - val stringFields = fields + val jsonFields = sortedFields .filter { case (_, v) => v != null } - .map { case (k, v) => (k.toString, v) } - val jsonFields = fieldsToJson(stringFields) - pretty(render(jsonFields)) + .map { case (k, v) => JField(k.toString, JString(v)) } + .toList + pretty(render(JObject(jsonFields))) } /** - * Return the JSON representation of the message fields, putting ACTION first. - * This assumes that applying `org.apache.spark.util.JsonProtocol.mapFromJson` - * to the result yields the original input. + * Return a list of (field, value) pairs with the following ordering: + * ACTION < SPARK_VERSION < * < MESSAGE */ - private def fieldsToJson(fields: Map[String, String]): JValue = { - val jsonFields = fields.toList - .sortBy { case (k, _) => if (isActionField(k)) 0 else 1 } - .map { case (k, v) => JField(k, JString(v)) } - JObject(jsonFields) + protected def sortedFields: Seq[(SubmitRestProtocolField, String)] = { + fields.toSeq.sortBy { case (k, _) => + k.toString match { + case x if isActionField(x) => 0 + case x if isSparkVersionField(x) => 1 + case x if isMessageField(x) => Int.MaxValue + case _ => 2 + } + } } } @@ -155,7 +159,7 @@ private[spark] object SubmitRestProtocolMessage { * If such a field does not exist in the JSON, throw an exception. */ def fromJson(json: String): SubmitRestProtocolMessage = { - val fields = org.apache.spark.util.JsonProtocol.mapFromJson(parse(json)) + val fields = JsonProtocol.mapFromJson(parse(json)) val action = fields .flatMap { case (k, v) => if (isActionField(k)) Some(v) else None } .headOption From e958caec3e2f5bbd1dd2cf6dc0e02a3ad27c2699 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 20 Jan 2015 15:29:56 -0800 Subject: [PATCH 05/48] Supported nested values in messages This is applicable to application arguments, Spark properties, and environment variables, all of which were previously handled through parameterized fields, which were cumbersome to parse. Since JSON naturally supports nesting, we should take advantage of it too. This commit refactors the code that converts the messages to and from JSON in a way that subclasses can easily override the conversion behavior without duplicating code. --- .../rest/DriverStatusRequestMessage.scala | 12 +- .../rest/DriverStatusResponseMessage.scala | 18 +-- .../spark/deploy/rest/ErrorMessage.scala | 9 +- .../rest/KillDriverRequestMessage.scala | 12 +- .../rest/KillDriverResponseMessage.scala | 14 +- .../deploy/rest/StandaloneRestClient.scala | 13 +- .../deploy/rest/StandaloneRestServer.scala | 54 +++---- .../rest/SubmitDriverRequestMessage.scala | 128 +++++++++-------- .../rest/SubmitDriverResponseMessage.scala | 12 +- .../rest/SubmitRestProtocolMessage.scala | 133 +++++++++++------- .../spark/deploy/rest/SubmitRestServer.scala | 9 +- 11 files changed, 223 insertions(+), 191 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala index 14b77e4da9c9..f5f36401e205 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -21,7 +21,8 @@ package org.apache.spark.deploy.rest * A field used in a DriverStatusRequestMessage. */ private[spark] abstract class DriverStatusRequestField extends SubmitRestProtocolField -private[spark] object DriverStatusRequestField extends SubmitRestProtocolFieldCompanion { +private[spark] object DriverStatusRequestField + extends SubmitRestProtocolFieldCompanion[DriverStatusRequestField] { case object ACTION extends DriverStatusRequestField case object SPARK_VERSION extends DriverStatusRequestField case object MESSAGE extends DriverStatusRequestField @@ -39,9 +40,8 @@ private[spark] class DriverStatusRequestMessage extends SubmitRestProtocolMessag DriverStatusRequestField.ACTION, DriverStatusRequestField.requiredFields) -private[spark] object DriverStatusRequestMessage extends SubmitRestProtocolMessageCompanion { - protected override def newMessage(): SubmitRestProtocolMessage = - new DriverStatusRequestMessage - protected override def fieldWithName(field: String): SubmitRestProtocolField = - DriverStatusRequestField.withName(field) +private[spark] object DriverStatusRequestMessage + extends SubmitRestProtocolMessageCompanion[DriverStatusRequestMessage] { + protected override def newMessage() = new DriverStatusRequestMessage + protected override def fieldWithName(field: String) = DriverStatusRequestField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala index 8d45ceaa3ee2..392deefbc0e8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -21,15 +21,16 @@ package org.apache.spark.deploy.rest * A field used in a DriverStatusResponseMessage. */ private[spark] abstract class DriverStatusResponseField extends SubmitRestProtocolField -private[spark] object DriverStatusResponseField extends SubmitRestProtocolFieldCompanion { +private[spark] object DriverStatusResponseField + extends SubmitRestProtocolFieldCompanion[DriverStatusResponseField] { case object ACTION extends DriverStatusResponseField case object SPARK_VERSION extends DriverStatusResponseField case object MESSAGE extends DriverStatusResponseField case object MASTER extends DriverStatusResponseField case object DRIVER_ID extends DriverStatusResponseField - case object DRIVER_STATE extends SubmitDriverResponseField - case object WORKER_ID extends SubmitDriverResponseField - case object WORKER_HOST_PORT extends SubmitDriverResponseField + case object DRIVER_STATE extends DriverStatusResponseField + case object WORKER_ID extends DriverStatusResponseField + case object WORKER_HOST_PORT extends DriverStatusResponseField override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) override val optionalFields = Seq.empty @@ -43,9 +44,8 @@ private[spark] class DriverStatusResponseMessage extends SubmitRestProtocolMessa DriverStatusResponseField.ACTION, DriverStatusResponseField.requiredFields) -private[spark] object DriverStatusResponseMessage extends SubmitRestProtocolMessageCompanion { - protected override def newMessage(): SubmitRestProtocolMessage = - new DriverStatusResponseMessage - protected override def fieldWithName(field: String): SubmitRestProtocolField = - DriverStatusResponseField.withName(field) +private[spark] object DriverStatusResponseMessage + extends SubmitRestProtocolMessageCompanion[DriverStatusResponseMessage] { + protected override def newMessage() = new DriverStatusResponseMessage + protected override def fieldWithName(field: String) = DriverStatusResponseField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala index 020c7c28dc36..17e2b75a6d4b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala @@ -21,7 +21,7 @@ package org.apache.spark.deploy.rest * A field used in a ErrorMessage. */ private[spark] abstract class ErrorField extends SubmitRestProtocolField -private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion { +private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion[ErrorField] { case object ACTION extends ErrorField case object SPARK_VERSION extends ErrorField case object MESSAGE extends ErrorField @@ -37,8 +37,7 @@ private[spark] class ErrorMessage extends SubmitRestProtocolMessage( ErrorField.ACTION, ErrorField.requiredFields) -private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion { - protected override def newMessage(): SubmitRestProtocolMessage = new ErrorMessage - protected override def fieldWithName(field: String): SubmitRestProtocolField = - ErrorField.withName(field) +private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion[ErrorMessage] { + protected override def newMessage() = new ErrorMessage + protected override def fieldWithName(field: String) = ErrorField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala index cdca193e6aed..913c31c9d2af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -21,7 +21,8 @@ package org.apache.spark.deploy.rest * A field used in a KillDriverRequestMessage. */ private[spark] abstract class KillDriverRequestField extends SubmitRestProtocolField -private[spark] object KillDriverRequestField extends SubmitRestProtocolFieldCompanion { +private[spark] object KillDriverRequestField + extends SubmitRestProtocolFieldCompanion[KillDriverRequestField] { case object ACTION extends KillDriverRequestField case object SPARK_VERSION extends KillDriverRequestField case object MESSAGE extends KillDriverRequestField @@ -39,9 +40,8 @@ private[spark] class KillDriverRequestMessage extends SubmitRestProtocolMessage( KillDriverRequestField.ACTION, KillDriverRequestField.requiredFields) -private[spark] object KillDriverRequestMessage extends SubmitRestProtocolMessageCompanion { - protected override def newMessage(): SubmitRestProtocolMessage = - new KillDriverRequestMessage - protected override def fieldWithName(field: String): SubmitRestProtocolField = - KillDriverRequestField.withName(field) +private[spark] object KillDriverRequestMessage + extends SubmitRestProtocolMessageCompanion[KillDriverRequestMessage] { + protected override def newMessage() = new KillDriverRequestMessage + protected override def fieldWithName(field: String) = KillDriverRequestField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index 60eac1b1f26d..56aa7023e153 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -21,13 +21,14 @@ package org.apache.spark.deploy.rest * A field used in a KillDriverResponseMessage. */ private[spark] abstract class KillDriverResponseField extends SubmitRestProtocolField -private[spark] object KillDriverResponseField extends SubmitRestProtocolFieldCompanion { +private[spark] object KillDriverResponseField + extends SubmitRestProtocolFieldCompanion[KillDriverResponseField] { case object ACTION extends KillDriverResponseField case object SPARK_VERSION extends KillDriverResponseField case object MESSAGE extends KillDriverResponseField case object MASTER extends KillDriverResponseField case object DRIVER_ID extends KillDriverResponseField - case object SUCCESS extends SubmitDriverResponseField + case object SUCCESS extends KillDriverResponseField override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, SUCCESS) override val optionalFields = Seq.empty } @@ -40,9 +41,8 @@ private[spark] class KillDriverResponseMessage extends SubmitRestProtocolMessage KillDriverResponseField.ACTION, KillDriverResponseField.requiredFields) -private[spark] object KillDriverResponseMessage extends SubmitRestProtocolMessageCompanion { - protected override def newMessage(): SubmitRestProtocolMessage = - new KillDriverResponseMessage - protected override def fieldWithName(field: String): SubmitRestProtocolField = - KillDriverResponseField.withName(field) +private[spark] object KillDriverResponseMessage + extends SubmitRestProtocolMessageCompanion[KillDriverResponseMessage] { + protected override def newMessage() = new KillDriverResponseMessage + protected override def fieldWithName(field: String) = KillDriverResponseField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index cb1aba45a218..91e62548c6cd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -24,9 +24,8 @@ import org.apache.spark.deploy.SparkSubmitArguments import org.apache.spark.util.Utils /** - * A client that submits Spark applications to the standalone Master using a stable - * REST protocol. This client is intended to communicate with the StandaloneRestServer, - * and currently only used in cluster mode. + * A client that submits Spark applications to the standalone Master using a stable REST protocol. + * This client is intended to communicate with the StandaloneRestServer. Cluster mode only. */ private[spark] class StandaloneRestClient extends SubmitRestClient { @@ -56,12 +55,8 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString) .setFieldIfNotNull(EXECUTOR_MEMORY, executorMemory) .setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores) - args.childArgs.zipWithIndex.foreach { case (arg, i) => - message.setFieldIfNotNull(APP_ARG(i), arg) - } - args.sparkProperties.foreach { case (k, v) => - message.setFieldIfNotNull(SPARK_PROPERTY(k), v) - } + args.childArgs.foreach(message.appendAppArg) + args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } // TODO: set environment variables? message.validate() } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index e160b8e0a0d2..54cdcf2e0d26 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -33,7 +33,7 @@ import akka.actor.ActorRef /** * A server that responds to requests submitted by the StandaloneRestClient. - * This is intended to be embedded in the standalone Master. + * This is intended to be embedded in the standalone Master. Cluster mode only. */ private[spark] class StandaloneRestServer( master: Master, @@ -44,8 +44,7 @@ private[spark] class StandaloneRestServer( } /** - * A handler that responds to requests submitted to the standalone Master - * through the REST protocol. + * A handler for requests submitted to the standalone Master through the REST protocol. */ private[spark] class StandaloneRestServerHandler( conf: SparkConf, @@ -53,7 +52,7 @@ private[spark] class StandaloneRestServerHandler( masterUrl: String) extends SubmitRestServerHandler { - private implicit val askTimeout = AkkaUtils.askTimeout(conf) + private val askTimeout = AkkaUtils.askTimeout(conf) def this(master: Master) = { this(master.conf, master.self, master.masterUrl) @@ -109,18 +108,15 @@ private[spark] class StandaloneRestServerHandler( /** * Build a driver description from the fields specified in the submit request. - * This currently does not consider fields used by python applications since + * This does not currently consider fields used by python applications since * python is not supported in standalone cluster mode yet. */ private def buildDriverDescription(request: SubmitDriverRequestMessage): DriverDescription = { import SubmitDriverRequestField._ - // Required fields + // Required fields, including the main class because python is not yet supported val appName = request.getFieldNotNull(APP_NAME) val appResource = request.getFieldNotNull(APP_RESOURCE) - - // Since standalone cluster mode does not yet support python, - // we treat the main class as required val mainClass = request.getFieldNotNull(MAIN_CLASS) // Optional fields @@ -134,25 +130,20 @@ private[spark] class StandaloneRestServerHandler( val superviseDriver = request.getFieldOption(SUPERVISE_DRIVER) val executorMemory = request.getFieldOption(EXECUTOR_MEMORY) val totalExecutorCores = request.getFieldOption(TOTAL_EXECUTOR_CORES) + val appArgs = request.getAppArgs + val sparkProperties = request.getSparkProperties + val environmentVariables = request.getEnvironmentVariables - // Parse special fields that take in parameters + // Translate all fields to the relevant Spark properties val conf = new SparkConf(false) - val env = new mutable.HashMap[String, String] - val appArgs = new ArrayBuffer[(Int, String)] - request.getFields.foreach { case (k, v) => - k match { - case APP_ARG(index) => appArgs += ((index, v)) - case SPARK_PROPERTY(propKey) => conf.set(propKey, v) - case ENVIRONMENT_VARIABLE(envKey) => env(envKey) = v - case _ => - } - } - - // Use the actual master URL instead of the one that refers to this REST server - // Otherwise, once the driver is launched it will contact with the wrong server - conf.set("spark.master", masterUrl) - conf.set("spark.app.name", appName) - conf.set("spark.jars", jars.map(_ + ",").getOrElse("") + appResource) // include app resource + .setAll(sparkProperties) + // Use the actual master URL instead of the one that refers to this REST server + // Otherwise, once the driver is launched it will contact with the wrong server + .set("spark.master", masterUrl) + .set("spark.app.name", appName) + // Include main app resource on the executor classpath + // The corresponding behavior in client mode is handled in SparkSubmit + .set("spark.jars", jars.map(_ + ",").getOrElse("") + appResource) files.foreach { f => conf.set("spark.files", f) } driverExtraJavaOptions.foreach { j => conf.set("spark.driver.extraJavaOptions", j) } driverExtraClassPath.foreach { cp => conf.set("spark.driver.extraClassPath", cp) } @@ -161,10 +152,6 @@ private[spark] class StandaloneRestServerHandler( totalExecutorCores.foreach { c => conf.set("spark.cores.max", c) } // Construct driver description and submit it - val actualDriverMemory = driverMemory.map(_.toInt).getOrElse(DEFAULT_MEMORY) - val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) - val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) - val actualAppArgs = appArgs.sortBy(_._1).map(_._2) // sort by index, map to value val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator)) val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator)) val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty) @@ -172,8 +159,11 @@ private[spark] class StandaloneRestServerHandler( val javaOpts = sparkJavaOpts ++ extraJavaOpts val command = new Command( "org.apache.spark.deploy.worker.DriverWrapper", - Seq("{{WORKER_URL}}", mainClass) ++ actualAppArgs, - env, extraClassPath, extraLibraryPath, javaOpts) + Seq("{{WORKER_URL}}", mainClass) ++ appArgs, // args to the DriverWrapper + environmentVariables, extraClassPath, extraLibraryPath, javaOpts) + val actualDriverMemory = driverMemory.map(_.toInt).getOrElse(DEFAULT_MEMORY) + val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) + val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) new DriverDescription( appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index 30c203e003f1..0ff68c1584f6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -17,15 +17,19 @@ package org.apache.spark.deploy.rest -import scala.util.matching.Regex +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer -import org.apache.spark.util.Utils +import org.json4s.JsonAST._ + +import org.apache.spark.util.JsonProtocol /** * A field used in a SubmitDriverRequestMessage. */ private[spark] abstract class SubmitDriverRequestField extends SubmitRestProtocolField -private[spark] object SubmitDriverRequestField extends SubmitRestProtocolFieldCompanion { +private[spark] object SubmitDriverRequestField + extends SubmitRestProtocolFieldCompanion[SubmitRestProtocolField] { case object ACTION extends SubmitDriverRequestField case object SPARK_VERSION extends SubmitDriverRequestField case object MESSAGE extends SubmitDriverRequestField @@ -44,35 +48,14 @@ private[spark] object SubmitDriverRequestField extends SubmitRestProtocolFieldCo case object SUPERVISE_DRIVER extends SubmitDriverRequestField // standalone cluster mode only case object EXECUTOR_MEMORY extends SubmitDriverRequestField case object TOTAL_EXECUTOR_CORES extends SubmitDriverRequestField - case class APP_ARG(index: Int) extends SubmitDriverRequestField { - override def toString: String = Utils.getFormattedClassName(this) + "_" + index - } - case class SPARK_PROPERTY(prop: String) extends SubmitDriverRequestField { - override def toString: String = Utils.getFormattedClassName(this) + "_" + prop - } - case class ENVIRONMENT_VARIABLE(envVar: String) extends SubmitDriverRequestField { - override def toString: String = Utils.getFormattedClassName(this) + "_" + envVar - } + case object APP_ARGS extends SubmitDriverRequestField + case object SPARK_PROPERTIES extends SubmitDriverRequestField + case object ENVIRONMENT_VARIABLES extends SubmitDriverRequestField override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, APP_NAME, APP_RESOURCE) override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, - SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES) - - // Because certain fields taken in arguments, we cannot simply rely on the - // list of all fields to reconstruct a field from its String representation. - // Instead, we must treat these fields as special cases and match on their prefixes. - override def withName(field: String): SubmitRestProtocolField = { - def buildRegex(obj: AnyRef): Regex = s"${Utils.getFormattedClassName(obj)}_(.*)".r - val appArg = buildRegex(APP_ARG) - val sparkProperty = buildRegex(SPARK_PROPERTY) - val environmentVariable = buildRegex(ENVIRONMENT_VARIABLE) - field match { - case appArg(f) => APP_ARG(f.toInt) - case sparkProperty(f) => SPARK_PROPERTY(f) - case environmentVariable(f) => ENVIRONMENT_VARIABLE(f) - case _ => super.withName(field) - } - } + SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES, APP_ARGS, SPARK_PROPERTIES, + ENVIRONMENT_VARIABLES) } /** @@ -85,37 +68,66 @@ private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessag import SubmitDriverRequestField._ - // Ensure continuous range of app arg indices starting from 0 - override def validate(): this.type = { - import SubmitDriverRequestField._ - val indices = fields.collect { case (a: APP_ARG, _) => a }.toSeq.sortBy(_.index).map(_.index) - val expectedIndices = (0 until indices.size).toSeq - if (indices != expectedIndices) { - throw new IllegalArgumentException(s"Malformed app arg indices: ${indices.mkString(",")}") - } - super.validate() - } + private val appArgs = new ArrayBuffer[String] + private val sparkProperties = new mutable.HashMap[String, String] + private val environmentVariables = new mutable.HashMap[String, String] - // List the fields in the following order: - // ACTION < SPARK_VERSION < * < APP_ARG < SPARK_PROPERTY < ENVIRONMENT_VARIABLE < MESSAGE - protected override def sortedFields: Seq[(SubmitRestProtocolField, String)] = { - fields.toSeq.sortBy { case (k, _) => - k match { - case ACTION => 0 - case SPARK_VERSION => 1 - case APP_ARG(index) => 10 + index - case SPARK_PROPERTY(propKey) => 100 - case ENVIRONMENT_VARIABLE(envKey) => 1000 - case MESSAGE => Int.MaxValue - case _ => 2 - } - } + // Special field setters + def appendAppArg(arg: String): Unit = { appArgs += arg } + def setSparkProperty(k: String, v: String): Unit = { sparkProperties(k) = v } + def setEnvironmentVariable(k: String, v: String): Unit = { environmentVariables(k) = v } + + // Special field getters + def getAppArgs: Seq[String] = appArgs.clone() + def getSparkProperties: Map[String, String] = sparkProperties.toMap + def getEnvironmentVariables: Map[String, String] = environmentVariables.toMap + + // Include app args, spark properties, and environment variables in the JSON object + override def toJsonObject: JObject = { + val otherFields = super.toJsonObject.obj + val appArgsJson = JArray(appArgs.map(JString).toList) + val sparkPropertiesJson = JsonProtocol.mapToJson(sparkProperties) + val environmentVariablesJson = JsonProtocol.mapToJson(environmentVariables) + val allFields = otherFields ++ List( + (APP_ARGS.toString, appArgsJson), + (SPARK_PROPERTIES.toString, sparkPropertiesJson), + (ENVIRONMENT_VARIABLES.toString, environmentVariablesJson) + ) + JObject(allFields) } } -private[spark] object SubmitDriverRequestMessage extends SubmitRestProtocolMessageCompanion { - protected override def newMessage(): SubmitRestProtocolMessage = - new SubmitDriverRequestMessage - protected override def fieldWithName(field: String): SubmitRestProtocolField = - SubmitDriverRequestField.withName(field) +private[spark] object SubmitDriverRequestMessage + extends SubmitRestProtocolMessageCompanion[SubmitDriverRequestMessage] { + + import SubmitDriverRequestField._ + + protected override def newMessage() = new SubmitDriverRequestMessage + protected override def fieldWithName(field: String) = SubmitDriverRequestField.withName(field) + + /** + * Process the given field and value appropriately based on the type of the field. + * This handles certain nested values in addition to flat values. + */ + override def handleField( + message: SubmitDriverRequestMessage, + field: SubmitRestProtocolField, + value: JValue): Unit = { + (field, value) match { + case (APP_ARGS, JArray(args)) => + args.map(_.asInstanceOf[JString].s).foreach { arg => + message.appendAppArg(arg) + } + case (SPARK_PROPERTIES, props: JObject) => + JsonProtocol.mapFromJson(props).foreach { case (k, v) => + message.setSparkProperty(k, v) + } + case (ENVIRONMENT_VARIABLES, envVars: JObject) => + JsonProtocol.mapFromJson(envVars).foreach { case (envKey, envValue) => + message.setEnvironmentVariable(envKey, envValue) + } + // All other fields are assumed to have flat values + case _ => super.handleField(message, field, value) + } + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index a877ec4fccff..bf3c03d2e333 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -21,7 +21,8 @@ package org.apache.spark.deploy.rest * A field used in a SubmitDriverResponseMessage. */ private[spark] abstract class SubmitDriverResponseField extends SubmitRestProtocolField -private[spark] object SubmitDriverResponseField extends SubmitRestProtocolFieldCompanion { +private[spark] object SubmitDriverResponseField + extends SubmitRestProtocolFieldCompanion[SubmitDriverResponseField] { case object ACTION extends SubmitDriverResponseField case object SPARK_VERSION extends SubmitDriverResponseField case object MESSAGE extends SubmitDriverResponseField @@ -40,9 +41,8 @@ private[spark] class SubmitDriverResponseMessage extends SubmitRestProtocolMessa SubmitDriverResponseField.ACTION, SubmitDriverResponseField.requiredFields) -private[spark] object SubmitDriverResponseMessage extends SubmitRestProtocolMessageCompanion { - protected override def newMessage(): SubmitRestProtocolMessage = - new SubmitDriverResponseMessage - protected override def fieldWithName(field: String): SubmitRestProtocolField = - SubmitDriverResponseField.withName(field) +private[spark] object SubmitDriverResponseMessage + extends SubmitRestProtocolMessageCompanion[SubmitDriverResponseMessage] { + protected override def newMessage() = new SubmitDriverResponseMessage + protected override def fieldWithName(field: String) = SubmitDriverResponseField.withName(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index db5a42d7b17d..496db227f210 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -24,7 +24,7 @@ import org.json4s.jackson.JsonMethods._ import org.json4s.JsonAST._ import org.apache.spark.{Logging, SparkException} -import org.apache.spark.util.{JsonProtocol, Utils} +import org.apache.spark.util.Utils /** * A field used in a SubmitRestProtocolMessage. @@ -64,8 +64,8 @@ private[spark] abstract class SubmitRestProtocolMessage( import SubmitRestProtocolField._ - private val className = Utils.getFormattedClassName(this) - protected val fields = new mutable.HashMap[SubmitRestProtocolField, String] + private val fields = new mutable.HashMap[SubmitRestProtocolField, String] + val className = Utils.getFormattedClassName(this) // Set the action field fields(actionField) = action.toString @@ -125,27 +125,27 @@ private[spark] abstract class SubmitRestProtocolMessage( } /** Return the JSON representation of this message. */ - def toJson: String = { - val jsonFields = sortedFields - .filter { case (_, v) => v != null } - .map { case (k, v) => JField(k.toString, JString(v)) } - .toList - pretty(render(JObject(jsonFields))) - } + def toJson: String = pretty(render(toJsonObject)) /** - * Return a list of (field, value) pairs with the following ordering: - * ACTION < SPARK_VERSION < * < MESSAGE + * Return a JObject that represents the JSON form of this message. + * This orders the fields by ACTION (first) < SPARK_VERSION < MESSAGE < * (last) + * and ignores fields with null values. */ - protected def sortedFields: Seq[(SubmitRestProtocolField, String)] = { - fields.toSeq.sortBy { case (k, _) => + protected def toJsonObject: JObject = { + val sortedFields = fields.toSeq.sortBy { case (k, _) => k.toString match { case x if isActionField(x) => 0 case x if isSparkVersionField(x) => 1 - case x if isMessageField(x) => Int.MaxValue - case _ => 2 + case x if isMessageField(x) => 2 + case _ => 3 } } + val jsonFields = sortedFields + .filter { case (_, v) => v != null } + .map { case (k, v) => JField(k.toString, JString(v)) } + .toList + JObject(jsonFields) } } @@ -153,27 +153,41 @@ private[spark] object SubmitRestProtocolMessage { import SubmitRestProtocolField._ import SubmitRestProtocolAction._ + /** Construct a SubmitRestProtocolMessage from its JSON representation. */ + def fromJson(json: String): SubmitRestProtocolMessage = { + fromJsonObject(parse(json).asInstanceOf[JObject]) + } + /** - * Construct a SubmitRestProtocolMessage from JSON. + * Construct a SubmitRestProtocolMessage from the given JSON object. * This uses the ACTION field to determine the type of the message to reconstruct. - * If such a field does not exist in the JSON, throw an exception. */ - def fromJson(json: String): SubmitRestProtocolMessage = { - val fields = JsonProtocol.mapFromJson(parse(json)) - val action = fields - .flatMap { case (k, v) => if (isActionField(k)) Some(v) else None } - .headOption - .getOrElse { throw new IllegalArgumentException(s"ACTION not found in message:\n$json") } + protected def fromJsonObject(jsonObject: JObject): SubmitRestProtocolMessage = { + val action = getAction(jsonObject) SubmitRestProtocolAction.withName(action) match { - case SUBMIT_DRIVER_REQUEST => SubmitDriverRequestMessage.fromFields(fields) - case SUBMIT_DRIVER_RESPONSE => SubmitDriverResponseMessage.fromFields(fields) - case KILL_DRIVER_REQUEST => KillDriverRequestMessage.fromFields(fields) - case KILL_DRIVER_RESPONSE => KillDriverResponseMessage.fromFields(fields) - case DRIVER_STATUS_REQUEST => DriverStatusRequestMessage.fromFields(fields) - case DRIVER_STATUS_RESPONSE => DriverStatusResponseMessage.fromFields(fields) - case ERROR => ErrorMessage.fromFields(fields) + case SUBMIT_DRIVER_REQUEST => SubmitDriverRequestMessage.fromJsonObject(jsonObject) + case SUBMIT_DRIVER_RESPONSE => SubmitDriverResponseMessage.fromJsonObject(jsonObject) + case KILL_DRIVER_REQUEST => KillDriverRequestMessage.fromJsonObject(jsonObject) + case KILL_DRIVER_RESPONSE => KillDriverResponseMessage.fromJsonObject(jsonObject) + case DRIVER_STATUS_REQUEST => DriverStatusRequestMessage.fromJsonObject(jsonObject) + case DRIVER_STATUS_RESPONSE => DriverStatusResponseMessage.fromJsonObject(jsonObject) + case ERROR => ErrorMessage.fromJsonObject(jsonObject) } } + + /** + * Extract the value of the ACTION field in the JSON object. + * If such a field does not exist in the JSON, throw an exception. + */ + private def getAction(jsonObject: JObject): String = { + jsonObject.obj + .collect { case JField(k, JString(v)) if isActionField(k) => v } + .headOption + .getOrElse { + throw new IllegalArgumentException( + "ACTION not found in message:\n" + pretty(render(jsonObject))) + } + } } /** @@ -182,17 +196,17 @@ private[spark] object SubmitRestProtocolMessage { * It is necessary to keep track of all fields that belong to this object in order to * reconstruct the fields from their names. */ -private[spark] trait SubmitRestProtocolFieldCompanion { - val requiredFields: Seq[SubmitRestProtocolField] - val optionalFields: Seq[SubmitRestProtocolField] +private[spark] trait SubmitRestProtocolFieldCompanion[FieldType <: SubmitRestProtocolField] { + val requiredFields: Seq[FieldType] + val optionalFields: Seq[FieldType] /** Listing of all fields indexed by the field's string representation. */ - private lazy val allFieldsMap: Map[String, SubmitRestProtocolField] = { + private lazy val allFieldsMap: Map[String, FieldType] = { (requiredFields ++ optionalFields).map { f => (f.toString, f) }.toMap } /** Return a SubmitRestProtocolField from its string representation. */ - def withName(field: String): SubmitRestProtocolField = { + def withName(field: String): FieldType = { allFieldsMap.get(field).getOrElse { throw new IllegalArgumentException(s"Unknown field $field") } @@ -202,29 +216,50 @@ private[spark] trait SubmitRestProtocolFieldCompanion { /** * A trait that holds common methods for SubmitRestProtocolMessage companion objects. */ -private[spark] trait SubmitRestProtocolMessageCompanion extends Logging { +private[spark] trait SubmitRestProtocolMessageCompanion[MessageType <: SubmitRestProtocolMessage] + extends Logging { + import SubmitRestProtocolField._ /** Construct a new message of the relevant type. */ - protected def newMessage(): SubmitRestProtocolMessage + protected def newMessage(): MessageType /** Return a field of the relevant type from the field's string representation. */ protected def fieldWithName(field: String): SubmitRestProtocolField - /** Construct a SubmitRestProtocolMessage from the set of fields provided. */ - def fromFields(fields: Map[String, String]): SubmitRestProtocolMessage = { + /** + * Process the given field and value appropriately based on the type of the field. + * The default behavior only considers fields that have flat values and ignores other fields. + * If the subclass uses fields with nested values, it should override this method appropriately. + */ + protected def handleField( + message: MessageType, + field: SubmitRestProtocolField, + value: JValue): Unit = { + value match { + case JString(s) => message.setField(field, s) + case _ => logWarning( + s"Unexpected value for field $field in message ${message.className}:\n$value") + } + } + + /** Construct a SubmitRestProtocolMessage from the given JSON object. */ + def fromJsonObject(jsonObject: JObject): MessageType = { val message = newMessage() - fields.foreach { case (k, v) => - try { - // The ACTION field is already set on instantiation - if (!isActionField(k)) { - message.setField(fieldWithName(k), v) + val fields = jsonObject.obj + .map { case JField(k, v) => (k, v) } + // The ACTION field is already handled on instantiation + .filter { case (k, _) => !isActionField(k) } + .flatMap { case (k, v) => + try { + Some((fieldWithName(k), v)) + } catch { + case e: IllegalArgumentException => + logWarning(s"Unexpected field $k in message ${Utils.getFormattedClassName(this)}") + None } - } catch { - case e: IllegalArgumentException => - logWarning(s"Unexpected field $k in message ${Utils.getFormattedClassName(this)}") } - } + fields.foreach { case (k, v) => handleField(message, k, v) } message } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 533597c9275d..6ad4be4f26ae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -45,7 +45,7 @@ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int) } /** - * A handler that responds to requests submitted via the submit REST protocol. + * A handler for requests submitted via the stable REST protocol for submitting applications. * This represents the main handler used in the SubmitRestServer. */ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler with Logging { @@ -78,7 +78,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi /** * Construct the appropriate response message based on the type of the request message. - * If an IllegalArgumentException is thrown in the process, construct an error message. + * If an IllegalArgumentException is thrown in the process, construct an error message instead. */ private def constructResponseMessage( request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { @@ -95,8 +95,9 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") } } catch { - // Propagate exception to user in an ErrorMessage. If the construction of the - // ErrorMessage itself throws an exception, log the exception and ignore the request. + // Propagate exception to user in an ErrorMessage. + // Note that the construction of the error message itself may throw an exception. + // In this case, let the higher level caller take care of this request. case e: IllegalArgumentException => handleError(e.getMessage) } } From 544de1dd5d4d2bb35180775ddf655bedee4d44ae Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 21 Jan 2015 14:02:33 -0800 Subject: [PATCH 06/48] Major clean ups in code and comments This involves refactoring SparkSubmit a little to put the code that launches the REST client in the right place. This commit also adds port retry logic in the REST server, which was previously missing. --- .../org/apache/spark/deploy/SparkSubmit.scala | 78 +++++++++---------- .../apache/spark/deploy/master/Master.scala | 1 + .../rest/DriverStatusRequestMessage.scala | 11 +-- .../rest/DriverStatusResponseMessage.scala | 11 +-- .../spark/deploy/rest/ErrorMessage.scala | 12 +-- .../rest/KillDriverRequestMessage.scala | 11 +-- .../rest/KillDriverResponseMessage.scala | 11 +-- .../deploy/rest/StandaloneRestClient.scala | 16 ++-- .../deploy/rest/StandaloneRestServer.scala | 18 ++--- .../rest/SubmitDriverRequestMessage.scala | 18 +++-- .../rest/SubmitDriverResponseMessage.scala | 11 +-- .../spark/deploy/rest/SubmitRestClient.scala | 3 +- .../rest/SubmitRestProtocolMessage.scala | 57 ++++++-------- .../spark/deploy/rest/SubmitRestServer.scala | 36 ++++++--- 14 files changed, 151 insertions(+), 143 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index ec3def339759..4aba0feefdf2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -73,30 +73,24 @@ object SparkSubmit { if (appArgs.verbose) { printStream.println(appArgs) } - - // In standalone cluster mode, use the brand new REST client to submit the application - val isStandaloneCluster = - appArgs.master.startsWith("spark://") && appArgs.deployMode == "cluster" - if (isStandaloneCluster) { - new StandaloneRestClient().submitDriver(appArgs) - return - } - - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) - launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose) + launch(appArgs) } /** - * @return a tuple containing - * (1) the arguments for the child process, - * (2) a list of classpath entries for the child, - * (3) a list of system properties and env vars, and - * (4) the main class for the child + * Launch the application using the provided parameters. + * + * This runs in two steps. First, we prepare the launch environment by setting up + * the appropriate classpath, system properties, and application arguments for + * running the child main class based on the cluster manager and the deploy mode. + * Second, we use this launch environment to invoke the main method of the child + * main class. + * + * Note that standalone cluster mode is an exception in that we do not invoke the + * main method of a child class. Instead, we pass the submit parameters directly to + * a REST client, which will submit the application using the stable REST protocol. */ - private[spark] def createLaunchEnv(args: SparkSubmitArguments) - : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = { - - // Values to return + private[spark] def launch(args: SparkSubmitArguments): Unit = { + // Environment needed to launch the child main class val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() val sysProps = new HashMap[String, String]() @@ -198,8 +192,6 @@ object SparkSubmit { // Standalone cluster only OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), - OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), - OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), @@ -228,6 +220,9 @@ object SparkSubmit { sysProp = "spark.files") ) + val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER + val isStandaloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER + // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath if (deployMode == CLIENT) { @@ -239,7 +234,6 @@ object SparkSubmit { if (args.childArgs != null) { childArgs ++= args.childArgs } } - // Map all arguments to command-line options or system properties for our chosen mode for (opt <- options) { if (opt.value != null && @@ -253,7 +247,6 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python files, the primary resource is already distributed as a regular file - val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER if (!isYarnCluster && !args.isPython) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { @@ -262,19 +255,6 @@ object SparkSubmit { sysProps.put("spark.jars", jars.mkString(",")) } - // In standalone-cluster mode, use Client as a wrapper around the user class - if (clusterManager == STANDALONE && deployMode == CLUSTER) { - childMainClass = "org.apache.spark.deploy.Client" - if (args.supervise) { - childArgs += "--supervise" - } - childArgs += "launch" - childArgs += (args.master, args.primaryResource, args.mainClass) - if (args.childArgs != null) { - childArgs ++= args.childArgs - } - } - // In yarn-cluster mode, use yarn.Client as a wrapper around the user class if (isYarnCluster) { childMainClass = "org.apache.spark.deploy.yarn.Client" @@ -294,7 +274,7 @@ object SparkSubmit { // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= ("spark.driver.host") + sysProps -= "spark.driver.host" } // Resolve paths in certain spark properties @@ -320,10 +300,28 @@ object SparkSubmit { sysProps("spark.submit.pyFiles") = formattedPyFiles } - (childArgs, childClasspath, sysProps, childMainClass) + // In standalone cluster mode, use the stable application submission REST protocol. + // Otherwise, just call the main method of the child class. + if (isStandaloneCluster) { + // NOTE: since we mutate the values of some configs in this method, we must update the + // corresponding fields in the original SparkSubmitArguments to reflect these changes. + args.sparkProperties.clear() + args.sparkProperties ++= sysProps + sysProps.get("spark.jars").foreach { args.jars = _ } + sysProps.get("spark.files").foreach { args.files = _ } + new StandaloneRestClient().submitDriver(args) + } else { + runMain(childArgs, childClasspath, sysProps, childMainClass) + } } - private def launch( + /** + * Run the main method of the child class using the provided launch environment. + * + * Depending on the deploy mode, cluster manager, and the type of the application, + * this main class may not necessarily be the one provided by the user. + */ + private def runMain( childArgs: ArrayBuffer[String], childClasspath: ArrayBuffer[String], sysProps: Map[String, String], diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 1bd1992c95b1..210b3802dc34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -180,6 +180,7 @@ private[spark] class Master( recoveryCompletionTask.cancel() } webUi.stop() + restServer.stop() masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala index f5f36401e205..57f79554151e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -33,15 +33,16 @@ private[spark] object DriverStatusRequestField } /** - * A request sent to the cluster manager to query the status of a driver. + * A request sent to the cluster manager to query the status of a driver + * in the stable application submission REST protocol. */ private[spark] class DriverStatusRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.DRIVER_STATUS_REQUEST, - DriverStatusRequestField.ACTION, - DriverStatusRequestField.requiredFields) + SubmitRestProtocolAction.DRIVER_STATUS_REQUEST, + DriverStatusRequestField.ACTION, + DriverStatusRequestField.requiredFields) private[spark] object DriverStatusRequestMessage extends SubmitRestProtocolMessageCompanion[DriverStatusRequestMessage] { protected override def newMessage() = new DriverStatusRequestMessage - protected override def fieldWithName(field: String) = DriverStatusRequestField.withName(field) + protected override def fieldFromString(field: String) = DriverStatusRequestField.fromString(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala index 392deefbc0e8..42c64dc60175 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -37,15 +37,16 @@ private[spark] object DriverStatusResponseField } /** - * A message sent from the cluster manager in response to a DriverStatusResponseMessage. + * A message sent from the cluster manager in response to a DriverStatusRequestMessage + * in the stable application submission REST protocol. */ private[spark] class DriverStatusResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE, - DriverStatusResponseField.ACTION, - DriverStatusResponseField.requiredFields) + SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE, + DriverStatusResponseField.ACTION, + DriverStatusResponseField.requiredFields) private[spark] object DriverStatusResponseMessage extends SubmitRestProtocolMessageCompanion[DriverStatusResponseMessage] { protected override def newMessage() = new DriverStatusResponseMessage - protected override def fieldWithName(field: String) = DriverStatusResponseField.withName(field) + protected override def fieldFromString(field: String) = DriverStatusResponseField.fromString(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala index 17e2b75a6d4b..04a298d98a34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.rest /** - * A field used in a ErrorMessage. + * A field used in an ErrorMessage. */ private[spark] abstract class ErrorField extends SubmitRestProtocolField private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion[ErrorField] { @@ -30,14 +30,14 @@ private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion[ErrorF } /** - * An error message exchanged in the stable application submission protocol. + * An error message exchanged in the stable application submission REST protocol. */ private[spark] class ErrorMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.ERROR, - ErrorField.ACTION, - ErrorField.requiredFields) + SubmitRestProtocolAction.ERROR, + ErrorField.ACTION, + ErrorField.requiredFields) private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion[ErrorMessage] { protected override def newMessage() = new ErrorMessage - protected override def fieldWithName(field: String) = ErrorField.withName(field) + protected override def fieldFromString(field: String) = ErrorField.fromString(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala index 913c31c9d2af..3245058ce4ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -33,15 +33,16 @@ private[spark] object KillDriverRequestField } /** - * A request sent to the cluster manager to kill a driver. + * A request sent to the cluster manager to kill a driver + * in the stable application submission REST protocol. */ private[spark] class KillDriverRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.KILL_DRIVER_REQUEST, - KillDriverRequestField.ACTION, - KillDriverRequestField.requiredFields) + SubmitRestProtocolAction.KILL_DRIVER_REQUEST, + KillDriverRequestField.ACTION, + KillDriverRequestField.requiredFields) private[spark] object KillDriverRequestMessage extends SubmitRestProtocolMessageCompanion[KillDriverRequestMessage] { protected override def newMessage() = new KillDriverRequestMessage - protected override def fieldWithName(field: String) = KillDriverRequestField.withName(field) + protected override def fieldFromString(field: String) = KillDriverRequestField.fromString(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index 56aa7023e153..92db6cfa2d64 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -34,15 +34,16 @@ private[spark] object KillDriverResponseField } /** - * A message sent from the cluster manager in response to a KillDriverResponseMessage. + * A message sent from the cluster manager in response to a KillDriverRequestMessage + * in the stable application submission REST protocol. */ private[spark] class KillDriverResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.KILL_DRIVER_RESPONSE, - KillDriverResponseField.ACTION, - KillDriverResponseField.requiredFields) + SubmitRestProtocolAction.KILL_DRIVER_RESPONSE, + KillDriverResponseField.ACTION, + KillDriverResponseField.requiredFields) private[spark] object KillDriverResponseMessage extends SubmitRestProtocolMessageCompanion[KillDriverResponseMessage] { protected override def newMessage() = new KillDriverResponseMessage - protected override def fieldWithName(field: String) = KillDriverResponseField.withName(field) + protected override def fieldFromString(field: String) = KillDriverResponseField.fromString(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 91e62548c6cd..03eaa93f0d33 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -24,7 +24,7 @@ import org.apache.spark.deploy.SparkSubmitArguments import org.apache.spark.util.Utils /** - * A client that submits Spark applications to the standalone Master using a stable REST protocol. + * A client that submits applications to the standalone Master using the stable REST protocol. * This client is intended to communicate with the StandaloneRestServer. Cluster mode only. */ private[spark] class StandaloneRestClient extends SubmitRestClient { @@ -33,12 +33,8 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { override protected def constructSubmitRequest( args: SparkSubmitArguments): SubmitDriverRequestMessage = { import SubmitDriverRequestField._ - val driverMemory = Option(args.driverMemory) - .map { m => Utils.memoryStringToMb(m).toString } - .orNull - val executorMemory = Option(args.executorMemory) - .map { m => Utils.memoryStringToMb(m).toString } - .orNull + val dm = Option(args.driverMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull + val em = Option(args.executorMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull val message = new SubmitDriverRequestMessage() .setField(SPARK_VERSION, sparkVersion) .setField(MASTER, args.master) @@ -47,17 +43,17 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setFieldIfNotNull(MAIN_CLASS, args.mainClass) .setFieldIfNotNull(JARS, args.jars) .setFieldIfNotNull(FILES, args.files) - .setFieldIfNotNull(DRIVER_MEMORY, driverMemory) + .setFieldIfNotNull(DRIVER_MEMORY, dm) .setFieldIfNotNull(DRIVER_CORES, args.driverCores) .setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions) .setFieldIfNotNull(DRIVER_EXTRA_CLASS_PATH, args.driverExtraClassPath) .setFieldIfNotNull(DRIVER_EXTRA_LIBRARY_PATH, args.driverExtraLibraryPath) .setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString) - .setFieldIfNotNull(EXECUTOR_MEMORY, executorMemory) + .setFieldIfNotNull(EXECUTOR_MEMORY, em) .setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores) args.childArgs.foreach(message.appendAppArg) args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } - // TODO: set environment variables? + // TODO: send special environment variables? message.validate() } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 54cdcf2e0d26..7916029517cc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy.rest import java.io.File -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer +import akka.actor.ActorRef import org.apache.spark.{SPARK_VERSION => sparkVersion} import org.apache.spark.SparkConf @@ -29,22 +28,19 @@ import org.apache.spark.deploy.{Command, DriverDescription} import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import akka.actor.ActorRef /** * A server that responds to requests submitted by the StandaloneRestClient. * This is intended to be embedded in the standalone Master. Cluster mode only. */ -private[spark] class StandaloneRestServer( - master: Master, - host: String, - requestedPort: Int) - extends SubmitRestServer(host, requestedPort) { +private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) + extends SubmitRestServer(host, requestedPort, master.conf) { override protected val handler = new StandaloneRestServerHandler(master) } /** - * A handler for requests submitted to the standalone Master through the REST protocol. + * A handler for requests submitted to the standalone Master + * via the stable application submission REST protocol. */ private[spark] class StandaloneRestServerHandler( conf: SparkConf, @@ -141,9 +137,7 @@ private[spark] class StandaloneRestServerHandler( // Otherwise, once the driver is launched it will contact with the wrong server .set("spark.master", masterUrl) .set("spark.app.name", appName) - // Include main app resource on the executor classpath - // The corresponding behavior in client mode is handled in SparkSubmit - .set("spark.jars", jars.map(_ + ",").getOrElse("") + appResource) + jars.foreach { j => conf.set("spark.jars", j) } files.foreach { f => conf.set("spark.files", f) } driverExtraJavaOptions.foreach { j => conf.set("spark.driver.extraJavaOptions", j) } driverExtraClassPath.foreach { cp => conf.set("spark.driver.extraClassPath", cp) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index 0ff68c1584f6..47f97b4fdc77 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -29,7 +29,7 @@ import org.apache.spark.util.JsonProtocol */ private[spark] abstract class SubmitDriverRequestField extends SubmitRestProtocolField private[spark] object SubmitDriverRequestField - extends SubmitRestProtocolFieldCompanion[SubmitRestProtocolField] { + extends SubmitRestProtocolFieldCompanion[SubmitDriverRequestField] { case object ACTION extends SubmitDriverRequestField case object SPARK_VERSION extends SubmitDriverRequestField case object MESSAGE extends SubmitDriverRequestField @@ -59,12 +59,13 @@ private[spark] object SubmitDriverRequestField } /** - * A request sent to the cluster manager to submit a driver. + * A request sent to the cluster manager to submit a driver + * in the stable application submission REST protocol. */ private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST, - SubmitDriverRequestField.ACTION, - SubmitDriverRequestField.requiredFields) { + SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST, + SubmitDriverRequestField.ACTION, + SubmitDriverRequestField.requiredFields) { import SubmitDriverRequestField._ @@ -72,17 +73,18 @@ private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessag private val sparkProperties = new mutable.HashMap[String, String] private val environmentVariables = new mutable.HashMap[String, String] - // Special field setters + // Setters for special fields def appendAppArg(arg: String): Unit = { appArgs += arg } def setSparkProperty(k: String, v: String): Unit = { sparkProperties(k) = v } def setEnvironmentVariable(k: String, v: String): Unit = { environmentVariables(k) = v } - // Special field getters + // Getters for special fields def getAppArgs: Seq[String] = appArgs.clone() def getSparkProperties: Map[String, String] = sparkProperties.toMap def getEnvironmentVariables: Map[String, String] = environmentVariables.toMap // Include app args, spark properties, and environment variables in the JSON object + // The order imposed here is as follows: * < APP_ARGS < SPARK_PROPERTIES < ENVIRONMENT_VARIABLES override def toJsonObject: JObject = { val otherFields = super.toJsonObject.obj val appArgsJson = JArray(appArgs.map(JString).toList) @@ -103,7 +105,7 @@ private[spark] object SubmitDriverRequestMessage import SubmitDriverRequestField._ protected override def newMessage() = new SubmitDriverRequestMessage - protected override def fieldWithName(field: String) = SubmitDriverRequestField.withName(field) + protected override def fieldFromString(field: String) = SubmitDriverRequestField.fromString(field) /** * Process the given field and value appropriately based on the type of the field. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index bf3c03d2e333..70670fd6c9c7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -34,15 +34,16 @@ private[spark] object SubmitDriverResponseField } /** - * A message sent from the cluster manager in response to a SubmitDriverRequestMessage. + * A message sent from the cluster manager in response to a SubmitDriverRequestMessage + * in the stable application submission REST protocol. */ private[spark] class SubmitDriverResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE, - SubmitDriverResponseField.ACTION, - SubmitDriverResponseField.requiredFields) + SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE, + SubmitDriverResponseField.ACTION, + SubmitDriverResponseField.requiredFields) private[spark] object SubmitDriverResponseMessage extends SubmitRestProtocolMessageCompanion[SubmitDriverResponseMessage] { protected override def newMessage() = new SubmitDriverResponseMessage - protected override def fieldWithName(field: String) = SubmitDriverResponseField.withName(field) + protected override def fieldFromString(field: String) = SubmitDriverResponseField.fromString(field) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index 4bd0d022e607..b3e0d9e02fab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -28,7 +28,7 @@ import org.apache.spark.Logging import org.apache.spark.deploy.SparkSubmitArguments /** - * An abstract client that submits Spark applications using a stable REST protocol. + * An abstract client that submits applications using the stable REST protocol. * This client is intended to communicate with the SubmitRestServer. */ private[spark] abstract class SubmitRestClient extends Logging { @@ -78,6 +78,7 @@ private[spark] abstract class SubmitRestClient extends Logging { /** * Send the provided request in an HTTP message to the given URL. + * This assumes both the request and the response use the JSON format. * Return the response received from the REST server. */ private def sendHttp(url: URL, request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 496db227f210..6419520743eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -38,7 +38,7 @@ private[spark] object SubmitRestProtocolField { } /** - * All possible values of the ACTION field. + * All possible values of the ACTION field in a SubmitRestProtocolMessage. */ private[spark] object SubmitRestProtocolAction extends Enumeration { type SubmitRestProtocolAction = Value @@ -53,9 +53,9 @@ import SubmitRestProtocolAction.SubmitRestProtocolAction * A general message exchanged in the stable application submission REST protocol. * * The message is represented by a set of fields in the form of key value pairs. - * Each message must contain an ACTION field, which should have only one possible value - * for each type of message. For compatibility with older versions of Spark, existing - * fields must not be removed or modified, though new fields can be added as necessary. + * Each message must contain an ACTION field, which fully specifies the type of the message. + * For compatibility with older versions of Spark, existing fields must not be removed or + * modified, though new fields can be added as necessary. */ private[spark] abstract class SubmitRestProtocolMessage( action: SubmitRestProtocolAction, @@ -104,8 +104,8 @@ private[spark] abstract class SubmitRestProtocolMessage( } /** - * Validate that all required fields are set and the value of the action field is as expected. - * If any of these conditions are not met, throw an IllegalArgumentException. + * Validate that all required fields are set and the value of the ACTION field is as expected. + * If any of these conditions are not met, throw an exception. */ def validate(): this.type = { if (!fields.contains(actionField)) { @@ -153,17 +153,16 @@ private[spark] object SubmitRestProtocolMessage { import SubmitRestProtocolField._ import SubmitRestProtocolAction._ - /** Construct a SubmitRestProtocolMessage from its JSON representation. */ - def fromJson(json: String): SubmitRestProtocolMessage = { - fromJsonObject(parse(json).asInstanceOf[JObject]) - } - /** - * Construct a SubmitRestProtocolMessage from the given JSON object. + * Construct a SubmitRestProtocolMessage from its JSON representation. * This uses the ACTION field to determine the type of the message to reconstruct. + * If such a field does not exist, throw an exception. */ - protected def fromJsonObject(jsonObject: JObject): SubmitRestProtocolMessage = { - val action = getAction(jsonObject) + def fromJson(json: String): SubmitRestProtocolMessage = { + val jsonObject = parse(json).asInstanceOf[JObject] + val action = getAction(jsonObject).getOrElse { + throw new IllegalArgumentException(s"ACTION not found in message:\n$json") + } SubmitRestProtocolAction.withName(action) match { case SUBMIT_DRIVER_REQUEST => SubmitDriverRequestMessage.fromJsonObject(jsonObject) case SUBMIT_DRIVER_RESPONSE => SubmitDriverResponseMessage.fromJsonObject(jsonObject) @@ -177,36 +176,30 @@ private[spark] object SubmitRestProtocolMessage { /** * Extract the value of the ACTION field in the JSON object. - * If such a field does not exist in the JSON, throw an exception. */ - private def getAction(jsonObject: JObject): String = { + private def getAction(jsonObject: JObject): Option[String] = { jsonObject.obj .collect { case JField(k, JString(v)) if isActionField(k) => v } .headOption - .getOrElse { - throw new IllegalArgumentException( - "ACTION not found in message:\n" + pretty(render(jsonObject))) - } } } /** - * A trait that holds common methods for SubmitRestProtocolField companion objects. - * - * It is necessary to keep track of all fields that belong to this object in order to - * reconstruct the fields from their names. + * Common methods used by companion objects of SubmitRestProtocolField's subclasses. + * This keeps track of all fields that belong to this object in order to reconstruct + * the fields from their names. */ private[spark] trait SubmitRestProtocolFieldCompanion[FieldType <: SubmitRestProtocolField] { val requiredFields: Seq[FieldType] val optionalFields: Seq[FieldType] - /** Listing of all fields indexed by the field's string representation. */ + // Listing of all fields indexed by the field's string representation private lazy val allFieldsMap: Map[String, FieldType] = { (requiredFields ++ optionalFields).map { f => (f.toString, f) }.toMap } - /** Return a SubmitRestProtocolField from its string representation. */ - def withName(field: String): FieldType = { + /** Return the appropriate SubmitRestProtocolField from its string representation. */ + def fromString(field: String): FieldType = { allFieldsMap.get(field).getOrElse { throw new IllegalArgumentException(s"Unknown field $field") } @@ -214,7 +207,7 @@ private[spark] trait SubmitRestProtocolFieldCompanion[FieldType <: SubmitRestPro } /** - * A trait that holds common methods for SubmitRestProtocolMessage companion objects. + * Common methods used by companion objects of SubmitRestProtocolMessage's subclasses. */ private[spark] trait SubmitRestProtocolMessageCompanion[MessageType <: SubmitRestProtocolMessage] extends Logging { @@ -225,11 +218,11 @@ private[spark] trait SubmitRestProtocolMessageCompanion[MessageType <: SubmitRes protected def newMessage(): MessageType /** Return a field of the relevant type from the field's string representation. */ - protected def fieldWithName(field: String): SubmitRestProtocolField + protected def fieldFromString(field: String): SubmitRestProtocolField /** - * Process the given field and value appropriately based on the type of the field. - * The default behavior only considers fields that have flat values and ignores other fields. + * Populate the given field and value in the provided message. + * The default behavior only handles fields that have flat values and ignores other fields. * If the subclass uses fields with nested values, it should override this method appropriately. */ protected def handleField( @@ -252,7 +245,7 @@ private[spark] trait SubmitRestProtocolMessageCompanion[MessageType <: SubmitRes .filter { case (k, _) => !isActionField(k) } .flatMap { case (k, v) => try { - Some((fieldWithName(k), v)) + Some((fieldFromString(k), v)) } catch { case e: IllegalArgumentException => logWarning(s"Unexpected field $k in message ${Utils.getFormattedClassName(this)}") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 6ad4be4f26ae..980d6089b676 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -27,25 +27,40 @@ import com.google.common.base.Charsets import org.eclipse.jetty.server.{Request, Server} import org.eclipse.jetty.server.handler.AbstractHandler -import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging} +import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging, SparkConf} import org.apache.spark.util.Utils /** - * A server that responds to requests submitted by the SubmitRestClient. + * An abstract server that responds to requests submitted by the SubmitRestClient + * in the stable application submission REST protocol. */ -private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int) { +private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, conf: SparkConf) + extends Logging { + protected val handler: SubmitRestServerHandler + private var _server: Option[Server] = None - /** Start the server. */ def start(): Unit = { + val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, conf) + _server = Some(server) + logInfo(s"Started REST server for submitting applications on port $boundPort") + } + + def stop(): Unit = { + _server.foreach(_.stop()) + } + + private def doStart(startPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(host, requestedPort)) server.setHandler(handler) server.start() + val boundPort = server.getConnectors()(0).getLocalPort + (server, boundPort) } } /** - * A handler for requests submitted via the stable REST protocol for submitting applications. + * An abstract handler for requests submitted via the stable application submission REST protocol. * This represents the main handler used in the SubmitRestServer. */ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler with Logging { @@ -53,7 +68,10 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage - /** Handle a request submitted by the SubmitRestClient. */ + /** + * Handle a request submitted by the SubmitRestClient. + * This assumes both the request and the response use the JSON format. + */ override def handle( target: String, baseRequest: Request, @@ -82,9 +100,9 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi */ private def constructResponseMessage( request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { - // If the request is sent via the SubmitRestClient, it should have already been - // validated remotely. In case this is not true, validate the request here to guard - // against potential NPEs. If validation fails, return an ERROR message to the sender. + // If the request is sent via the SubmitRestClient, it should have already been validated + // remotely. In case this is not true, validate the request here again to guard against + // potential NPEs. If validation fails, send an error message back to the sender. try { request.validate() request match { From 120ab9d33484ccc20f45aee6272daeca5ebcc878 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 21 Jan 2015 14:55:35 -0800 Subject: [PATCH 07/48] Support kill and request driver status through SparkSubmit --- .../org/apache/spark/deploy/SparkSubmit.scala | 34 ++++++++++- .../spark/deploy/SparkSubmitArguments.scala | 61 +++++++++++++++++-- 2 files changed, 88 insertions(+), 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 4aba0feefdf2..1e3c7c2f1bb1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -27,6 +27,15 @@ import org.apache.spark.executor.ExecutorURLClassLoader import org.apache.spark.util.Utils import org.apache.spark.deploy.rest.StandaloneRestClient +/** + * Whether to submit, kill, or request the status of an application. + * The latter two operations are currently supported only for standalone cluster mode. + */ +private[spark] object Action extends Enumeration { + type Action = Value + val SUBMIT, KILL, REQUEST_STATUS = Value +} + /** * Main gateway of launching a Spark application. * @@ -73,11 +82,30 @@ object SparkSubmit { if (appArgs.verbose) { printStream.println(appArgs) } - launch(appArgs) + appArgs.action match { + case Action.SUBMIT => submit(appArgs) + case Action.KILL => kill(appArgs) + case Action.REQUEST_STATUS => requestStatus(appArgs) + } + } + + /** + * Kill an existing driver using the stable REST protocol. Standalone cluster mode only. + */ + private[spark] def kill(args: SparkSubmitArguments): Unit = { + new StandaloneRestClient().killDriver(args.master, args.driverToKill) + } + + /** + * Request the status of an existing driver using the stable REST protocol. + * Standalone cluster mode only. + */ + private[spark] def requestStatus(args: SparkSubmitArguments): Unit = { + new StandaloneRestClient().requestDriverStatus(args.master, args.driverToRequestStatusFor) } /** - * Launch the application using the provided parameters. + * Submit the application using the provided parameters. * * This runs in two steps. First, we prepare the launch environment by setting up * the appropriate classpath, system properties, and application arguments for @@ -89,7 +117,7 @@ object SparkSubmit { * main method of a child class. Instead, we pass the submit parameters directly to * a REST client, which will submit the application using the stable REST protocol. */ - private[spark] def launch(args: SparkSubmitArguments): Unit = { + private[spark] def submit(args: SparkSubmitArguments): Unit = { // Environment needed to launch the child main class val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 310b34a92633..48f0ea39a095 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -23,6 +23,7 @@ import java.util.jar.JarFile import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.util.Utils +import org.apache.spark.deploy.Action.Action /** * Parses and encapsulates arguments from the spark-submit script. @@ -39,8 +40,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var driverExtraClassPath: String = null var driverExtraLibraryPath: String = null var driverExtraJavaOptions: String = null - var driverCores: String = null - var supervise: Boolean = false var queue: String = null var numExecutors: String = null var files: String = null @@ -55,6 +54,23 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var pyFiles: String = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() + // Standalone cluster mode only + var supervise: Boolean = false + var driverCores: String = null + var driverToKill: String = null + var driverToRequestStatusFor: String = null + + def action: Action = { + (driverToKill, driverToRequestStatusFor) match { + case (null, null) => Action.SUBMIT + case (_, null) => Action.KILL + case (null, _) => Action.REQUEST_STATUS + case _ => SparkSubmit.printErrorAndExit( + "Requested to both kill and request status for a driver. Choose only one.") + null // never reached + } + } + /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() @@ -79,7 +95,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St // Use `sparkProperties` map along with env vars to fill in any missing parameters loadEnvironmentArguments() - checkRequiredArguments() + validateArguments() /** * Merge values from the default properties file with those specified through --conf. @@ -171,7 +187,15 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ - private def checkRequiredArguments(): Unit = { + private def validateArguments(): Unit = { + action match { + case Action.SUBMIT => validateSubmitArguments() + case Action.KILL => validateKillArguments() + case Action.REQUEST_STATUS => validateStatusRequestArguments() + } + } + + private def validateSubmitArguments(): Unit = { if (args.length == 0) { printUsageAndExit(-1) } @@ -206,6 +230,25 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } + private def validateKillArguments(): Unit = { + if (!master.startsWith("spark://") || deployMode != "cluster") { + SparkSubmit.printErrorAndExit("Killing drivers is only supported in standalone cluster mode") + } + if (driverToKill == null) { + SparkSubmit.printErrorAndExit("Please specify a driver to kill") + } + } + + private def validateStatusRequestArguments(): Unit = { + if (!master.startsWith("spark://") || deployMode != "cluster") { + SparkSubmit.printErrorAndExit( + "Requesting driver statuses is only supported in standalone cluster mode") + } + if (driverToRequestStatusFor == null) { + SparkSubmit.printErrorAndExit("Please specify a driver to request status for") + } + } + override def toString = { s"""Parsed arguments: | master $master @@ -312,6 +355,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St propertiesFile = value parse(tail) + case ("--kill") :: value :: tail => + driverToKill = value + parse(tail) + + case ("--status") :: value :: tail => + driverToRequestStatusFor = value + parse(tail) + case ("--supervise") :: tail => supervise = true parse(tail) @@ -410,6 +461,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). | --supervise If given, restarts the driver on failure. + | --kill DRIVER_ID If given, kills the driver specified. + | --status DRIVER_ID If given, requests the status of the driver specified. | | Spark standalone and Mesos only: | --total-executor-cores NUM Total cores for all executors. From b44e103b78b36fadd887dc4b894027a03069b1f7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 21 Jan 2015 17:19:33 -0800 Subject: [PATCH 08/48] Implement status requests + fix validation behavior This commit makes the StandaloneRestServer actually handle status requests. The existing polling behavior from o.a.s.deploy.Client is also implemented in the StandaloneRestClient and amended. Additionally, the validation behavior was confusing before this commit. Previously the error message would seem to indicate that the user constructed a malformed message even if the message was constructed on the server side. This commit ensures that the error message is different for these two cases. --- .../org/apache/spark/deploy/SparkSubmit.scala | 3 +- .../rest/DriverStatusRequestMessage.scala | 2 +- .../rest/DriverStatusResponseMessage.scala | 9 +-- .../spark/deploy/rest/ErrorMessage.scala | 2 +- .../rest/KillDriverRequestMessage.scala | 2 +- .../rest/KillDriverResponseMessage.scala | 2 +- .../deploy/rest/StandaloneRestClient.scala | 61 ++++++++++++++++++- .../deploy/rest/StandaloneRestServer.scala | 21 ++++--- .../rest/SubmitDriverRequestMessage.scala | 2 +- .../rest/SubmitDriverResponseMessage.scala | 2 +- .../spark/deploy/rest/SubmitRestClient.scala | 46 ++++++++------ .../spark/deploy/rest/SubmitRestServer.scala | 37 ++++++----- 12 files changed, 130 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 1e3c7c2f1bb1..30b982822dba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -49,8 +49,7 @@ object SparkSubmit { private val STANDALONE = 2 private val MESOS = 4 private val LOCAL = 8 - private val REST = 16 - private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL | REST + private val ALL_CLUSTER_MGRS = YARN | STANDALONE | MESOS | LOCAL // Deploy modes private val CLIENT = 1 diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala index 57f79554151e..d435687606e3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -44,5 +44,5 @@ private[spark] class DriverStatusRequestMessage extends SubmitRestProtocolMessag private[spark] object DriverStatusRequestMessage extends SubmitRestProtocolMessageCompanion[DriverStatusRequestMessage] { protected override def newMessage() = new DriverStatusRequestMessage - protected override def fieldFromString(field: String) = DriverStatusRequestField.fromString(field) + protected override def fieldFromString(f: String) = DriverStatusRequestField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala index 42c64dc60175..a0264643890a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -28,12 +28,13 @@ private[spark] object DriverStatusResponseField case object MESSAGE extends DriverStatusResponseField case object MASTER extends DriverStatusResponseField case object DRIVER_ID extends DriverStatusResponseField + case object SUCCESS extends DriverStatusResponseField + // Standalone specific fields case object DRIVER_STATE extends DriverStatusResponseField case object WORKER_ID extends DriverStatusResponseField case object WORKER_HOST_PORT extends DriverStatusResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, - MASTER, DRIVER_ID, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) - override val optionalFields = Seq.empty + override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, DRIVER_ID, SUCCESS) + override val optionalFields = Seq(MESSAGE, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) } /** @@ -48,5 +49,5 @@ private[spark] class DriverStatusResponseMessage extends SubmitRestProtocolMessa private[spark] object DriverStatusResponseMessage extends SubmitRestProtocolMessageCompanion[DriverStatusResponseMessage] { protected override def newMessage() = new DriverStatusResponseMessage - protected override def fieldFromString(field: String) = DriverStatusResponseField.fromString(field) + protected override def fieldFromString(f: String) = DriverStatusResponseField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala index 04a298d98a34..aefd7b60d32a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala @@ -39,5 +39,5 @@ private[spark] class ErrorMessage extends SubmitRestProtocolMessage( private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion[ErrorMessage] { protected override def newMessage() = new ErrorMessage - protected override def fieldFromString(field: String) = ErrorField.fromString(field) + protected override def fieldFromString(f: String) = ErrorField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala index 3245058ce4ba..3353bfba5a69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -44,5 +44,5 @@ private[spark] class KillDriverRequestMessage extends SubmitRestProtocolMessage( private[spark] object KillDriverRequestMessage extends SubmitRestProtocolMessageCompanion[KillDriverRequestMessage] { protected override def newMessage() = new KillDriverRequestMessage - protected override def fieldFromString(field: String) = KillDriverRequestField.fromString(field) + protected override def fieldFromString(f: String) = KillDriverRequestField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index 92db6cfa2d64..974fcb9936fc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -45,5 +45,5 @@ private[spark] class KillDriverResponseMessage extends SubmitRestProtocolMessage private[spark] object KillDriverResponseMessage extends SubmitRestProtocolMessageCompanion[KillDriverResponseMessage] { protected override def newMessage() = new KillDriverResponseMessage - protected override def fieldFromString(field: String) = KillDriverResponseField.fromString(field) + protected override def fieldFromString(f: String) = KillDriverResponseField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 03eaa93f0d33..43164ae3a4c8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -28,6 +28,58 @@ import org.apache.spark.util.Utils * This client is intended to communicate with the StandaloneRestServer. Cluster mode only. */ private[spark] class StandaloneRestClient extends SubmitRestClient { + import StandaloneRestClient._ + + /** + * Request that the REST server submit a driver specified by the provided arguments. + * + * If the driver was successfully submitted, this polls the status of the driver that was + * just submitted and reports it to the user. Otherwise, if the submission was unsuccessful, + * this reports failure and logs an error message provided by the REST server. + */ + override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponseMessage = { + import SubmitDriverResponseField._ + val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponseMessage] + val submitSuccess = submitResponse.getFieldNotNull(SUCCESS).toBoolean + if (submitSuccess) { + val driverId = submitResponse.getFieldNotNull(DRIVER_ID) + logInfo(s"Driver successfully submitted as $driverId. Polling driver state...") + pollSubmittedDriverStatus(args.master, driverId) + } else { + val submitMessage = submitResponse.getFieldNotNull(MESSAGE) + logError(s"Application submission failed: $submitMessage") + } + submitResponse + } + + /** + * Poll the status of the driver that was just submitted and report it. + * This retries up to a fixed number of times until giving up. + */ + private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { + import DriverStatusResponseField._ + (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => + val statusResponse = requestDriverStatus(master, driverId) + .asInstanceOf[DriverStatusResponseMessage] + val statusSuccess = statusResponse.getFieldNotNull(SUCCESS).toBoolean + if (statusSuccess) { + val driverState = statusResponse.getFieldNotNull(DRIVER_STATE) + val workerId = statusResponse.getFieldOption(WORKER_ID) + val workerHostPort = statusResponse.getFieldOption(WORKER_HOST_PORT) + val exception = statusResponse.getFieldOption(MESSAGE) + logInfo(s"State of driver $driverId is now $driverState.") + // Log worker node, if present + (workerId, workerHostPort) match { + case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") + case _ => + } + // Log exception stack trace, if present + exception.foreach { e => logError(e) } + return + } + } + logError(s"Error: Master did not recognize driver $driverId.") + } /** Construct a submit driver request message. */ override protected def constructSubmitRequest( @@ -54,7 +106,7 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { args.childArgs.foreach(message.appendAppArg) args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } // TODO: send special environment variables? - message.validate() + message } /** Construct a kill driver request message. */ @@ -66,7 +118,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setField(SPARK_VERSION, sparkVersion) .setField(MASTER, master) .setField(DRIVER_ID, driverId) - .validate() } /** Construct a driver status request message. */ @@ -78,7 +129,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setField(SPARK_VERSION, sparkVersion) .setField(MASTER, master) .setField(DRIVER_ID, driverId) - .validate() } /** Throw an exception if this is not standalone mode. */ @@ -101,3 +151,8 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { new URL("http://" + master.stripPrefix("spark://")) } } + +private object StandaloneRestClient { + val REPORT_DRIVER_STATUS_INTERVAL = 1000 + val REPORT_DRIVER_STATUS_MAX_TRIES = 10 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 7916029517cc..563ee1c25144 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -67,7 +67,6 @@ private[spark] class StandaloneRestServerHandler( .setField(MASTER, masterUrl) .setField(SUCCESS, response.success.toString) .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) - .validate() } /** Handle a request to kill a driver. */ @@ -83,23 +82,29 @@ private[spark] class StandaloneRestServerHandler( .setField(MASTER, masterUrl) .setField(DRIVER_ID, driverId) .setField(SUCCESS, response.success.toString) - .validate() } /** Handle a request for a driver's status. */ override protected def handleStatus( request: DriverStatusRequestMessage): DriverStatusResponseMessage = { import DriverStatusResponseField._ - // TODO: Actually look up the status of the driver - val master = request.getField(DriverStatusRequestField.MASTER) val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) - val driverState = "HEALTHY" + val response = AkkaUtils.askWithReply[DriverStatusResponse]( + RequestDriverStatus(driverId), masterActor, askTimeout) + // Format exception nicely, if it exists + val message = response.exception.map { e => + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"Exception from the cluster:\n$e\n$stackTraceString" + } new DriverStatusResponseMessage() .setField(SPARK_VERSION, sparkVersion) - .setField(MASTER, master) + .setField(MASTER, masterUrl) .setField(DRIVER_ID, driverId) - .setField(DRIVER_STATE, driverState) - .validate() + .setField(SUCCESS, response.found.toString) + .setFieldIfNotNull(DRIVER_STATE, response.state.map(_.toString).orNull) + .setFieldIfNotNull(WORKER_ID, response.workerId.orNull) + .setFieldIfNotNull(WORKER_HOST_PORT, response.workerHostPort.orNull) + .setFieldIfNotNull(MESSAGE, message.orNull) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index 47f97b4fdc77..1ce867febcf9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -105,7 +105,7 @@ private[spark] object SubmitDriverRequestMessage import SubmitDriverRequestField._ protected override def newMessage() = new SubmitDriverRequestMessage - protected override def fieldFromString(field: String) = SubmitDriverRequestField.fromString(field) + protected override def fieldFromString(f: String) = SubmitDriverRequestField.fromString(f) /** * Process the given field and value appropriately based on the type of the field. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index 70670fd6c9c7..455170766037 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -45,5 +45,5 @@ private[spark] class SubmitDriverResponseMessage extends SubmitRestProtocolMessa private[spark] object SubmitDriverResponseMessage extends SubmitRestProtocolMessageCompanion[SubmitDriverResponseMessage] { protected override def newMessage() = new SubmitDriverResponseMessage - protected override def fieldFromString(field: String) = SubmitDriverResponseField.fromString(field) + protected override def fieldFromString(f: String) = SubmitDriverResponseField.fromString(f) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index b3e0d9e02fab..513c17deee89 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -17,14 +17,14 @@ package org.apache.spark.deploy.rest -import java.io.DataOutputStream +import java.io.{DataOutputStream, FileNotFoundException} import java.net.{HttpURLConnection, URL} import scala.io.Source import com.google.common.base.Charsets -import org.apache.spark.Logging +import org.apache.spark.{Logging, SparkException} import org.apache.spark.deploy.SparkSubmitArguments /** @@ -33,8 +33,8 @@ import org.apache.spark.deploy.SparkSubmitArguments */ private[spark] abstract class SubmitRestClient extends Logging { - /** Request that the REST server submits a driver specified by the provided arguments. */ - def submitDriver(args: SparkSubmitArguments): Unit = { + /** Request that the REST server submit a driver specified by the provided arguments. */ + def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolMessage = { validateSubmitArguments(args) val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) @@ -42,8 +42,8 @@ private[spark] abstract class SubmitRestClient extends Logging { sendHttp(url, request) } - /** Request that the REST server kills the specified driver. */ - def killDriver(master: String, driverId: String): Unit = { + /** Request that the REST server kill the specified driver. */ + def killDriver(master: String, driverId: String): SubmitRestProtocolMessage = { validateMaster(master) val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) @@ -52,7 +52,7 @@ private[spark] abstract class SubmitRestClient extends Logging { } /** Request the status of the specified driver from the REST server. */ - def requestDriverStatus(master: String, driverId: String): Unit = { + def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolMessage = { validateMaster(master) val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) @@ -82,18 +82,24 @@ private[spark] abstract class SubmitRestClient extends Logging { * Return the response received from the REST server. */ private def sendHttp(url: URL, request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { - val conn = url.openConnection().asInstanceOf[HttpURLConnection] - conn.setRequestMethod("POST") - conn.setRequestProperty("Content-Type", "application/json") - conn.setRequestProperty("charset", "utf-8") - conn.setDoOutput(true) - val requestJson = request.toJson - logDebug(s"Sending the following request to the REST server:\n$requestJson") - val out = new DataOutputStream(conn.getOutputStream) - out.write(requestJson.getBytes(Charsets.UTF_8)) - out.close() - val responseJson = Source.fromInputStream(conn.getInputStream).mkString - logDebug(s"Response from the REST server:\n$responseJson") - SubmitRestProtocolMessage.fromJson(responseJson) + try { + val conn = url.openConnection().asInstanceOf[HttpURLConnection] + conn.setRequestMethod("POST") + conn.setRequestProperty("Content-Type", "application/json") + conn.setRequestProperty("charset", "utf-8") + conn.setDoOutput(true) + request.validate() + val requestJson = request.toJson + logDebug(s"Sending the following request to the REST server:\n$requestJson") + val out = new DataOutputStream(conn.getOutputStream) + out.write(requestJson.getBytes(Charsets.UTF_8)) + out.close() + val responseJson = Source.fromInputStream(conn.getInputStream).mkString + logDebug(s"Response from the REST server:\n$responseJson") + SubmitRestProtocolMessage.fromJson(responseJson) + } catch { + case e: FileNotFoundException => + throw new SparkException(s"Unable to connect to REST server $url", e) + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 980d6089b676..c659dfddbf6a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -100,23 +100,29 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi */ private def constructResponseMessage( request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { - // If the request is sent via the SubmitRestClient, it should have already been validated - // remotely. In case this is not true, validate the request here again to guard against - // potential NPEs. If validation fails, send an error message back to the sender. - try { - request.validate() - request match { - case submit: SubmitDriverRequestMessage => handleSubmit(submit) - case kill: KillDriverRequestMessage => handleKill(kill) - case status: DriverStatusRequestMessage => handleStatus(status) - case unexpected => handleError( - s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + // Validate the request message to ensure that it is correctly constructed. If the request + // is sent via the SubmitRestClient, it should have already been validated remotely. In case + // this is not true, do it again here to guard against potential NPEs. If validation fails, + // send an error message back to the sender. + val response = + try { + request.validate() + request match { + case submit: SubmitDriverRequestMessage => handleSubmit(submit) + case kill: KillDriverRequestMessage => handleKill(kill) + case status: DriverStatusRequestMessage => handleStatus(status) + case unexpected => handleError( + s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") + } + } catch { + case e: IllegalArgumentException => handleError(e.getMessage) } + // Validate the response message to ensure that it is correctly constructed. If it is not, + // propagate the exception back to the client and signal that it is a server error. + try { + response.validate() } catch { - // Propagate exception to user in an ErrorMessage. - // Note that the construction of the error message itself may throw an exception. - // In this case, let the higher level caller take care of this request. - case e: IllegalArgumentException => handleError(e.getMessage) + case e: IllegalArgumentException => handleError(s"Internal server error: ${e.getMessage}") } } @@ -126,6 +132,5 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi new ErrorMessage() .setField(SPARK_VERSION, sparkVersion) .setField(MESSAGE, message) - .validate() } } From 51c5ca6d8ef448f7b6181c684fa3ee3794f0d6b8 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 21 Jan 2015 17:43:43 -0800 Subject: [PATCH 09/48] Distinguish client and server side Spark versions Otherwise it's a little ambiguous what we mean by SPARK_VERSION. --- .../spark/deploy/rest/DriverStatusRequestMessage.scala | 4 ++-- .../deploy/rest/DriverStatusResponseMessage.scala | 4 ++-- .../org/apache/spark/deploy/rest/ErrorMessage.scala | 7 ++++--- .../spark/deploy/rest/KillDriverRequestMessage.scala | 4 ++-- .../spark/deploy/rest/KillDriverResponseMessage.scala | 4 ++-- .../spark/deploy/rest/StandaloneRestClient.scala | 6 +++--- .../spark/deploy/rest/StandaloneRestServer.scala | 6 +++--- .../spark/deploy/rest/SubmitDriverRequestMessage.scala | 4 ++-- .../deploy/rest/SubmitDriverResponseMessage.scala | 4 ++-- .../spark/deploy/rest/SubmitRestProtocolMessage.scala | 10 +++++++--- .../apache/spark/deploy/rest/SubmitRestServer.scala | 2 +- 11 files changed, 30 insertions(+), 25 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala index d435687606e3..e6b513bd4b1c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -24,11 +24,11 @@ private[spark] abstract class DriverStatusRequestField extends SubmitRestProtoco private[spark] object DriverStatusRequestField extends SubmitRestProtocolFieldCompanion[DriverStatusRequestField] { case object ACTION extends DriverStatusRequestField - case object SPARK_VERSION extends DriverStatusRequestField + case object CLIENT_SPARK_VERSION extends DriverStatusRequestField case object MESSAGE extends DriverStatusRequestField case object MASTER extends DriverStatusRequestField case object DRIVER_ID extends DriverStatusRequestField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, DRIVER_ID) + override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, MASTER, DRIVER_ID) override val optionalFields = Seq(MESSAGE) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala index a0264643890a..7a65f31c711f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -24,7 +24,7 @@ private[spark] abstract class DriverStatusResponseField extends SubmitRestProtoc private[spark] object DriverStatusResponseField extends SubmitRestProtocolFieldCompanion[DriverStatusResponseField] { case object ACTION extends DriverStatusResponseField - case object SPARK_VERSION extends DriverStatusResponseField + case object SERVER_SPARK_VERSION extends DriverStatusResponseField case object MESSAGE extends DriverStatusResponseField case object MASTER extends DriverStatusResponseField case object DRIVER_ID extends DriverStatusResponseField @@ -33,7 +33,7 @@ private[spark] object DriverStatusResponseField case object DRIVER_STATE extends DriverStatusResponseField case object WORKER_ID extends DriverStatusResponseField case object WORKER_HOST_PORT extends DriverStatusResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, DRIVER_ID, SUCCESS) + override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MASTER, DRIVER_ID, SUCCESS) override val optionalFields = Seq(MESSAGE, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala index aefd7b60d32a..88d33462dd44 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala @@ -23,14 +23,15 @@ package org.apache.spark.deploy.rest private[spark] abstract class ErrorField extends SubmitRestProtocolField private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion[ErrorField] { case object ACTION extends ErrorField - case object SPARK_VERSION extends ErrorField + case object SERVER_SPARK_VERSION extends ErrorField case object MESSAGE extends ErrorField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE) + override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE) override val optionalFields = Seq.empty } /** - * An error message exchanged in the stable application submission REST protocol. + * An error message sent from the cluster manager + * in the stable application submission REST protocol. */ private[spark] class ErrorMessage extends SubmitRestProtocolMessage( SubmitRestProtocolAction.ERROR, diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala index 3353bfba5a69..ae3c62d496b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -24,11 +24,11 @@ private[spark] abstract class KillDriverRequestField extends SubmitRestProtocolF private[spark] object KillDriverRequestField extends SubmitRestProtocolFieldCompanion[KillDriverRequestField] { case object ACTION extends KillDriverRequestField - case object SPARK_VERSION extends KillDriverRequestField + case object CLIENT_SPARK_VERSION extends KillDriverRequestField case object MESSAGE extends KillDriverRequestField case object MASTER extends KillDriverRequestField case object DRIVER_ID extends KillDriverRequestField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, DRIVER_ID) + override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, MASTER, DRIVER_ID) override val optionalFields = Seq(MESSAGE) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index 974fcb9936fc..b5dc4ee557cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -24,12 +24,12 @@ private[spark] abstract class KillDriverResponseField extends SubmitRestProtocol private[spark] object KillDriverResponseField extends SubmitRestProtocolFieldCompanion[KillDriverResponseField] { case object ACTION extends KillDriverResponseField - case object SPARK_VERSION extends KillDriverResponseField + case object SERVER_SPARK_VERSION extends KillDriverResponseField case object MESSAGE extends KillDriverResponseField case object MASTER extends KillDriverResponseField case object DRIVER_ID extends KillDriverResponseField case object SUCCESS extends KillDriverResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, SUCCESS) + override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, SUCCESS) override val optionalFields = Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 43164ae3a4c8..4f5c31c080fb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -88,7 +88,7 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { val dm = Option(args.driverMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull val em = Option(args.executorMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull val message = new SubmitDriverRequestMessage() - .setField(SPARK_VERSION, sparkVersion) + .setField(CLIENT_SPARK_VERSION, sparkVersion) .setField(MASTER, args.master) .setField(APP_NAME, args.name) .setField(APP_RESOURCE, args.primaryResource) @@ -115,7 +115,7 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { driverId: String): KillDriverRequestMessage = { import KillDriverRequestField._ new KillDriverRequestMessage() - .setField(SPARK_VERSION, sparkVersion) + .setField(CLIENT_SPARK_VERSION, sparkVersion) .setField(MASTER, master) .setField(DRIVER_ID, driverId) } @@ -126,7 +126,7 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { driverId: String): DriverStatusRequestMessage = { import DriverStatusRequestField._ new DriverStatusRequestMessage() - .setField(SPARK_VERSION, sparkVersion) + .setField(CLIENT_SPARK_VERSION, sparkVersion) .setField(MASTER, master) .setField(DRIVER_ID, driverId) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 563ee1c25144..5a5afcc22833 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -62,7 +62,7 @@ private[spark] class StandaloneRestServerHandler( val response = AkkaUtils.askWithReply[SubmitDriverResponse]( RequestSubmitDriver(driverDescription), masterActor, askTimeout) new SubmitDriverResponseMessage() - .setField(SPARK_VERSION, sparkVersion) + .setField(SERVER_SPARK_VERSION, sparkVersion) .setField(MESSAGE, response.message) .setField(MASTER, masterUrl) .setField(SUCCESS, response.success.toString) @@ -77,7 +77,7 @@ private[spark] class StandaloneRestServerHandler( val response = AkkaUtils.askWithReply[KillDriverResponse]( RequestKillDriver(driverId), masterActor, askTimeout) new KillDriverResponseMessage() - .setField(SPARK_VERSION, sparkVersion) + .setField(SERVER_SPARK_VERSION, sparkVersion) .setField(MESSAGE, response.message) .setField(MASTER, masterUrl) .setField(DRIVER_ID, driverId) @@ -97,7 +97,7 @@ private[spark] class StandaloneRestServerHandler( s"Exception from the cluster:\n$e\n$stackTraceString" } new DriverStatusResponseMessage() - .setField(SPARK_VERSION, sparkVersion) + .setField(SERVER_SPARK_VERSION, sparkVersion) .setField(MASTER, masterUrl) .setField(DRIVER_ID, driverId) .setField(SUCCESS, response.found.toString) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index 1ce867febcf9..b4c29d171b73 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -31,7 +31,7 @@ private[spark] abstract class SubmitDriverRequestField extends SubmitRestProtoco private[spark] object SubmitDriverRequestField extends SubmitRestProtocolFieldCompanion[SubmitDriverRequestField] { case object ACTION extends SubmitDriverRequestField - case object SPARK_VERSION extends SubmitDriverRequestField + case object CLIENT_SPARK_VERSION extends SubmitDriverRequestField case object MESSAGE extends SubmitDriverRequestField case object MASTER extends SubmitDriverRequestField case object APP_NAME extends SubmitDriverRequestField @@ -51,7 +51,7 @@ private[spark] object SubmitDriverRequestField case object APP_ARGS extends SubmitDriverRequestField case object SPARK_PROPERTIES extends SubmitDriverRequestField case object ENVIRONMENT_VARIABLES extends SubmitDriverRequestField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MASTER, APP_NAME, APP_RESOURCE) + override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, MASTER, APP_NAME, APP_RESOURCE) override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES, APP_ARGS, SPARK_PROPERTIES, diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index 455170766037..7b3524b10f6c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -24,12 +24,12 @@ private[spark] abstract class SubmitDriverResponseField extends SubmitRestProtoc private[spark] object SubmitDriverResponseField extends SubmitRestProtocolFieldCompanion[SubmitDriverResponseField] { case object ACTION extends SubmitDriverResponseField - case object SPARK_VERSION extends SubmitDriverResponseField + case object SERVER_SPARK_VERSION extends SubmitDriverResponseField case object MESSAGE extends SubmitDriverResponseField case object MASTER extends SubmitDriverResponseField case object SUCCESS extends SubmitDriverResponseField case object DRIVER_ID extends SubmitDriverResponseField - override val requiredFields = Seq(ACTION, SPARK_VERSION, MESSAGE, MASTER, SUCCESS) + override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, MASTER, SUCCESS) override val optionalFields = Seq(DRIVER_ID) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 6419520743eb..2969886f9094 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -28,12 +28,16 @@ import org.apache.spark.util.Utils /** * A field used in a SubmitRestProtocolMessage. - * Three special fields ACTION, SPARK_VERSION, and MESSAGE are common across all messages. + * There are a few special fields: + * - ACTION entirely specifies the type of the message and is required in all messages + * - MESSAGE contains arbitrary messages and is common, but not required, in all messages + * - CLIENT_SPARK_VERSION is required in all messages sent from the client + * - SERVER_SPARK_VERSION is required in all messages sent from the server */ private[spark] abstract class SubmitRestProtocolField private[spark] object SubmitRestProtocolField { def isActionField(field: String): Boolean = field == "ACTION" - def isSparkVersionField(field: String): Boolean = field == "SPARK_VERSION" + def isSparkVersionField(field: String): Boolean = field.endsWith("_SPARK_VERSION") def isMessageField(field: String): Boolean = field == "MESSAGE" } @@ -129,7 +133,7 @@ private[spark] abstract class SubmitRestProtocolMessage( /** * Return a JObject that represents the JSON form of this message. - * This orders the fields by ACTION (first) < SPARK_VERSION < MESSAGE < * (last) + * This orders the fields by ACTION (first) < SERVER_SPARK_VERSION < MESSAGE < * (last) * and ignores fields with null values. */ protected def toJsonObject: JObject = { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index c659dfddbf6a..1cf02c0efd3b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -130,7 +130,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi private def handleError(message: String): ErrorMessage = { import ErrorField._ new ErrorMessage() - .setField(SPARK_VERSION, sparkVersion) + .setField(SERVER_SPARK_VERSION, sparkVersion) .setField(MESSAGE, message) } } From 9e21b7294ed6fd4cc62691f52ccf2665dc18536c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 23 Jan 2015 11:39:48 -0800 Subject: [PATCH 10/48] Action -> SparkSubmitAction (minor) --- .../org/apache/spark/deploy/SparkSubmit.scala | 10 +++++----- .../spark/deploy/SparkSubmitArguments.scala | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 30b982822dba..842fec8ea952 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -31,8 +31,8 @@ import org.apache.spark.deploy.rest.StandaloneRestClient * Whether to submit, kill, or request the status of an application. * The latter two operations are currently supported only for standalone cluster mode. */ -private[spark] object Action extends Enumeration { - type Action = Value +private[spark] object SparkSubmitAction extends Enumeration { + type SparkSubmitAction = Value val SUBMIT, KILL, REQUEST_STATUS = Value } @@ -82,9 +82,9 @@ object SparkSubmit { printStream.println(appArgs) } appArgs.action match { - case Action.SUBMIT => submit(appArgs) - case Action.KILL => kill(appArgs) - case Action.REQUEST_STATUS => requestStatus(appArgs) + case SparkSubmitAction.SUBMIT => submit(appArgs) + case SparkSubmitAction.KILL => kill(appArgs) + case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 48f0ea39a095..bce20c2b92e7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -22,8 +22,8 @@ import java.util.jar.JarFile import scala.collection.mutable.{ArrayBuffer, HashMap} +import org.apache.spark.deploy.SparkSubmitAction.SparkSubmitAction import org.apache.spark.util.Utils -import org.apache.spark.deploy.Action.Action /** * Parses and encapsulates arguments from the spark-submit script. @@ -60,11 +60,11 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var driverToKill: String = null var driverToRequestStatusFor: String = null - def action: Action = { + def action: SparkSubmitAction = { (driverToKill, driverToRequestStatusFor) match { - case (null, null) => Action.SUBMIT - case (_, null) => Action.KILL - case (null, _) => Action.REQUEST_STATUS + case (null, null) => SparkSubmitAction.SUBMIT + case (_, null) => SparkSubmitAction.KILL + case (null, _) => SparkSubmitAction.REQUEST_STATUS case _ => SparkSubmit.printErrorAndExit( "Requested to both kill and request status for a driver. Choose only one.") null // never reached @@ -189,9 +189,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St /** Ensure that required fields exists. Call this only once all defaults are loaded. */ private def validateArguments(): Unit = { action match { - case Action.SUBMIT => validateSubmitArguments() - case Action.KILL => validateKillArguments() - case Action.REQUEST_STATUS => validateStatusRequestArguments() + case SparkSubmitAction.SUBMIT => validateSubmitArguments() + case SparkSubmitAction.KILL => validateKillArguments() + case SparkSubmitAction.REQUEST_STATUS => validateStatusRequestArguments() } } From 63c05b3f403626fe8ff4ce16a00047fb7335890c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 23 Jan 2015 11:48:50 -0800 Subject: [PATCH 11/48] Remove MASTER as a field (minor) --- .../apache/spark/deploy/rest/DriverStatusRequestMessage.scala | 3 +-- .../apache/spark/deploy/rest/DriverStatusResponseMessage.scala | 3 +-- .../apache/spark/deploy/rest/KillDriverRequestMessage.scala | 3 +-- .../apache/spark/deploy/rest/KillDriverResponseMessage.scala | 3 +-- .../org/apache/spark/deploy/rest/StandaloneRestClient.scala | 3 --- .../org/apache/spark/deploy/rest/StandaloneRestServer.scala | 3 --- .../apache/spark/deploy/rest/SubmitDriverRequestMessage.scala | 3 +-- .../apache/spark/deploy/rest/SubmitDriverResponseMessage.scala | 3 +-- 8 files changed, 6 insertions(+), 18 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala index e6b513bd4b1c..bdb3c9399251 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -26,9 +26,8 @@ private[spark] object DriverStatusRequestField case object ACTION extends DriverStatusRequestField case object CLIENT_SPARK_VERSION extends DriverStatusRequestField case object MESSAGE extends DriverStatusRequestField - case object MASTER extends DriverStatusRequestField case object DRIVER_ID extends DriverStatusRequestField - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, MASTER, DRIVER_ID) + override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, DRIVER_ID) override val optionalFields = Seq(MESSAGE) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala index 7a65f31c711f..f0315db4c007 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -26,14 +26,13 @@ private[spark] object DriverStatusResponseField case object ACTION extends DriverStatusResponseField case object SERVER_SPARK_VERSION extends DriverStatusResponseField case object MESSAGE extends DriverStatusResponseField - case object MASTER extends DriverStatusResponseField case object DRIVER_ID extends DriverStatusResponseField case object SUCCESS extends DriverStatusResponseField // Standalone specific fields case object DRIVER_STATE extends DriverStatusResponseField case object WORKER_ID extends DriverStatusResponseField case object WORKER_HOST_PORT extends DriverStatusResponseField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MASTER, DRIVER_ID, SUCCESS) + override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, DRIVER_ID, SUCCESS) override val optionalFields = Seq(MESSAGE, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala index ae3c62d496b7..ee7d8e1f7bff 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -26,9 +26,8 @@ private[spark] object KillDriverRequestField case object ACTION extends KillDriverRequestField case object CLIENT_SPARK_VERSION extends KillDriverRequestField case object MESSAGE extends KillDriverRequestField - case object MASTER extends KillDriverRequestField case object DRIVER_ID extends KillDriverRequestField - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, MASTER, DRIVER_ID) + override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, DRIVER_ID) override val optionalFields = Seq(MESSAGE) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index b5dc4ee557cb..e6ab62980811 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -26,10 +26,9 @@ private[spark] object KillDriverResponseField case object ACTION extends KillDriverResponseField case object SERVER_SPARK_VERSION extends KillDriverResponseField case object MESSAGE extends KillDriverResponseField - case object MASTER extends KillDriverResponseField case object DRIVER_ID extends KillDriverResponseField case object SUCCESS extends KillDriverResponseField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, MASTER, DRIVER_ID, SUCCESS) + override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, DRIVER_ID, SUCCESS) override val optionalFields = Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 4f5c31c080fb..278c9af749b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -89,7 +89,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { val em = Option(args.executorMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull val message = new SubmitDriverRequestMessage() .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(MASTER, args.master) .setField(APP_NAME, args.name) .setField(APP_RESOURCE, args.primaryResource) .setFieldIfNotNull(MAIN_CLASS, args.mainClass) @@ -116,7 +115,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { import KillDriverRequestField._ new KillDriverRequestMessage() .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(MASTER, master) .setField(DRIVER_ID, driverId) } @@ -127,7 +125,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { import DriverStatusRequestField._ new DriverStatusRequestMessage() .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(MASTER, master) .setField(DRIVER_ID, driverId) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 5a5afcc22833..4b347386397c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -64,7 +64,6 @@ private[spark] class StandaloneRestServerHandler( new SubmitDriverResponseMessage() .setField(SERVER_SPARK_VERSION, sparkVersion) .setField(MESSAGE, response.message) - .setField(MASTER, masterUrl) .setField(SUCCESS, response.success.toString) .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) } @@ -79,7 +78,6 @@ private[spark] class StandaloneRestServerHandler( new KillDriverResponseMessage() .setField(SERVER_SPARK_VERSION, sparkVersion) .setField(MESSAGE, response.message) - .setField(MASTER, masterUrl) .setField(DRIVER_ID, driverId) .setField(SUCCESS, response.success.toString) } @@ -98,7 +96,6 @@ private[spark] class StandaloneRestServerHandler( } new DriverStatusResponseMessage() .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(MASTER, masterUrl) .setField(DRIVER_ID, driverId) .setField(SUCCESS, response.found.toString) .setFieldIfNotNull(DRIVER_STATE, response.state.map(_.toString).orNull) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index b4c29d171b73..d3ad9c14844a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -33,7 +33,6 @@ private[spark] object SubmitDriverRequestField case object ACTION extends SubmitDriverRequestField case object CLIENT_SPARK_VERSION extends SubmitDriverRequestField case object MESSAGE extends SubmitDriverRequestField - case object MASTER extends SubmitDriverRequestField case object APP_NAME extends SubmitDriverRequestField case object APP_RESOURCE extends SubmitDriverRequestField case object MAIN_CLASS extends SubmitDriverRequestField @@ -51,7 +50,7 @@ private[spark] object SubmitDriverRequestField case object APP_ARGS extends SubmitDriverRequestField case object SPARK_PROPERTIES extends SubmitDriverRequestField case object ENVIRONMENT_VARIABLES extends SubmitDriverRequestField - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, MASTER, APP_NAME, APP_RESOURCE) + override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, APP_NAME, APP_RESOURCE) override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES, APP_ARGS, SPARK_PROPERTIES, diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index 7b3524b10f6c..938ff1e32a75 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -26,10 +26,9 @@ private[spark] object SubmitDriverResponseField case object ACTION extends SubmitDriverResponseField case object SERVER_SPARK_VERSION extends SubmitDriverResponseField case object MESSAGE extends SubmitDriverResponseField - case object MASTER extends SubmitDriverResponseField case object SUCCESS extends SubmitDriverResponseField case object DRIVER_ID extends SubmitDriverResponseField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, MASTER, SUCCESS) + override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, SUCCESS) override val optionalFields = Seq(DRIVER_ID) } From 206cae46093cab19e9f066388c6f848df3ef5391 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 26 Jan 2015 16:59:11 -0800 Subject: [PATCH 12/48] Refactor and add tests for the REST protocol This commit does the following major things: (1) Refactor SparkSubmit such that SparkSubmitSuite now passes (2) Refactor the REST messages such that it's easier to test them (3) Add type-safety validation logic for REST fields (4) Move REST fields to its own file (5) Maintain ordering of fields added to REST messages (6) Add an option to disable the REST server, as we do in tests --- .../org/apache/spark/deploy/SparkSubmit.scala | 57 ++- .../apache/spark/deploy/master/Master.scala | 5 +- .../rest/DriverStatusRequestMessage.scala | 2 +- .../rest/DriverStatusResponseMessage.scala | 4 +- .../spark/deploy/rest/ErrorMessage.scala | 2 +- .../rest/KillDriverRequestMessage.scala | 2 +- .../rest/KillDriverResponseMessage.scala | 8 +- .../rest/SubmitDriverRequestMessage.scala | 51 ++- .../rest/SubmitDriverResponseMessage.scala | 4 +- .../deploy/rest/SubmitRestProtocolField.scala | 121 ++++++ .../rest/SubmitRestProtocolMessage.scala | 87 +--- .../spark/deploy/SparkSubmitSuite.scala | 53 +-- .../deploy/rest/SubmitRestProtocolSuite.scala | 411 ++++++++++++++++++ pom.xml | 1 + project/SparkBuild.scala | 1 + 15 files changed, 648 insertions(+), 161 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala create mode 100644 core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 842fec8ea952..6f98609caf2f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -91,7 +91,7 @@ object SparkSubmit { /** * Kill an existing driver using the stable REST protocol. Standalone cluster mode only. */ - private[spark] def kill(args: SparkSubmitArguments): Unit = { + private def kill(args: SparkSubmitArguments): Unit = { new StandaloneRestClient().killDriver(args.master, args.driverToKill) } @@ -99,7 +99,7 @@ object SparkSubmit { * Request the status of an existing driver using the stable REST protocol. * Standalone cluster mode only. */ - private[spark] def requestStatus(args: SparkSubmitArguments): Unit = { + private def requestStatus(args: SparkSubmitArguments): Unit = { new StandaloneRestClient().requestDriverStatus(args.master, args.driverToRequestStatusFor) } @@ -116,7 +116,36 @@ object SparkSubmit { * main method of a child class. Instead, we pass the submit parameters directly to * a REST client, which will submit the application using the stable REST protocol. */ - private[spark] def submit(args: SparkSubmitArguments): Unit = { + private def submit(args: SparkSubmitArguments): Unit = { + val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + val isStandaloneCluster = args.master.startsWith("spark://") && args.deployMode == "cluster" + // In standalone cluster mode, use the stable application submission REST protocol. + // Otherwise, just call the main method of the child class. + if (isStandaloneCluster) { + // NOTE: since we mutate the values of some configs in `prepareSubmitEnvironment`, we + // must update the corresponding fields in the original SparkSubmitArguments to reflect + // these changes. + args.sparkProperties.clear() + args.sparkProperties ++= sysProps + sysProps.get("spark.jars").foreach { args.jars = _ } + sysProps.get("spark.files").foreach { args.files = _ } + new StandaloneRestClient().submitDriver(args) + } else { + runMain(childArgs, childClasspath, sysProps, childMainClass) + } + } + + /** + * Prepare the environment for submitting an application. + * This returns a 4-tuple: + * (1) the arguments for the child process, + * (2) a list of classpath entries for the child, + * (3) a list of system properties and env vars, and + * (4) the main class for the child + * Exposed for testing. + */ + private[spark] def prepareSubmitEnvironment(args: SparkSubmitArguments) + : (Seq[String], Seq[String], Map[String, String], String) = { // Environment needed to launch the child main class val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() @@ -247,9 +276,6 @@ object SparkSubmit { sysProp = "spark.files") ) - val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER - val isStandaloneCluster = clusterManager == STANDALONE && deployMode == CLUSTER - // In client mode, launch the application main class directly // In addition, add the main application jar and any added jars (if any) to the classpath if (deployMode == CLIENT) { @@ -274,6 +300,7 @@ object SparkSubmit { // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python files, the primary resource is already distributed as a regular file + val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER if (!isYarnCluster && !args.isPython) { var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { @@ -327,19 +354,7 @@ object SparkSubmit { sysProps("spark.submit.pyFiles") = formattedPyFiles } - // In standalone cluster mode, use the stable application submission REST protocol. - // Otherwise, just call the main method of the child class. - if (isStandaloneCluster) { - // NOTE: since we mutate the values of some configs in this method, we must update the - // corresponding fields in the original SparkSubmitArguments to reflect these changes. - args.sparkProperties.clear() - args.sparkProperties ++= sysProps - sysProps.get("spark.jars").foreach { args.jars = _ } - sysProps.get("spark.files").foreach { args.files = _ } - new StandaloneRestClient().submitDriver(args) - } else { - runMain(childArgs, childClasspath, sysProps, childMainClass) - } + (childArgs, childClasspath, sysProps, childMainClass) } /** @@ -349,8 +364,8 @@ object SparkSubmit { * this main class may not necessarily be the one provided by the user. */ private def runMain( - childArgs: ArrayBuffer[String], - childClasspath: ArrayBuffer[String], + childArgs: Seq[String], + childClasspath: Seq[String], sysProps: Map[String, String], childMainClass: String, verbose: Boolean = false) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 210b3802dc34..5ffdbe126689 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -123,9 +123,12 @@ private[spark] class Master( } // Alternative application submission gateway that is stable across Spark versions + private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) private val restServerPort = conf.getInt("spark.master.rest.port", 17077) private val restServer = new StandaloneRestServer(this, host, restServerPort) - restServer.start() + if (restServerEnabled) { + restServer.start() + } override def preStart() { logInfo("Starting Spark master at " + masterUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala index bdb3c9399251..f0d0c5f874d5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala @@ -23,7 +23,7 @@ package org.apache.spark.deploy.rest private[spark] abstract class DriverStatusRequestField extends SubmitRestProtocolField private[spark] object DriverStatusRequestField extends SubmitRestProtocolFieldCompanion[DriverStatusRequestField] { - case object ACTION extends DriverStatusRequestField + case object ACTION extends DriverStatusRequestField with ActionField case object CLIENT_SPARK_VERSION extends DriverStatusRequestField case object MESSAGE extends DriverStatusRequestField case object DRIVER_ID extends DriverStatusRequestField diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala index f0315db4c007..d65145248505 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala @@ -23,11 +23,11 @@ package org.apache.spark.deploy.rest private[spark] abstract class DriverStatusResponseField extends SubmitRestProtocolField private[spark] object DriverStatusResponseField extends SubmitRestProtocolFieldCompanion[DriverStatusResponseField] { - case object ACTION extends DriverStatusResponseField + case object ACTION extends DriverStatusResponseField with ActionField case object SERVER_SPARK_VERSION extends DriverStatusResponseField case object MESSAGE extends DriverStatusResponseField case object DRIVER_ID extends DriverStatusResponseField - case object SUCCESS extends DriverStatusResponseField + case object SUCCESS extends DriverStatusResponseField with BooleanField // Standalone specific fields case object DRIVER_STATE extends DriverStatusResponseField case object WORKER_ID extends DriverStatusResponseField diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala index 88d33462dd44..f1fbdd227507 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala @@ -22,7 +22,7 @@ package org.apache.spark.deploy.rest */ private[spark] abstract class ErrorField extends SubmitRestProtocolField private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion[ErrorField] { - case object ACTION extends ErrorField + case object ACTION extends ErrorField with ActionField case object SERVER_SPARK_VERSION extends ErrorField case object MESSAGE extends ErrorField override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala index ee7d8e1f7bff..232bb364e889 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala @@ -23,7 +23,7 @@ package org.apache.spark.deploy.rest private[spark] abstract class KillDriverRequestField extends SubmitRestProtocolField private[spark] object KillDriverRequestField extends SubmitRestProtocolFieldCompanion[KillDriverRequestField] { - case object ACTION extends KillDriverRequestField + case object ACTION extends KillDriverRequestField with ActionField case object CLIENT_SPARK_VERSION extends KillDriverRequestField case object MESSAGE extends KillDriverRequestField case object DRIVER_ID extends KillDriverRequestField diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala index e6ab62980811..0717131ab2ec 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala @@ -23,13 +23,13 @@ package org.apache.spark.deploy.rest private[spark] abstract class KillDriverResponseField extends SubmitRestProtocolField private[spark] object KillDriverResponseField extends SubmitRestProtocolFieldCompanion[KillDriverResponseField] { - case object ACTION extends KillDriverResponseField + case object ACTION extends KillDriverResponseField with ActionField case object SERVER_SPARK_VERSION extends KillDriverResponseField case object MESSAGE extends KillDriverResponseField case object DRIVER_ID extends KillDriverResponseField - case object SUCCESS extends KillDriverResponseField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, DRIVER_ID, SUCCESS) - override val optionalFields = Seq.empty + case object SUCCESS extends KillDriverResponseField with BooleanField + override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, DRIVER_ID, SUCCESS) + override val optionalFields = Seq(MESSAGE) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala index d3ad9c14844a..90d7e408fefc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala @@ -30,7 +30,7 @@ import org.apache.spark.util.JsonProtocol private[spark] abstract class SubmitDriverRequestField extends SubmitRestProtocolField private[spark] object SubmitDriverRequestField extends SubmitRestProtocolFieldCompanion[SubmitDriverRequestField] { - case object ACTION extends SubmitDriverRequestField + case object ACTION extends SubmitDriverRequestField with ActionField case object CLIENT_SPARK_VERSION extends SubmitDriverRequestField case object MESSAGE extends SubmitDriverRequestField case object APP_NAME extends SubmitDriverRequestField @@ -39,17 +39,32 @@ private[spark] object SubmitDriverRequestField case object JARS extends SubmitDriverRequestField case object FILES extends SubmitDriverRequestField case object PY_FILES extends SubmitDriverRequestField - case object DRIVER_MEMORY extends SubmitDriverRequestField - case object DRIVER_CORES extends SubmitDriverRequestField + case object DRIVER_MEMORY extends SubmitDriverRequestField with MemoryField + case object DRIVER_CORES extends SubmitDriverRequestField with NumericField case object DRIVER_EXTRA_JAVA_OPTIONS extends SubmitDriverRequestField case object DRIVER_EXTRA_CLASS_PATH extends SubmitDriverRequestField case object DRIVER_EXTRA_LIBRARY_PATH extends SubmitDriverRequestField - case object SUPERVISE_DRIVER extends SubmitDriverRequestField // standalone cluster mode only - case object EXECUTOR_MEMORY extends SubmitDriverRequestField - case object TOTAL_EXECUTOR_CORES extends SubmitDriverRequestField - case object APP_ARGS extends SubmitDriverRequestField - case object SPARK_PROPERTIES extends SubmitDriverRequestField - case object ENVIRONMENT_VARIABLES extends SubmitDriverRequestField + case object SUPERVISE_DRIVER extends SubmitDriverRequestField with BooleanField + case object EXECUTOR_MEMORY extends SubmitDriverRequestField with MemoryField + case object TOTAL_EXECUTOR_CORES extends SubmitDriverRequestField with NumericField + + // Special fields that should not be set directly + case object APP_ARGS extends SubmitDriverRequestField { + override def validateValue(v: String): Unit = { + validateFailed(v, "Use message.appendAppArg(arg) instead") + } + } + case object SPARK_PROPERTIES extends SubmitDriverRequestField { + override def validateValue(v: String): Unit = { + validateFailed(v, "Use message.setSparkProperty(k, v) instead") + } + } + case object ENVIRONMENT_VARIABLES extends SubmitDriverRequestField { + override def validateValue(v: String): Unit = { + validateFailed(v, "Use message.setEnvironmentVariable(k, v) instead") + } + } + override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, APP_NAME, APP_RESOURCE) override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, @@ -89,12 +104,18 @@ private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessag val appArgsJson = JArray(appArgs.map(JString).toList) val sparkPropertiesJson = JsonProtocol.mapToJson(sparkProperties) val environmentVariablesJson = JsonProtocol.mapToJson(environmentVariables) - val allFields = otherFields ++ List( - (APP_ARGS.toString, appArgsJson), - (SPARK_PROPERTIES.toString, sparkPropertiesJson), - (ENVIRONMENT_VARIABLES.toString, environmentVariablesJson) - ) - JObject(allFields) + val jsonFields = new ArrayBuffer[JField] + jsonFields ++= otherFields + if (appArgs.nonEmpty) { + jsonFields += JField(APP_ARGS.toString, appArgsJson) + } + if (sparkProperties.nonEmpty) { + jsonFields += JField(SPARK_PROPERTIES.toString, sparkPropertiesJson) + } + if (environmentVariables.nonEmpty) { + jsonFields += JField(ENVIRONMENT_VARIABLES.toString, environmentVariablesJson) + } + JObject(jsonFields.toList) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala index 938ff1e32a75..d5a2e1660eb0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala @@ -23,10 +23,10 @@ package org.apache.spark.deploy.rest private[spark] abstract class SubmitDriverResponseField extends SubmitRestProtocolField private[spark] object SubmitDriverResponseField extends SubmitRestProtocolFieldCompanion[SubmitDriverResponseField] { - case object ACTION extends SubmitDriverResponseField + case object ACTION extends SubmitDriverResponseField with ActionField case object SERVER_SPARK_VERSION extends SubmitDriverResponseField case object MESSAGE extends SubmitDriverResponseField - case object SUCCESS extends SubmitDriverResponseField + case object SUCCESS extends SubmitDriverResponseField with BooleanField case object DRIVER_ID extends SubmitDriverResponseField override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, SUCCESS) override val optionalFields = Seq(DRIVER_ID) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala new file mode 100644 index 000000000000..639e00d912e7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import scala.collection.Map +import scala.util.Try + +import org.apache.spark.util.Utils + +/** + * A field used in a SubmitRestProtocolMessage. + * There are a few special fields: + * - ACTION entirely specifies the type of the message and is required in all messages + * - MESSAGE contains arbitrary messages and is common, but not required, in all messages + * - CLIENT_SPARK_VERSION is required in all messages sent from the client + * - SERVER_SPARK_VERSION is required in all messages sent from the server + */ +private[spark] abstract class SubmitRestProtocolField { + protected val fieldName = Utils.getFormattedClassName(this) + def validateValue(value: String): Unit = { } + def validateFailed(v: String, msg: String): Unit = { + throw new IllegalArgumentException(s"Detected setting of $fieldName to $v: $msg") + } +} +private[spark] object SubmitRestProtocolField { + def isActionField(field: String): Boolean = field == "ACTION" +} + +/** A field that should accept only boolean values. */ +private[spark] trait BooleanField extends SubmitRestProtocolField { + override def validateValue(v: String): Unit = { + Try(v.toBoolean).getOrElse { validateFailed(v, s"Error parsing $v as a boolean!") } + } +} + +/** A field that should accept only numeric values. */ +private[spark] trait NumericField extends SubmitRestProtocolField { + override def validateValue(v: String): Unit = { + Try(v.toInt).getOrElse { validateFailed(v, s"Error parsing $v as an integer!") } + } +} + +/** A field that should accept only memory values. */ +private[spark] trait MemoryField extends SubmitRestProtocolField { + override def validateValue(v: String): Unit = { + Try(Utils.memoryStringToMb(v)).getOrElse { + validateFailed(v, s"Error parsing $v as a memory string!") + } + } +} + +/** + * The main action field in every message. + * This should be set only on message instantiation. + */ +private[spark] trait ActionField extends SubmitRestProtocolField { + override def validateValue(v: String): Unit = { + validateFailed(v, "The ACTION field must not be set directly after instantiation.") + } +} + +/** + * All possible values of the ACTION field in a SubmitRestProtocolMessage. + */ +private[spark] abstract class SubmitRestProtocolAction +private[spark] object SubmitRestProtocolAction { + case object SUBMIT_DRIVER_REQUEST extends SubmitRestProtocolAction + case object SUBMIT_DRIVER_RESPONSE extends SubmitRestProtocolAction + case object KILL_DRIVER_REQUEST extends SubmitRestProtocolAction + case object KILL_DRIVER_RESPONSE extends SubmitRestProtocolAction + case object DRIVER_STATUS_REQUEST extends SubmitRestProtocolAction + case object DRIVER_STATUS_RESPONSE extends SubmitRestProtocolAction + case object ERROR extends SubmitRestProtocolAction + private val allActions = + Seq(SUBMIT_DRIVER_REQUEST, SUBMIT_DRIVER_RESPONSE, KILL_DRIVER_REQUEST, + KILL_DRIVER_RESPONSE, DRIVER_STATUS_REQUEST, DRIVER_STATUS_RESPONSE, ERROR) + private val allActionsMap = allActions.map { a => (a.toString, a) }.toMap + + def fromString(action: String): SubmitRestProtocolAction = { + allActionsMap.get(action).getOrElse { + throw new IllegalArgumentException(s"Unknown action $action") + } + } +} + +/** + * Common methods used by companion objects of SubmitRestProtocolField's subclasses. + * This keeps track of all fields that belong to this object in order to reconstruct + * the fields from their names. + */ +private[spark] trait SubmitRestProtocolFieldCompanion[FieldType <: SubmitRestProtocolField] { + val requiredFields: Seq[FieldType] + val optionalFields: Seq[FieldType] + + // Listing of all fields indexed by the field's string representation + private lazy val allFieldsMap: Map[String, FieldType] = { + (requiredFields ++ optionalFields).map { f => (f.toString, f) }.toMap + } + + /** Return the appropriate SubmitRestProtocolField from its string representation. */ + def fromString(field: String): FieldType = { + allFieldsMap.get(field).getOrElse { + throw new IllegalArgumentException(s"Unknown field $field") + } + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 2969886f9094..7899668ac526 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -18,41 +18,14 @@ package org.apache.spark.deploy.rest import scala.collection.Map -import scala.collection.mutable +import scala.collection.JavaConversions._ import org.json4s.jackson.JsonMethods._ import org.json4s.JsonAST._ -import org.apache.spark.{Logging, SparkException} +import org.apache.spark.Logging import org.apache.spark.util.Utils -/** - * A field used in a SubmitRestProtocolMessage. - * There are a few special fields: - * - ACTION entirely specifies the type of the message and is required in all messages - * - MESSAGE contains arbitrary messages and is common, but not required, in all messages - * - CLIENT_SPARK_VERSION is required in all messages sent from the client - * - SERVER_SPARK_VERSION is required in all messages sent from the server - */ -private[spark] abstract class SubmitRestProtocolField -private[spark] object SubmitRestProtocolField { - def isActionField(field: String): Boolean = field == "ACTION" - def isSparkVersionField(field: String): Boolean = field.endsWith("_SPARK_VERSION") - def isMessageField(field: String): Boolean = field == "MESSAGE" -} - -/** - * All possible values of the ACTION field in a SubmitRestProtocolMessage. - */ -private[spark] object SubmitRestProtocolAction extends Enumeration { - type SubmitRestProtocolAction = Value - val SUBMIT_DRIVER_REQUEST, SUBMIT_DRIVER_RESPONSE = Value - val KILL_DRIVER_REQUEST, KILL_DRIVER_RESPONSE = Value - val DRIVER_STATUS_REQUEST, DRIVER_STATUS_RESPONSE = Value - val ERROR = Value -} -import SubmitRestProtocolAction.SubmitRestProtocolAction - /** * A general message exchanged in the stable application submission REST protocol. * @@ -63,19 +36,18 @@ import SubmitRestProtocolAction.SubmitRestProtocolAction */ private[spark] abstract class SubmitRestProtocolMessage( action: SubmitRestProtocolAction, - actionField: SubmitRestProtocolField, + actionField: ActionField, requiredFields: Seq[SubmitRestProtocolField]) { - import SubmitRestProtocolField._ - - private val fields = new mutable.HashMap[SubmitRestProtocolField, String] + // Maintain the insert order for converting to JSON later + private val fields = new java.util.LinkedHashMap[SubmitRestProtocolField, String] val className = Utils.getFormattedClassName(this) // Set the action field - fields(actionField) = action.toString + fields.put(actionField, action.toString) /** Return all fields currently set in this message. */ - def getFields: Map[SubmitRestProtocolField, String] = fields + def getFields: Map[SubmitRestProtocolField, String] = fields.toMap /** Return the value of the given field. If the field is not present, return null. */ def getField(key: SubmitRestProtocolField): String = getFieldOption(key).orNull @@ -88,14 +60,12 @@ private[spark] abstract class SubmitRestProtocolMessage( } /** Return the value of the given field as an option. */ - def getFieldOption(key: SubmitRestProtocolField): Option[String] = fields.get(key) + def getFieldOption(key: SubmitRestProtocolField): Option[String] = Option(fields.get(key)) /** Assign the given value to the field, overriding any existing value. */ def setField(key: SubmitRestProtocolField, value: String): this.type = { - if (key == actionField) { - throw new SparkException("Setting the ACTION field is only allowed during instantiation.") - } - fields(key) = value + key.validateValue(value) + fields.put(key, value) this } @@ -133,19 +103,10 @@ private[spark] abstract class SubmitRestProtocolMessage( /** * Return a JObject that represents the JSON form of this message. - * This orders the fields by ACTION (first) < SERVER_SPARK_VERSION < MESSAGE < * (last) - * and ignores fields with null values. + * This ignores fields with null values. */ protected def toJsonObject: JObject = { - val sortedFields = fields.toSeq.sortBy { case (k, _) => - k.toString match { - case x if isActionField(x) => 0 - case x if isSparkVersionField(x) => 1 - case x if isMessageField(x) => 2 - case _ => 3 - } - } - val jsonFields = sortedFields + val jsonFields = fields.toSeq .filter { case (_, v) => v != null } .map { case (k, v) => JField(k.toString, JString(v)) } .toList @@ -167,7 +128,7 @@ private[spark] object SubmitRestProtocolMessage { val action = getAction(jsonObject).getOrElse { throw new IllegalArgumentException(s"ACTION not found in message:\n$json") } - SubmitRestProtocolAction.withName(action) match { + SubmitRestProtocolAction.fromString(action) match { case SUBMIT_DRIVER_REQUEST => SubmitDriverRequestMessage.fromJsonObject(jsonObject) case SUBMIT_DRIVER_RESPONSE => SubmitDriverResponseMessage.fromJsonObject(jsonObject) case KILL_DRIVER_REQUEST => KillDriverRequestMessage.fromJsonObject(jsonObject) @@ -188,28 +149,6 @@ private[spark] object SubmitRestProtocolMessage { } } -/** - * Common methods used by companion objects of SubmitRestProtocolField's subclasses. - * This keeps track of all fields that belong to this object in order to reconstruct - * the fields from their names. - */ -private[spark] trait SubmitRestProtocolFieldCompanion[FieldType <: SubmitRestProtocolField] { - val requiredFields: Seq[FieldType] - val optionalFields: Seq[FieldType] - - // Listing of all fields indexed by the field's string representation - private lazy val allFieldsMap: Map[String, FieldType] = { - (requiredFields ++ optionalFields).map { f => (f.toString, f) }.toMap - } - - /** Return the appropriate SubmitRestProtocolField from its string representation. */ - def fromString(field: String): FieldType = { - allFieldsMap.get(field).getOrElse { - throw new IllegalArgumentException(s"Unknown field $field") - } - } -} - /** * Common methods used by companion objects of SubmitRestProtocolMessage's subclasses. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 065b7534cece..044e0699968e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -138,7 +138,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--executor-memory 5g") @@ -177,7 +177,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -197,33 +197,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties sysProps("spark.shuffle.spill") should be ("false") } - test("handles standalone cluster mode") { - val clArgs = Seq( - "--deploy-mode", "cluster", - "--master", "spark://h:p", - "--class", "org.SomeClass", - "--supervise", - "--driver-memory", "4g", - "--driver-cores", "5", - "--conf", "spark.shuffle.spill=false", - "thejar.jar", - "arg1", "arg2") - val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) - val childArgsStr = childArgs.mkString(" ") - childArgsStr should startWith ("--memory 4g --cores 5 --supervise") - childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") - mainClass should be ("org.apache.spark.deploy.Client") - classpath should have size (0) - sysProps should have size (5) - sysProps.keys should contain ("SPARK_SUBMIT") - sysProps.keys should contain ("spark.master") - sysProps.keys should contain ("spark.app.name") - sysProps.keys should contain ("spark.jars") - sysProps.keys should contain ("spark.shuffle.spill") - sysProps("spark.shuffle.spill") should be ("false") - } - test("handles standalone client mode") { val clArgs = Seq( "--deploy-mode", "client", @@ -236,7 +209,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -258,7 +231,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) @@ -278,7 +251,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, mainClass) = createLaunchEnv(appArgs) + val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) sysProps("spark.executor.memory") should be ("5g") sysProps("spark.master") should be ("yarn-cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") @@ -291,6 +264,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--name", "testApp", "--master", "local", "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -306,6 +280,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--master", "local-cluster[2,1,512]", "--jars", jarsString, "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -324,7 +299,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) @@ -339,7 +314,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should be (Utils.resolveURIs(archives)) sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) @@ -352,7 +327,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) sysProps3("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) @@ -377,7 +352,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3 + val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) sysProps("spark.files") should be(Utils.resolveURIs(files)) @@ -394,7 +369,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3 + val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) @@ -409,7 +384,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3 + val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 sysProps3("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) } @@ -425,7 +400,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties val appArgs = new SparkSubmitArguments(args, Map("SPARK_CONF_DIR" -> path)) assert(appArgs.propertiesFile != null) assert(appArgs.propertiesFile.startsWith(path)) - appArgs.executorMemory should be ("2.3g") + appArgs.executorMemory should be ("2.3g") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala new file mode 100644 index 000000000000..18091e98c0b2 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ +import org.scalatest.FunSuite + +/** + * Dummy fields and messages for testing. + */ +private abstract class DummyField extends SubmitRestProtocolField +private object DummyField extends SubmitRestProtocolFieldCompanion[DummyField] { + case object ACTION extends DummyField with ActionField + case object DUMMY_FIELD extends DummyField + case object BOOLEAN_FIELD extends DummyField with BooleanField + case object MEMORY_FIELD extends DummyField with MemoryField + case object NUMERIC_FIELD extends DummyField with NumericField + case object REQUIRED_FIELD extends DummyField + override val requiredFields = Seq(ACTION, REQUIRED_FIELD) + override val optionalFields = Seq(DUMMY_FIELD, BOOLEAN_FIELD, MEMORY_FIELD, NUMERIC_FIELD) +} +private object DUMMY_ACTION extends SubmitRestProtocolAction { + override def toString: String = "DUMMY_ACTION" +} +private class DummyMessage extends SubmitRestProtocolMessage( + DUMMY_ACTION, + DummyField.ACTION, + DummyField.requiredFields) +private object DummyMessage extends SubmitRestProtocolMessageCompanion[DummyMessage] { + protected override def newMessage() = new DummyMessage + protected override def fieldFromString(f: String) = DummyField.fromString(f) +} + + +/** + * Tests for the stable application submission REST protocol. + */ +class SubmitRestProtocolSuite extends FunSuite { + + /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */ + private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = { + val trimmedJson1 = jsonString1.trim + val trimmedJson2 = jsonString2.trim + val json1 = compact(render(parse(trimmedJson1))) + val json2 = compact(render(parse(trimmedJson2))) + // Put this on a separate line to avoid printing comparison twice when test fails + val equals = json1 == json2 + assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2)) + } + + test("get and set fields") { + import DummyField._ + val message = new DummyMessage + // action field is already set on instantiation + assert(message.getFields.size === 1) + assert(message.getField(ACTION) === DUMMY_ACTION.toString) + // required field not set yet + intercept[IllegalArgumentException] { message.validate() } + intercept[IllegalArgumentException] { message.getFieldNotNull(DUMMY_FIELD) } + intercept[IllegalArgumentException] { message.getFieldNotNull(REQUIRED_FIELD) } + message.setField(DUMMY_FIELD, "dummy value") + message.setField(BOOLEAN_FIELD, "true") + message.setField(MEMORY_FIELD, "401k") + message.setField(NUMERIC_FIELD, "401") + message.setFieldIfNotNull(REQUIRED_FIELD, null) // no-op because value is null + assert(message.getFields.size === 5) + // required field still not set + intercept[IllegalArgumentException] { message.validate() } + intercept[IllegalArgumentException] { message.getFieldNotNull(REQUIRED_FIELD) } + message.setFieldIfNotNull(REQUIRED_FIELD, "dummy value") + // all required fields are now set + assert(message.getFields.size === 6) + assert(message.getField(DUMMY_FIELD) === "dummy value") + assert(message.getField(BOOLEAN_FIELD) === "true") + assert(message.getField(MEMORY_FIELD) === "401k") + assert(message.getField(NUMERIC_FIELD) === "401") + assert(message.getField(REQUIRED_FIELD) === "dummy value") + message.validate() + // bad field values + intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } + intercept[IllegalArgumentException] { message.setField(BOOLEAN_FIELD, "not T nor F") } + intercept[IllegalArgumentException] { message.setField(MEMORY_FIELD, "not memory") } + intercept[IllegalArgumentException] { message.setField(NUMERIC_FIELD, "not a number") } + } + + test("to and from JSON") { + import DummyField._ + val message = new DummyMessage() + .setField(DUMMY_FIELD, "dummy value") + .setField(BOOLEAN_FIELD, "true") + .setField(MEMORY_FIELD, "401k") + .setField(NUMERIC_FIELD, "401") + .setField(REQUIRED_FIELD, "dummy value") + .validate() + val expectedJson = + """ + |{ + | "ACTION" : "DUMMY_ACTION", + | "DUMMY_FIELD" : "dummy value", + | "BOOLEAN_FIELD" : "true", + | "MEMORY_FIELD" : "401k", + | "NUMERIC_FIELD" : "401", + | "REQUIRED_FIELD" : "dummy value" + |} + """.stripMargin + val actualJson = message.toJson + assertJsonEquals(actualJson, expectedJson) + // Do not use SubmitRestProtocolMessage.fromJson here + // because DUMMY_ACTION is not a known action + val jsonObject = parse(expectedJson).asInstanceOf[JObject] + val newMessage = DummyMessage.fromJsonObject(jsonObject) + assert(newMessage.getFieldNotNull(ACTION) === "DUMMY_ACTION") + assert(newMessage.getFieldNotNull(DUMMY_FIELD) === "dummy value") + assert(newMessage.getFieldNotNull(BOOLEAN_FIELD) === "true") + assert(newMessage.getFieldNotNull(MEMORY_FIELD) === "401k") + assert(newMessage.getFieldNotNull(NUMERIC_FIELD) === "401") + assert(newMessage.getFieldNotNull(REQUIRED_FIELD) === "dummy value") + assert(newMessage.getFields.size === 6) + } + + test("SubmitDriverRequestMessage") { + import SubmitDriverRequestField._ + val message = new SubmitDriverRequestMessage + intercept[IllegalArgumentException] { message.validate() } + message.setField(CLIENT_SPARK_VERSION, "1.2.3") + message.setField(MESSAGE, "Submitting them drivers.") + message.setField(APP_NAME, "SparkPie") + message.setField(APP_RESOURCE, "honey-walnut-cherry.jar") + // all required fields are now set + message.validate() + message.setField(MAIN_CLASS, "org.apache.spark.examples.SparkPie") + message.setField(JARS, "mayonnaise.jar,ketchup.jar") + message.setField(FILES, "fireball.png") + message.setField(PY_FILES, "do-not-eat-my.py") + message.setField(DRIVER_MEMORY, "512m") + message.setField(DRIVER_CORES, "180") + message.setField(DRIVER_EXTRA_JAVA_OPTIONS, " -Dslices=5 -Dcolor=mostly_red") + message.setField(DRIVER_EXTRA_CLASS_PATH, "food-coloring.jar") + message.setField(DRIVER_EXTRA_LIBRARY_PATH, "pickle.jar") + message.setField(SUPERVISE_DRIVER, "false") + message.setField(EXECUTOR_MEMORY, "256m") + message.setField(TOTAL_EXECUTOR_CORES, "10000") + // bad field values + intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } + intercept[IllegalArgumentException] { message.setField(DRIVER_MEMORY, "more than expected") } + intercept[IllegalArgumentException] { message.setField(DRIVER_CORES, "one hundred feet") } + intercept[IllegalArgumentException] { message.setField(SUPERVISE_DRIVER, "nope, never") } + intercept[IllegalArgumentException] { message.setField(EXECUTOR_MEMORY, "less than expected") } + intercept[IllegalArgumentException] { message.setField(TOTAL_EXECUTOR_CORES, "two men") } + intercept[IllegalArgumentException] { message.setField(APP_ARGS, "anything") } + intercept[IllegalArgumentException] { message.setField(SPARK_PROPERTIES, "anything") } + intercept[IllegalArgumentException] { message.setField(ENVIRONMENT_VARIABLES, "anything") } + // special fields + message.appendAppArg("two slices") + message.appendAppArg("a hint of cinnamon") + message.setSparkProperty("spark.live.long", "true") + message.setSparkProperty("spark.shuffle.enabled", "false") + message.setEnvironmentVariable("PATH", "/dev/null") + message.setEnvironmentVariable("PYTHONPATH", "/dev/null") + assert(message.getAppArgs === Seq("two slices", "a hint of cinnamon")) + assert(message.getSparkProperties.size === 2) + assert(message.getSparkProperties("spark.live.long") === "true") + assert(message.getSparkProperties("spark.shuffle.enabled") === "false") + assert(message.getEnvironmentVariables.size === 2) + assert(message.getEnvironmentVariables("PATH") === "/dev/null") + assert(message.getEnvironmentVariables("PYTHONPATH") === "/dev/null") + // test JSON + val expectedJson = submitDriverRequestJson + assertJsonEquals(message.toJson, expectedJson) + val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) + .asInstanceOf[SubmitDriverRequestMessage] + assert(newMessage.getFields === message.getFields) + assert(newMessage.getAppArgs === message.getAppArgs) + assert(newMessage.getSparkProperties === message.getSparkProperties) + assert(newMessage.getEnvironmentVariables === message.getEnvironmentVariables) + } + + test("SubmitDriverResponseMessage") { + import SubmitDriverResponseField._ + val message = new SubmitDriverResponseMessage + intercept[IllegalArgumentException] { message.validate() } + message.setField(SERVER_SPARK_VERSION, "1.2.3") + message.setField(MESSAGE, "Dem driver is now submitted.") + message.setField(DRIVER_ID, "driver_123") + message.setField(SUCCESS, "true") + // all required fields are now set + message.validate() + // bad field values + intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } + intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe not") } + // test JSON + val expectedJson = submitDriverResponseJson + val actualJson = message.toJson + assertJsonEquals(actualJson, expectedJson) + val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) + assert(newMessage.isInstanceOf[SubmitDriverResponseMessage]) + assert(newMessage.getFields === message.getFields) + } + + test("KillDriverRequestMessage") { + import KillDriverRequestField._ + val message = new KillDriverRequestMessage + intercept[IllegalArgumentException] { message.validate() } + intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } + message.setField(CLIENT_SPARK_VERSION, "1.2.3") + message.setField(DRIVER_ID, "driver_123") + // all required fields are now set + message.validate() + // test JSON + val expectedJson = killDriverRequestJson + val actualJson = message.toJson + assertJsonEquals(actualJson, expectedJson) + val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) + assert(newMessage.isInstanceOf[KillDriverRequestMessage]) + assert(newMessage.getFields === message.getFields) + } + + test("KillDriverResponseMessage") { + import KillDriverResponseField._ + val message = new KillDriverResponseMessage + intercept[IllegalArgumentException] { message.validate() } + message.setField(SERVER_SPARK_VERSION, "1.2.3") + message.setField(DRIVER_ID, "driver_123") + message.setField(SUCCESS, "true") + // all required fields are now set + message.validate() + message.setField(MESSAGE, "Killing dem reckless drivers.") + // bad field values + intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } + intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe?") } + // test JSON + val expectedJson = killDriverResponseJson + val actualJson = message.toJson + assertJsonEquals(actualJson, expectedJson) + val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) + assert(newMessage.isInstanceOf[KillDriverResponseMessage]) + assert(newMessage.getFields === message.getFields) + } + + test("DriverStatusRequestMessage") { + import DriverStatusRequestField._ + val message = new DriverStatusRequestMessage + intercept[IllegalArgumentException] { message.validate() } + intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } + message.setField(CLIENT_SPARK_VERSION, "1.2.3") + message.setField(DRIVER_ID, "driver_123") + // all required fields are now set + message.validate() + // test JSON + val expectedJson = driverStatusRequestJson + val actualJson = message.toJson + assertJsonEquals(actualJson, expectedJson) + val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) + assert(newMessage.isInstanceOf[DriverStatusRequestMessage]) + assert(newMessage.getFields === message.getFields) + } + + test("DriverStatusResponseMessage") { + import DriverStatusResponseField._ + val message = new DriverStatusResponseMessage + intercept[IllegalArgumentException] { message.validate() } + message.setField(SERVER_SPARK_VERSION, "1.2.3") + message.setField(DRIVER_ID, "driver_123") + message.setField(SUCCESS, "true") + // all required fields are now set + message.validate() + message.setField(MESSAGE, "Your driver is having some trouble...") + message.setField(DRIVER_STATE, "RUNNING") + message.setField(WORKER_ID, "worker_123") + message.setField(WORKER_HOST_PORT, "1.2.3.4:7780") + // bad field values + intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } + intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe") } + // test JSON + val expectedJson = driverStatusResponseJson + val actualJson = message.toJson + assertJsonEquals(actualJson, expectedJson) + val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) + assert(newMessage.isInstanceOf[DriverStatusResponseMessage]) + assert(newMessage.getFields === message.getFields) + } + + test("ErrorMessage") { + import ErrorField._ + val message = new ErrorMessage + intercept[IllegalArgumentException] { message.validate() } + intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } + message.setField(SERVER_SPARK_VERSION, "1.2.3") + message.setField(MESSAGE, "Your wife threw an exception!") + // all required fields are now set + message.validate() + // test JSON + val expectedJson = errorJson + val actualJson = message.toJson + assertJsonEquals(actualJson, expectedJson) + val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) + assert(newMessage.isInstanceOf[ErrorMessage]) + assert(newMessage.getFields === message.getFields) + } + + private val submitDriverRequestJson = + """ + |{ + | "ACTION" : "SUBMIT_DRIVER_REQUEST", + | "CLIENT_SPARK_VERSION" : "1.2.3", + | "MESSAGE" : "Submitting them drivers.", + | "APP_NAME" : "SparkPie", + | "APP_RESOURCE" : "honey-walnut-cherry.jar", + | "MAIN_CLASS" : "org.apache.spark.examples.SparkPie", + | "JARS" : "mayonnaise.jar,ketchup.jar", + | "FILES" : "fireball.png", + | "PY_FILES" : "do-not-eat-my.py", + | "DRIVER_MEMORY" : "512m", + | "DRIVER_CORES" : "180", + | "DRIVER_EXTRA_JAVA_OPTIONS" : " -Dslices=5 -Dcolor=mostly_red", + | "DRIVER_EXTRA_CLASS_PATH" : "food-coloring.jar", + | "DRIVER_EXTRA_LIBRARY_PATH" : "pickle.jar", + | "SUPERVISE_DRIVER" : "false", + | "EXECUTOR_MEMORY" : "256m", + | "TOTAL_EXECUTOR_CORES" : "10000", + | "APP_ARGS" : [ "two slices", "a hint of cinnamon" ], + | "SPARK_PROPERTIES" : { + | "spark.live.long" : "true", + | "spark.shuffle.enabled" : "false" + | }, + | "ENVIRONMENT_VARIABLES" : { + | "PATH" : "/dev/null", + | "PYTHONPATH" : "/dev/null" + | } + |} + """.stripMargin + + private val submitDriverResponseJson = + """ + |{ + | "ACTION" : "SUBMIT_DRIVER_RESPONSE", + | "SERVER_SPARK_VERSION" : "1.2.3", + | "MESSAGE" : "Dem driver is now submitted.", + | "DRIVER_ID" : "driver_123", + | "SUCCESS" : "true" + |} + """.stripMargin + + private val killDriverRequestJson = + """ + |{ + | "ACTION" : "KILL_DRIVER_REQUEST", + | "CLIENT_SPARK_VERSION" : "1.2.3", + | "DRIVER_ID" : "driver_123" + |} + """.stripMargin + + private val killDriverResponseJson = + """ + |{ + | "ACTION" : "KILL_DRIVER_RESPONSE", + | "SERVER_SPARK_VERSION" : "1.2.3", + | "DRIVER_ID" : "driver_123", + | "SUCCESS" : "true", + | "MESSAGE" : "Killing dem reckless drivers." + |} + """.stripMargin + + private val driverStatusRequestJson = + """ + |{ + | "ACTION" : "DRIVER_STATUS_REQUEST", + | "CLIENT_SPARK_VERSION" : "1.2.3", + | "DRIVER_ID" : "driver_123" + |} + """.stripMargin + + private val driverStatusResponseJson = + """ + |{ + | "ACTION" : "DRIVER_STATUS_RESPONSE", + | "SERVER_SPARK_VERSION" : "1.2.3", + | "DRIVER_ID" : "driver_123", + | "SUCCESS" : "true", + | "MESSAGE" : "Your driver is having some trouble...", + | "DRIVER_STATE" : "RUNNING", + | "WORKER_ID" : "worker_123", + | "WORKER_HOST_PORT" : "1.2.3.4:7780" + |} + """.stripMargin + + private val errorJson = + """ + |{ + | "ACTION" : "ERROR", + | "SERVER_SPARK_VERSION" : "1.2.3", + | "MESSAGE" : "Your wife threw an exception!" + |} + """.stripMargin +} diff --git a/pom.xml b/pom.xml index f4466e56c2a5..41d418039c31 100644 --- a/pom.xml +++ b/pom.xml @@ -1127,6 +1127,7 @@ 1 false false + false ${test_classpath} true diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 46a54c681840..05de3bf18fda 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -393,6 +393,7 @@ object TestSettings { javaOptions in Test += "-Dspark.port.maxRetries=100", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", + javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") From 77774ba94e4c5c66af7d5053970ed30ec22ffef3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 26 Jan 2015 18:11:46 -0800 Subject: [PATCH 13/48] Minor fixes --- .../main/scala/org/apache/spark/deploy/SparkSubmit.scala | 7 +++---- .../org/apache/spark/deploy/SparkSubmitArguments.scala | 8 ++++++-- .../apache/spark/deploy/rest/StandaloneRestServer.scala | 2 -- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 6f98609caf2f..e4db1ed31a91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -118,10 +118,9 @@ object SparkSubmit { */ private def submit(args: SparkSubmitArguments): Unit = { val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) - val isStandaloneCluster = args.master.startsWith("spark://") && args.deployMode == "cluster" // In standalone cluster mode, use the stable application submission REST protocol. // Otherwise, just call the main method of the child class. - if (isStandaloneCluster) { + if (args.isStandaloneCluster) { // NOTE: since we mutate the values of some configs in `prepareSubmitEnvironment`, we // must update the corresponding fields in the original SparkSubmitArguments to reflect // these changes. @@ -146,7 +145,7 @@ object SparkSubmit { */ private[spark] def prepareSubmitEnvironment(args: SparkSubmitArguments) : (Seq[String], Seq[String], Map[String, String], String) = { - // Environment needed to launch the child main class + // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() val sysProps = new HashMap[String, String]() @@ -158,7 +157,7 @@ object SparkSubmit { case m if m.startsWith("spark") => STANDALONE case m if m.startsWith("mesos") => MESOS case m if m.startsWith("local") => LOCAL - case _ => printErrorAndExit("Master must start with yarn, spark, mesos, local, or rest"); -1 + case _ => printErrorAndExit("Master must start with yarn, spark, mesos or local"); -1 } // Set the deploy mode; default is client mode diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index bce20c2b92e7..ed550fb4a9f6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -231,7 +231,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } private def validateKillArguments(): Unit = { - if (!master.startsWith("spark://") || deployMode != "cluster") { + if (!isStandaloneCluster) { SparkSubmit.printErrorAndExit("Killing drivers is only supported in standalone cluster mode") } if (driverToKill == null) { @@ -240,7 +240,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } private def validateStatusRequestArguments(): Unit = { - if (!master.startsWith("spark://") || deployMode != "cluster") { + if (!isStandaloneCluster) { SparkSubmit.printErrorAndExit( "Requesting driver statuses is only supported in standalone cluster mode") } @@ -249,6 +249,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St } } + def isStandaloneCluster: Boolean = { + master.startsWith("spark://") && deployMode == "cluster" + } + override def toString = { s"""Parsed arguments: | master $master diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 4b347386397c..eb6065ff16c4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -135,8 +135,6 @@ private[spark] class StandaloneRestServerHandler( // Translate all fields to the relevant Spark properties val conf = new SparkConf(false) .setAll(sparkProperties) - // Use the actual master URL instead of the one that refers to this REST server - // Otherwise, once the driver is launched it will contact with the wrong server .set("spark.master", masterUrl) .set("spark.app.name", appName) jars.foreach { j => conf.set("spark.jars", j) } From d8d3717330c5608aa0ba07580078f2769c13f00f Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 27 Jan 2015 13:06:39 -0800 Subject: [PATCH 14/48] Use a daemon thread pool for REST server The motivation is to fix failing tests SparkSubmitSuite and DriverSuite. --- .../scala/org/apache/spark/deploy/rest/SubmitRestServer.scala | 4 ++++ .../test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala | 2 -- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 1cf02c0efd3b..addee24aace4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -26,6 +26,7 @@ import scala.io.Source import com.google.common.base.Charsets import org.eclipse.jetty.server.{Request, Server} import org.eclipse.jetty.server.handler.AbstractHandler +import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging, SparkConf} import org.apache.spark.util.Utils @@ -52,6 +53,9 @@ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, private def doStart(startPort: Int): (Server, Int) = { val server = new Server(new InetSocketAddress(host, requestedPort)) + val threadPool = new QueuedThreadPool + threadPool.setDaemon(true) + server.setThreadPool(threadPool) server.setHandler(handler) server.start() val boundPort = server.getConnectors()(0).getLocalPort diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 044e0699968e..807a50254882 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -264,7 +264,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--name", "testApp", "--master", "local", "--conf", "spark.ui.enabled=false", - "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -280,7 +279,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--master", "local-cluster[2,1,512]", "--jars", jarsString, "--conf", "spark.ui.enabled=false", - "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } From 837475b343e9aeb6c266ca2f3f730bb08808cab0 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 27 Jan 2015 15:56:47 -0800 Subject: [PATCH 15/48] Show the REST port on the Master UI --- .../org/apache/spark/ui/static/webui.css | 6 ++++++ .../apache/spark/deploy/DeployMessage.scala | 15 +++++++++++---- .../org/apache/spark/deploy/SparkSubmit.scala | 14 ++++++++------ .../apache/spark/deploy/master/Master.scala | 19 ++++++++++++------- .../spark/deploy/master/ui/MasterPage.scala | 9 +++++++++ .../spark/deploy/rest/SubmitRestServer.scala | 4 +++- 6 files changed, 49 insertions(+), 18 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index f23ba9dba167..7d3abee8c956 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -103,6 +103,12 @@ span.expand-details { float: right; } +span.stable-uri { + font-size: 10pt; + font-style: italic; + color: gray; +} + pre { font-size: 0.8em; } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 243d8edb72ed..d95830d515ca 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -148,15 +148,22 @@ private[deploy] object DeployMessages { // Master to MasterWebUI - case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo], - activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], - activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo], - status: MasterState) { + case class MasterStateResponse( + host: String, + port: Int, + stablePort: Option[Int], + workers: Array[WorkerInfo], + activeApps: Array[ApplicationInfo], + completedApps: Array[ApplicationInfo], + activeDrivers: Array[DriverInfo], + completedDrivers: Array[DriverInfo], + status: MasterState) { Utils.checkHost(host, "Required hostname") assert (port > 0) def uri = "spark://" + host + ":" + port + def stableUri: Option[String] = stablePort.map { p => "spark://" + host + ":" + p } } // WorkerWebUI to Worker diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 617785a67c34..b5a046d25a13 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -112,15 +112,17 @@ object SparkSubmit { * Second, we use this launch environment to invoke the main method of the child * main class. * - * Note that standalone cluster mode is an exception in that we do not invoke the - * main method of a child class. Instead, we pass the submit parameters directly to - * a REST client, which will submit the application using the stable REST protocol. + * As of Spark 1.3, a stable REST-based application submission gateway is introduced. + * If this is enabled, then we will run standalone cluster mode by passing the submit + * parameters directly to a REST client, which will submit the application using the + * REST protocol instead. */ private def submit(args: SparkSubmitArguments): Unit = { val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) - // In standalone cluster mode, use the stable application submission REST protocol. - // Otherwise, just call the main method of the child class. - if (args.isStandaloneCluster) { + val restKey = "spark.submit.rest.enabled" + val restEnabled = args.sparkProperties.get(restKey).getOrElse("false").toBoolean + if (args.isStandaloneCluster && restEnabled) { + printStream.println("Running standalone cluster mode using the stable REST protocol.") // NOTE: since we mutate the values of some configs in `prepareSubmitEnvironment`, we // must update the corresponding fields in the original SparkSubmitArguments to reflect // these changes. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 7956a041e722..f667a5d28edb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -124,11 +124,14 @@ private[spark] class Master( // Alternative application submission gateway that is stable across Spark versions private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) - private val restServerPort = conf.getInt("spark.master.rest.port", 17077) - private val restServer = new StandaloneRestServer(this, host, restServerPort) - if (restServerEnabled) { - restServer.start() - } + private val restServer = + if (restServerEnabled) { + val port = conf.getInt("spark.master.rest.port", 17077) + Some(new StandaloneRestServer(this, host, port)) + } else { + None + } + private val restServerBoundPort = restServer.map(_.start()) override def preStart() { logInfo("Starting Spark master at " + masterUrl) @@ -183,7 +186,7 @@ private[spark] class Master( recoveryCompletionTask.cancel() } webUi.stop() - restServer.stop() + restServer.foreach(_.stop()) masterMetricsSystem.stop() applicationMetricsSystem.stop() persistenceEngine.close() @@ -431,7 +434,9 @@ private[spark] class Master( } case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, + sender ! MasterStateResponse( + host, port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, drivers.toArray, completedDrivers.toArray, state) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7ca3b08a2872..55a1346a79bd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -73,6 +73,15 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
  • URL: {state.uri}
  • + { + state.stableUri + .map { uri => +
  • + Stable URL: {uri} + (for standalone cluster mode in Spark 1.3+) +
  • } + .getOrElse { Seq.empty } + }
  • Workers: {state.workers.size}
  • Cores: {state.workers.map(_.cores).sum} Total, {state.workers.map(_.coresUsed).sum} Used
  • diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index addee24aace4..9bc7220eb19d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -41,10 +41,12 @@ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, protected val handler: SubmitRestServerHandler private var _server: Option[Server] = None - def start(): Unit = { + /** Start the server and return the bound port. */ + def start(): Int = { val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, conf) _server = Some(server) logInfo(s"Started REST server for submitting applications on port $boundPort") + boundPort } def stop(): Unit = { From e42c131001cb8dca7cbbee033c4d60f3a4cc5d8c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 28 Jan 2015 00:15:04 -0800 Subject: [PATCH 16/48] Add end-to-end tests for standalone REST protocol This is actually non-trivial because we must run standalone mode instead of relying on the existing local-cluster mode. This means we must manually start our own Master and Workers, and provide a real jar when submitting the test application, which involves manually packaging our own jar. Further, since the driver output is difficult to obtain programmatically in cluster mode, we need to write the results to a special file and verify them later. --- .../spark/deploy/LocalSparkCluster.scala | 2 +- .../org/apache/spark/deploy/SparkSubmit.scala | 45 ++-- .../spark/deploy/SparkSubmitArguments.scala | 5 + .../apache/spark/deploy/master/Master.scala | 29 ++- .../spark/deploy/master/MasterMessages.scala | 4 +- .../spark/deploy/JsonProtocolSuite.scala | 3 +- .../spark/deploy/SparkSubmitSuite.scala | 28 +++ .../rest/StandaloneRestProtocolSuite.scala | 234 ++++++++++++++++++ 8 files changed, 321 insertions(+), 29 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 9a7a113c9571..17f729c0e075 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -45,7 +45,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I /* Start the Master */ val conf = new SparkConf(false) - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) + val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) masterActorSystems += masterSystem val masterUrl = "spark://" + localHostname + ":" + masterPort val masters = Array(masterUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index b5a046d25a13..42ebf338b0bf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -117,19 +117,10 @@ object SparkSubmit { * parameters directly to a REST client, which will submit the application using the * REST protocol instead. */ - private def submit(args: SparkSubmitArguments): Unit = { + private[spark] def submit(args: SparkSubmitArguments): Unit = { val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) - val restKey = "spark.submit.rest.enabled" - val restEnabled = args.sparkProperties.get(restKey).getOrElse("false").toBoolean - if (args.isStandaloneCluster && restEnabled) { + if (args.isStandaloneCluster && args.isRestEnabled) { printStream.println("Running standalone cluster mode using the stable REST protocol.") - // NOTE: since we mutate the values of some configs in `prepareSubmitEnvironment`, we - // must update the corresponding fields in the original SparkSubmitArguments to reflect - // these changes. - args.sparkProperties.clear() - args.sparkProperties ++= sysProps - sysProps.get("spark.jars").foreach { args.jars = _ } - sysProps.get("spark.files").foreach { args.files = _ } new StandaloneRestClient().submitDriver(args) } else { runMain(childArgs, childClasspath, sysProps, childMainClass) @@ -159,7 +150,7 @@ object SparkSubmit { case m if m.startsWith("spark") => STANDALONE case m if m.startsWith("mesos") => MESOS case m if m.startsWith("local") => LOCAL - case _ => printErrorAndExit("Master must start with yarn, spark, mesos or local"); -1 + case _ => printErrorAndExit("Master must start with yarn, spark, mesos, or local"); -1 } // Set the deploy mode; default is client mode @@ -249,6 +240,8 @@ object SparkSubmit { // Standalone cluster only OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"), + OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"), + OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"), // Yarn client only OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"), @@ -299,6 +292,20 @@ object SparkSubmit { } } + // In standalone-cluster mode, use Client as a wrapper around the user class + // Note that we won't actually launch this class if we're using the stable REST protocol + if (args.isStandaloneCluster && !args.isRestEnabled) { + childMainClass = "org.apache.spark.deploy.Client" + if (args.supervise) { + childArgs += "--supervise" + } + childArgs += "launch" + childArgs += (args.master, args.primaryResource, args.mainClass) + if (args.childArgs != null) { + childArgs ++= args.childArgs + } + } + // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python files, the primary resource is already distributed as a regular file @@ -356,14 +363,24 @@ object SparkSubmit { sysProps("spark.submit.pyFiles") = formattedPyFiles } + // NOTE: If we are using the REST gateway, we will use the original arguments directly. + // Since we mutate the values of some configs in this method, we must update the + // corresponding fields in the original SparkSubmitArguments to reflect these changes. + if (args.isStandaloneCluster && args.isRestEnabled) { + args.sparkProperties.clear() + args.sparkProperties ++= sysProps + sysProps.get("spark.jars").foreach { args.jars = _ } + sysProps.get("spark.files").foreach { args.files = _ } + } + (childArgs, childClasspath, sysProps, childMainClass) } /** * Run the main method of the child class using the provided launch environment. * - * Depending on the deploy mode, cluster manager, and the type of the application, - * this main class may not necessarily be the one provided by the user. + * Note that this main class will not be the one provided by the user if we're + * running cluster deploy mode or python applications. */ private def runMain( childArgs: Seq[String], diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index abaf12c09ad4..0032759d7577 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -256,6 +256,11 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St master.startsWith("spark://") && deployMode == "cluster" } + /** Return whether the stable application submission REST gateway is enabled. */ + def isRestEnabled: Boolean = { + sparkProperties.get("spark.submit.rest.enabled").getOrElse("false").toBoolean + } + override def toString = { s"""Parsed arguments: | master $master diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index f667a5d28edb..9b5286255985 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -53,12 +53,12 @@ private[spark] class Master( host: String, port: Int, webUiPort: Int, - val securityMgr: SecurityManager) + val securityMgr: SecurityManager, + val conf: SparkConf) extends Actor with ActorLogReceive with Logging with LeaderElectable { import context.dispatcher // to use Akka's scheduler.schedule() - val conf = new SparkConf val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs @@ -444,8 +444,8 @@ private[spark] class Master( timeOutDeadWorkers() } - case RequestWebUIPort => { - sender ! WebUIPortResponse(webUi.boundPort) + case BoundPortsRequest => { + sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) } } @@ -866,7 +866,7 @@ private[spark] object Master extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) + val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) actorSystem.awaitTermination() } @@ -890,19 +890,26 @@ private[spark] object Master extends Logging { Address("akka.tcp", systemName, host, port) } + /** + * Start the Master and return a four tuple of: + * (1) The Master actor system + * (2) The bound port + * (3) The web UI bound port + * (4) The REST server bound port, if any + */ def startSystemAndActor( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int) = { + conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, securityManager = securityMgr) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, - securityMgr), actorName) + val actor = actorSystem.actorOf( + Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) val timeout = AkkaUtils.askTimeout(conf) - val respFuture = actor.ask(RequestWebUIPort)(timeout) - val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] - (actorSystem, boundPort, resp.webUIBoundPort) + val portsRequest = actor.ask(BoundPortsRequest)(timeout) + val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] + (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.stablePort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index db72d8ae9bda..ca9f93edca25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -36,7 +36,7 @@ private[master] object MasterMessages { case object CompleteRecovery - case object RequestWebUIPort + case object BoundPortsRequest - case class WebUIPortResponse(webUIBoundPort: Int) + case class BoundPortsResponse(actorPort: Int, webUIPort: Int, stablePort: Option[Int]) } diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index aa65f7e8915e..f39e7a08a54a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -68,7 +68,8 @@ class JsonProtocolSuite extends FunSuite { val completedApps = Array[ApplicationInfo]() val activeDrivers = Array(createDriverInfo()) val completedDrivers = Array(createDriverInfo()) - val stateResponse = new MasterStateResponse("host", 8080, workers, activeApps, completedApps, + val stateResponse = new MasterStateResponse( + "host", 8080, None, workers, activeApps, completedApps, activeDrivers, completedDrivers, RecoveryState.ALIVE) val output = JsonProtocol.writeMasterState(stateResponse) assertValidJson(output) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 807a50254882..9964d114e4f6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -197,6 +197,33 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties sysProps("spark.shuffle.spill") should be ("false") } + test("handles standalone cluster mode") { + val clArgs = Seq( + "--deploy-mode", "cluster", + "--master", "spark://h:p", + "--class", "org.SomeClass", + "--supervise", + "--driver-memory", "4g", + "--driver-cores", "5", + "--conf", "spark.shuffle.spill=false", + "thejar.jar", + "arg1", "arg2") + val appArgs = new SparkSubmitArguments(clArgs) + val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val childArgsStr = childArgs.mkString(" ") + childArgsStr should startWith ("--memory 4g --cores 5 --supervise") + childArgsStr should include regex ("launch spark://h:p .*thejar.jar org.SomeClass arg1 arg2") + mainClass should be ("org.apache.spark.deploy.Client") + classpath should have size (0) + sysProps should have size (5) + sysProps.keys should contain ("SPARK_SUBMIT") + sysProps.keys should contain ("spark.master") + sysProps.keys should contain ("spark.app.name") + sysProps.keys should contain ("spark.jars") + sysProps.keys should contain ("spark.shuffle.spill") + sysProps("spark.shuffle.spill") should be ("false") + } + test("handles standalone client mode") { val clArgs = Seq( "--deploy-mode", "client", @@ -279,6 +306,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--master", "local-cluster[2,1,512]", "--jars", jarsString, "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala new file mode 100644 index 000000000000..ac2e3880ca19 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -0,0 +1,234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import java.io.{File, FileInputStream, FileOutputStream, PrintWriter} +import java.util.jar.{JarEntry, JarOutputStream} +import java.util.zip.ZipEntry + +import scala.collection.mutable.ArrayBuffer +import scala.io.Source + +import akka.actor.ActorSystem +import com.google.common.io.ByteStreams +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} +import org.scalatest.exceptions.TestFailedException + +import org.apache.spark._ +import org.apache.spark.util.Utils +import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} +import org.apache.spark.deploy.master.{DriverState, Master} +import org.apache.spark.deploy.worker.Worker + +/** + * End-to-end tests for the stable application submission protocol in standalone mode. + */ +class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { + private val systemsToStop = new ArrayBuffer[ActorSystem] + private val masterRestUrl = startLocalCluster() + private val client = new StandaloneRestClient + private val mainJar = StandaloneRestProtocolSuite.createJar() + private val mainClass = StandaloneRestApp.getClass.getName.stripSuffix("$") + + override def afterAll() { + systemsToStop.foreach(_.shutdown()) + } + + test("simple submit until completion") { + val resultsFile = File.createTempFile("test-submit", ".txt") + val numbers = Seq(1, 2, 3) + val size = 500 + val driverId = submitApp(resultsFile, numbers, size) + waitUntilFinished(driverId) + validateResult(resultsFile, numbers, size) + } + + test("kill empty driver") { + val killResponse = client.killDriver(masterRestUrl, "driver-that-does-not-exist") + val killSuccess = killResponse.getFieldNotNull(KillDriverResponseField.SUCCESS) + assert(killSuccess === "false") + } + + test("kill running driver") { + val resultsFile = File.createTempFile("test-kill", ".txt") + val numbers = Seq(1, 2, 3) + val size = 500 + val driverId = submitApp(resultsFile, numbers, size) + val killResponse = client.killDriver(masterRestUrl, driverId) + val killSuccess = killResponse.getFieldNotNull(KillDriverResponseField.SUCCESS) + waitUntilFinished(driverId) + val statusResponse = client.requestDriverStatus(masterRestUrl, driverId) + val statusSuccess = statusResponse.getFieldNotNull(DriverStatusResponseField.SUCCESS) + val driverState = statusResponse.getFieldNotNull(DriverStatusResponseField.DRIVER_STATE) + assert(killSuccess === "true") + assert(statusSuccess === "true") + assert(driverState === DriverState.KILLED.toString) + intercept[TestFailedException] { validateResult(resultsFile, numbers, size) } + } + + test("request status for empty driver") { + val statusResponse = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist") + val statusSuccess = statusResponse.getFieldNotNull(DriverStatusResponseField.SUCCESS) + assert(statusSuccess === "false") + } + + /** + * Start a local cluster containing one Master and a few Workers. + * Do not use org.apache.spark.deploy.LocalCluster here because we want the REST URL. + */ + private def startLocalCluster(): String = { + val conf = new SparkConf(false) + .set("spark.master.rest.enabled", "true") + .set("spark.master.rest.port", "0") + val (numWorkers, coresPerWorker, memPerWorker) = (2, 1, 512) + val localHostName = Utils.localHostName() + val (masterSystem, masterPort, _, _masterRestPort) = + Master.startSystemAndActor(localHostName, 0, 0, conf) + val masterRestPort = _masterRestPort.getOrElse { fail("REST server not started on Master!") } + val masterUrl = "spark://" + localHostName + ":" + masterPort + val masterRestUrl = "spark://" + localHostName + ":" + masterRestPort + (1 to numWorkers).foreach { n => + val (workerSystem, _) = Worker.startSystemAndActor( + localHostName, 0, 0, coresPerWorker, memPerWorker, Array(masterUrl), null, Some(n)) + systemsToStop.append(workerSystem) + } + systemsToStop.append(masterSystem) + masterRestUrl + } + + /** + * Submit an application through the stable gateway and return the corresponding driver ID. + */ + private def submitApp(resultsFile: File, numbers: Seq[Int], size: Int): String = { + val appArgs = Seq(resultsFile.getAbsolutePath) ++ numbers.map(_.toString) ++ Seq(size.toString) + val commandLineArgs = Array( + "--deploy-mode", "cluster", + "--master", masterRestUrl, + "--name", mainClass, + "--class", mainClass, + "--conf", "spark.submit.rest.enabled=true", + mainJar) ++ appArgs + val args = new SparkSubmitArguments(commandLineArgs) + SparkSubmit.prepareSubmitEnvironment(args) + val submitResponse = client.submitDriver(args) + submitResponse.getFieldNotNull(SubmitDriverResponseField.DRIVER_ID) + } + + /** + * Wait until the given driver has finished running, + * up to the specified maximum number of seconds. + */ + private def waitUntilFinished(driverId: String, maxSeconds: Int = 10): Unit = { + var finished = false + val expireTime = System.currentTimeMillis + maxSeconds * 1000 + while (!finished) { + val statusResponse = client.requestDriverStatus(masterRestUrl, driverId) + val driverState = statusResponse.getFieldNotNull(DriverStatusResponseField.DRIVER_STATE) + finished = + driverState != DriverState.SUBMITTED.toString && + driverState != DriverState.RUNNING.toString + if (System.currentTimeMillis > expireTime) { + fail(s"Driver $driverId did not finish within $maxSeconds seconds.") + } + Thread.sleep(1000) + } + } + + /** Validate whether the application produced the corrupt output. */ + private def validateResult(resultsFile: File, numbers: Seq[Int], size: Int): Unit = { + val lines = Source.fromFile(resultsFile.getAbsolutePath).getLines().toSeq + val unexpectedContent = + if (lines.nonEmpty) { + "[\n" + lines.map { l => " " + l }.mkString("\n") + "\n]" + } else { + "[EMPTY]" + } + assert(lines.size === 2, s"Unexpected content in file: $unexpectedContent") + assert(lines(0).toInt === numbers.sum, s"Sum of ${numbers.mkString(",")} is incorrect") + assert(lines(1).toInt === (size / 2) + 1, "Result of Spark job is incorrect") + } +} + +private object StandaloneRestProtocolSuite { + private val pathPrefix = "org/apache/spark/deploy/rest" + + /** + * Create a jar that contains all the class files needed for running the StandaloneRestApp. + * Return the absolute path to that jar. + */ + def createJar(): String = { + val jarFile = File.createTempFile("test-standalone-rest-protocol", ".jar") + val jarFileStream = new FileOutputStream(jarFile) + val jarStream = new JarOutputStream(jarFileStream, new java.util.jar.Manifest) + jarStream.putNextEntry(new ZipEntry(pathPrefix)) + getClassFiles.foreach { cf => + jarStream.putNextEntry(new JarEntry(pathPrefix + "/" + cf.getName)) + val in = new FileInputStream(cf) + ByteStreams.copy(in, jarStream) + in.close() + } + jarStream.close() + jarFileStream.close() + jarFile.getAbsolutePath + } + + /** + * Return a list of class files compiled for StandaloneRestApp. + * This includes all the anonymous classes used in StandaloneRestApp#main. + */ + private def getClassFiles: Seq[File] = { + val clazz = StandaloneRestApp.getClass + val className = Utils.getFormattedClassName(StandaloneRestApp) + val basePath = clazz.getProtectionDomain.getCodeSource.getLocation.toURI.getPath + val baseDir = new File(basePath + "/" + pathPrefix) + baseDir.listFiles().filter(_.getName.contains(className)) + } +} + +/** + * Sample application to be submitted to the cluster using the stable gateway. + * All relevant classes will be packaged into a jar dynamically and submitted to the cluster. + */ +object StandaloneRestApp { + // Usage: [path to results file] [num1] [num2] [num3] [rddSize] + // The first line of the results file should be (num1 + num2 + num3) + // The second line should be (rddSize / 2) + 1 + def main(args: Array[String]) { + assert(args.size == 5) + val resultFile = new File(args(0)) + val writer = new PrintWriter(resultFile) + try { + val firstLine = args(1).toInt + args(2).toInt + args(3).toInt + val rddSize = args(4).toInt + val conf = new SparkConf() + val sc = new SparkContext(conf) + val secondLine = sc.parallelize(1 to rddSize) + .map { i => (i / 2, i) } + .reduceByKey(_ + _) + .count() + writer.println(firstLine) + writer.println(secondLine) + } catch { + case e: Exception => + writer.println(e) + e.getStackTrace.foreach { l => writer.println(" " + l) } + } finally { + writer.close() + } + } +} From d7a1f9fbb586035a53ea0a22f6703d11ed3818db Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 28 Jan 2015 09:05:32 -0800 Subject: [PATCH 17/48] Fix local cluster tests --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../org/apache/spark/deploy/LocalSparkCluster.scala | 12 +++++++++--- .../org/apache/spark/deploy/SparkSubmitSuite.scala | 1 - .../serializer/KryoSerializerDistributedSuite.scala | 6 +++--- pom.xml | 1 - project/SparkBuild.scala | 1 - 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4c4ee04cc515..1cb1e2bc4b2b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1945,7 +1945,7 @@ object SparkContext extends Logging { val scheduler = new TaskSchedulerImpl(sc) val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf) val masterUrls = localCluster.start() val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls) scheduler.initialize(backend) diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 17f729c0e075..0401b15446a7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -33,7 +33,11 @@ import org.apache.spark.util.Utils * fault recovery without spinning up a lot of processes. */ private[spark] -class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) +class LocalSparkCluster( + numWorkers: Int, + coresPerWorker: Int, + memoryPerWorker: Int, + conf: SparkConf) extends Logging { private val localHostname = Utils.localHostName() @@ -43,9 +47,11 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") + // Disable REST server on Master in this mode unless otherwise specified + val _conf = conf.clone().setIfMissing("spark.master.rest.enabled", "false") + /* Start the Master */ - val conf = new SparkConf(false) - val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) + val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf) masterActorSystems += masterSystem val masterUrl = "spark://" + localHostname + ":" + masterPort val masters = Array(masterUrl) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 9964d114e4f6..82047ef22ee6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -306,7 +306,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--master", "local-cluster[2,1,512]", "--jars", jarsString, "--conf", "spark.ui.enabled=false", - "--conf", "spark.master.rest.enabled=false", unusedJar.toString) runSparkSubmit(args) } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 855f1b627608..054a4c64897a 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -29,9 +29,9 @@ class KryoSerializerDistributedSuite extends FunSuite { test("kryo objects are serialised consistently in different processes") { val conf = new SparkConf(false) - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) - conf.set("spark.task.maxFailures", "1") + .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + .set("spark.kryo.registrator", classOf[AppJarRegistrator].getName) + .set("spark.task.maxFailures", "1") val jar = TestUtils.createJarWithClasses(List(AppJarRegistrator.customClassName)) conf.setJars(List(jar.getPath)) diff --git a/pom.xml b/pom.xml index 7f0f384b38e4..05cb3797fc55 100644 --- a/pom.xml +++ b/pom.xml @@ -1127,7 +1127,6 @@ 1 false false - false ${test_classpath} true diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 85e6f3b74900..ded4b5443a90 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -380,7 +380,6 @@ object TestSettings { javaOptions in Test += "-Dspark.port.maxRetries=100", javaOptions in Test += "-Dspark.ui.enabled=false", javaOptions in Test += "-Dspark.ui.showConsoleProgress=false", - javaOptions in Test += "-Dspark.master.rest.enabled=false", javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true", javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true", javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark") From df90e8b32ce017294cc0a47bcb78e118943662f9 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 29 Jan 2015 11:01:41 -0800 Subject: [PATCH 18/48] Use Jackson for JSON de/serialization This involves a major refactor of all message representations. The main motivation for this change is to simplify the logic to enforce type safety, such that we no longer depend on the behavior of all the scala class magic we used to rely on. This commit also introduces a differentiation between request and response messages to provide further type safety. This would have introduced much additional complexity without the refactor. --- .../deploy/rest/DriverStatusRequest.scala | 31 + .../rest/DriverStatusRequestMessage.scala | 47 -- .../deploy/rest/DriverStatusResponse.scala | 45 ++ .../rest/DriverStatusResponseMessage.scala | 52 -- .../spark/deploy/rest/ErrorMessage.scala | 44 -- .../spark/deploy/rest/ErrorResponse.scala | 26 + .../spark/deploy/rest/KillDriverRequest.scala | 31 + .../rest/KillDriverRequestMessage.scala | 47 -- .../deploy/rest/KillDriverResponse.scala | 36 ++ .../rest/KillDriverResponseMessage.scala | 48 -- .../deploy/rest/StandaloneRestClient.scala | 79 ++- .../deploy/rest/StandaloneRestServer.scala | 95 ++- .../deploy/rest/SubmitDriverRequest.scala | 131 ++++ .../rest/SubmitDriverRequestMessage.scala | 155 ----- .../deploy/rest/SubmitDriverResponse.scala | 35 ++ .../rest/SubmitDriverResponseMessage.scala | 48 -- .../spark/deploy/rest/SubmitRestClient.scala | 22 +- .../deploy/rest/SubmitRestProtocolField.scala | 89 +-- .../rest/SubmitRestProtocolMessage.scala | 294 ++++----- .../spark/deploy/rest/SubmitRestServer.scala | 26 +- .../org/apache/spark/util/JsonProtocol.scala | 9 + .../rest/StandaloneRestProtocolSuite.scala | 14 +- .../deploy/rest/SubmitRestProtocolSuite.scala | 577 +++++++++--------- 23 files changed, 910 insertions(+), 1071 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala new file mode 100644 index 000000000000..f5d4d95cebf1 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class DriverStatusRequest extends SubmitRestProtocolRequest { + protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_REQUEST + private val driverId = new SubmitRestProtocolField[String] + + def getDriverId: String = driverId.toString + def setDriverId(s: String): this.type = setField(driverId, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(driverId, "driver_id") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala deleted file mode 100644 index f0d0c5f874d5..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequestMessage.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a DriverStatusRequestMessage. - */ -private[spark] abstract class DriverStatusRequestField extends SubmitRestProtocolField -private[spark] object DriverStatusRequestField - extends SubmitRestProtocolFieldCompanion[DriverStatusRequestField] { - case object ACTION extends DriverStatusRequestField with ActionField - case object CLIENT_SPARK_VERSION extends DriverStatusRequestField - case object MESSAGE extends DriverStatusRequestField - case object DRIVER_ID extends DriverStatusRequestField - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, DRIVER_ID) - override val optionalFields = Seq(MESSAGE) -} - -/** - * A request sent to the cluster manager to query the status of a driver - * in the stable application submission REST protocol. - */ -private[spark] class DriverStatusRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.DRIVER_STATUS_REQUEST, - DriverStatusRequestField.ACTION, - DriverStatusRequestField.requiredFields) - -private[spark] object DriverStatusRequestMessage - extends SubmitRestProtocolMessageCompanion[DriverStatusRequestMessage] { - protected override def newMessage() = new DriverStatusRequestMessage - protected override def fieldFromString(f: String) = DriverStatusRequestField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala new file mode 100644 index 000000000000..1e8090c33681 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class DriverStatusResponse extends SubmitRestProtocolResponse { + protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE + private val driverId = new SubmitRestProtocolField[String] + private val success = new SubmitRestProtocolField[Boolean] + private val driverState = new SubmitRestProtocolField[String] + private val workerId = new SubmitRestProtocolField[String] + private val workerHostPort = new SubmitRestProtocolField[String] + + def getDriverId: String = driverId.toString + def getSuccess: String = success.toString + def getDriverState: String = driverState.toString + def getWorkerId: String = workerId.toString + def getWorkerHostPort: String = workerHostPort.toString + + def setDriverId(s: String): this.type = setField(driverId, s) + def setSuccess(s: String): this.type = setBooleanField(success, s) + def setDriverState(s: String): this.type = setField(driverState, s) + def setWorkerId(s: String): this.type = setField(workerId, s) + def setWorkerHostPort(s: String): this.type = setField(workerHostPort, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(driverId, "driver_id") + assertFieldIsSet(success, "success") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala deleted file mode 100644 index d65145248505..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponseMessage.scala +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a DriverStatusResponseMessage. - */ -private[spark] abstract class DriverStatusResponseField extends SubmitRestProtocolField -private[spark] object DriverStatusResponseField - extends SubmitRestProtocolFieldCompanion[DriverStatusResponseField] { - case object ACTION extends DriverStatusResponseField with ActionField - case object SERVER_SPARK_VERSION extends DriverStatusResponseField - case object MESSAGE extends DriverStatusResponseField - case object DRIVER_ID extends DriverStatusResponseField - case object SUCCESS extends DriverStatusResponseField with BooleanField - // Standalone specific fields - case object DRIVER_STATE extends DriverStatusResponseField - case object WORKER_ID extends DriverStatusResponseField - case object WORKER_HOST_PORT extends DriverStatusResponseField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, DRIVER_ID, SUCCESS) - override val optionalFields = Seq(MESSAGE, DRIVER_STATE, WORKER_ID, WORKER_HOST_PORT) -} - -/** - * A message sent from the cluster manager in response to a DriverStatusRequestMessage - * in the stable application submission REST protocol. - */ -private[spark] class DriverStatusResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE, - DriverStatusResponseField.ACTION, - DriverStatusResponseField.requiredFields) - -private[spark] object DriverStatusResponseMessage - extends SubmitRestProtocolMessageCompanion[DriverStatusResponseMessage] { - protected override def newMessage() = new DriverStatusResponseMessage - protected override def fieldFromString(f: String) = DriverStatusResponseField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala deleted file mode 100644 index f1fbdd227507..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorMessage.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in an ErrorMessage. - */ -private[spark] abstract class ErrorField extends SubmitRestProtocolField -private[spark] object ErrorField extends SubmitRestProtocolFieldCompanion[ErrorField] { - case object ACTION extends ErrorField with ActionField - case object SERVER_SPARK_VERSION extends ErrorField - case object MESSAGE extends ErrorField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE) - override val optionalFields = Seq.empty -} - -/** - * An error message sent from the cluster manager - * in the stable application submission REST protocol. - */ -private[spark] class ErrorMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.ERROR, - ErrorField.ACTION, - ErrorField.requiredFields) - -private[spark] object ErrorMessage extends SubmitRestProtocolMessageCompanion[ErrorMessage] { - protected override def newMessage() = new ErrorMessage - protected override def fieldFromString(f: String) = ErrorField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala new file mode 100644 index 000000000000..8c30d3185088 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class ErrorResponse extends SubmitRestProtocolResponse { + protected override val action = SubmitRestProtocolAction.ERROR + override def validate(): Unit = { + super.validate() + assertFieldIsSet(message, "message") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala new file mode 100644 index 000000000000..c44c94d95a1f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class KillDriverRequest extends SubmitRestProtocolRequest { + protected override val action = SubmitRestProtocolAction.KILL_DRIVER_REQUEST + private val driverId = new SubmitRestProtocolField[String] + + def getDriverId: String = driverId.toString + def setDriverId(s: String): this.type = setField(driverId, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(driverId, "driver_id") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala deleted file mode 100644 index 232bb364e889..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequestMessage.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a KillDriverRequestMessage. - */ -private[spark] abstract class KillDriverRequestField extends SubmitRestProtocolField -private[spark] object KillDriverRequestField - extends SubmitRestProtocolFieldCompanion[KillDriverRequestField] { - case object ACTION extends KillDriverRequestField with ActionField - case object CLIENT_SPARK_VERSION extends KillDriverRequestField - case object MESSAGE extends KillDriverRequestField - case object DRIVER_ID extends KillDriverRequestField - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, DRIVER_ID) - override val optionalFields = Seq(MESSAGE) -} - -/** - * A request sent to the cluster manager to kill a driver - * in the stable application submission REST protocol. - */ -private[spark] class KillDriverRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.KILL_DRIVER_REQUEST, - KillDriverRequestField.ACTION, - KillDriverRequestField.requiredFields) - -private[spark] object KillDriverRequestMessage - extends SubmitRestProtocolMessageCompanion[KillDriverRequestMessage] { - protected override def newMessage() = new KillDriverRequestMessage - protected override def fieldFromString(f: String) = KillDriverRequestField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala new file mode 100644 index 000000000000..e75a52bc9bf0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class KillDriverResponse extends SubmitRestProtocolResponse { + protected override val action = SubmitRestProtocolAction.KILL_DRIVER_RESPONSE + private val driverId = new SubmitRestProtocolField[String] + private val success = new SubmitRestProtocolField[Boolean] + + def getDriverId: String = driverId.toString + def getSuccess: String = success.toString + + def setDriverId(s: String): this.type = setField(driverId, s) + def setSuccess(s: String): this.type = setBooleanField(success, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(driverId, "driver_id") + assertFieldIsSet(success, "success") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala deleted file mode 100644 index 0717131ab2ec..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponseMessage.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a KillDriverResponseMessage. - */ -private[spark] abstract class KillDriverResponseField extends SubmitRestProtocolField -private[spark] object KillDriverResponseField - extends SubmitRestProtocolFieldCompanion[KillDriverResponseField] { - case object ACTION extends KillDriverResponseField with ActionField - case object SERVER_SPARK_VERSION extends KillDriverResponseField - case object MESSAGE extends KillDriverResponseField - case object DRIVER_ID extends KillDriverResponseField - case object SUCCESS extends KillDriverResponseField with BooleanField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, DRIVER_ID, SUCCESS) - override val optionalFields = Seq(MESSAGE) -} - -/** - * A message sent from the cluster manager in response to a KillDriverRequestMessage - * in the stable application submission REST protocol. - */ -private[spark] class KillDriverResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.KILL_DRIVER_RESPONSE, - KillDriverResponseField.ACTION, - KillDriverResponseField.requiredFields) - -private[spark] object KillDriverResponseMessage - extends SubmitRestProtocolMessageCompanion[KillDriverResponseMessage] { - protected override def newMessage() = new KillDriverResponseMessage - protected override def fieldFromString(f: String) = KillDriverResponseField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 278c9af749b1..b564006fd745 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -37,16 +37,15 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { * just submitted and reports it to the user. Otherwise, if the submission was unsuccessful, * this reports failure and logs an error message provided by the REST server. */ - override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponseMessage = { - import SubmitDriverResponseField._ - val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponseMessage] - val submitSuccess = submitResponse.getFieldNotNull(SUCCESS).toBoolean + override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { + val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponse] + val submitSuccess = submitResponse.getSuccess.toBoolean if (submitSuccess) { - val driverId = submitResponse.getFieldNotNull(DRIVER_ID) + val driverId = submitResponse.getDriverId logInfo(s"Driver successfully submitted as $driverId. Polling driver state...") pollSubmittedDriverStatus(args.master, driverId) } else { - val submitMessage = submitResponse.getFieldNotNull(MESSAGE) + val submitMessage = submitResponse.getMessage logError(s"Application submission failed: $submitMessage") } submitResponse @@ -57,16 +56,15 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { * This retries up to a fixed number of times until giving up. */ private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { - import DriverStatusResponseField._ (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => val statusResponse = requestDriverStatus(master, driverId) - .asInstanceOf[DriverStatusResponseMessage] - val statusSuccess = statusResponse.getFieldNotNull(SUCCESS).toBoolean + .asInstanceOf[DriverStatusResponse] + val statusSuccess = statusResponse.getSuccess.toBoolean if (statusSuccess) { - val driverState = statusResponse.getFieldNotNull(DRIVER_STATE) - val workerId = statusResponse.getFieldOption(WORKER_ID) - val workerHostPort = statusResponse.getFieldOption(WORKER_HOST_PORT) - val exception = statusResponse.getFieldOption(MESSAGE) + val driverState = statusResponse.getDriverState + val workerId = Option(statusResponse.getWorkerId) + val workerHostPort = Option(statusResponse.getWorkerHostPort) + val exception = Option(statusResponse.getMessage) logInfo(s"State of driver $driverId is now $driverState.") // Log worker node, if present (workerId, workerHostPort) match { @@ -83,26 +81,23 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { /** Construct a submit driver request message. */ override protected def constructSubmitRequest( - args: SparkSubmitArguments): SubmitDriverRequestMessage = { - import SubmitDriverRequestField._ - val dm = Option(args.driverMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull - val em = Option(args.executorMemory).map { m => Utils.memoryStringToMb(m).toString }.orNull - val message = new SubmitDriverRequestMessage() - .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(APP_NAME, args.name) - .setField(APP_RESOURCE, args.primaryResource) - .setFieldIfNotNull(MAIN_CLASS, args.mainClass) - .setFieldIfNotNull(JARS, args.jars) - .setFieldIfNotNull(FILES, args.files) - .setFieldIfNotNull(DRIVER_MEMORY, dm) - .setFieldIfNotNull(DRIVER_CORES, args.driverCores) - .setFieldIfNotNull(DRIVER_EXTRA_JAVA_OPTIONS, args.driverExtraJavaOptions) - .setFieldIfNotNull(DRIVER_EXTRA_CLASS_PATH, args.driverExtraClassPath) - .setFieldIfNotNull(DRIVER_EXTRA_LIBRARY_PATH, args.driverExtraLibraryPath) - .setFieldIfNotNull(SUPERVISE_DRIVER, args.supervise.toString) - .setFieldIfNotNull(EXECUTOR_MEMORY, em) - .setFieldIfNotNull(TOTAL_EXECUTOR_CORES, args.totalExecutorCores) - args.childArgs.foreach(message.appendAppArg) + args: SparkSubmitArguments): SubmitDriverRequest = { + val message = new SubmitDriverRequest() + .setSparkVersion(sparkVersion) + .setAppName(args.name) + .setAppResource(args.primaryResource) + .setMainClass(args.mainClass) + .setJars(args.jars) + .setFiles(args.files) + .setDriverMemory(args.driverMemory) + .setDriverCores(args.driverCores) + .setDriverExtraJavaOptions(args.driverExtraJavaOptions) + .setDriverExtraClassPath(args.driverExtraClassPath) + .setDriverExtraLibraryPath(args.driverExtraLibraryPath) + .setSuperviseDriver(args.supervise.toString) + .setExecutorMemory(args.executorMemory) + .setTotalExecutorCores(args.totalExecutorCores) + args.childArgs.foreach(message.addAppArg) args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } // TODO: send special environment variables? message @@ -111,21 +106,19 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { /** Construct a kill driver request message. */ override protected def constructKillRequest( master: String, - driverId: String): KillDriverRequestMessage = { - import KillDriverRequestField._ - new KillDriverRequestMessage() - .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(DRIVER_ID, driverId) + driverId: String): KillDriverRequest = { + new KillDriverRequest() + .setSparkVersion(sparkVersion) + .setDriverId(driverId) } /** Construct a driver status request message. */ override protected def constructStatusRequest( master: String, - driverId: String): DriverStatusRequestMessage = { - import DriverStatusRequestField._ - new DriverStatusRequestMessage() - .setField(CLIENT_SPARK_VERSION, sparkVersion) - .setField(DRIVER_ID, driverId) + driverId: String): DriverStatusRequest = { + new DriverStatusRequest() + .setSparkVersion(sparkVersion) + .setDriverId(driverId) } /** Throw an exception if this is not standalone mode. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index eb6065ff16c4..3fcfe189c6a1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkConf import org.apache.spark.util.{AkkaUtils, Utils} import org.apache.spark.deploy.{Command, DriverDescription} import org.apache.spark.deploy.ClientArguments._ -import org.apache.spark.deploy.DeployMessages._ +import org.apache.spark.deploy.DeployMessages import org.apache.spark.deploy.master.Master /** @@ -56,52 +56,49 @@ private[spark] class StandaloneRestServerHandler( /** Handle a request to submit a driver. */ override protected def handleSubmit( - request: SubmitDriverRequestMessage): SubmitDriverResponseMessage = { - import SubmitDriverResponseField._ + request: SubmitDriverRequest): SubmitDriverResponse = { val driverDescription = buildDriverDescription(request) - val response = AkkaUtils.askWithReply[SubmitDriverResponse]( - RequestSubmitDriver(driverDescription), masterActor, askTimeout) - new SubmitDriverResponseMessage() - .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(MESSAGE, response.message) - .setField(SUCCESS, response.success.toString) - .setFieldIfNotNull(DRIVER_ID, response.driverId.orNull) + val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + new SubmitDriverResponse() + .setSparkVersion(sparkVersion) + .setMessage(response.message) + .setSuccess(response.success.toString) + .setDriverId(response.driverId.orNull) } /** Handle a request to kill a driver. */ override protected def handleKill( - request: KillDriverRequestMessage): KillDriverResponseMessage = { - import KillDriverResponseField._ - val driverId = request.getFieldNotNull(KillDriverRequestField.DRIVER_ID) - val response = AkkaUtils.askWithReply[KillDriverResponse]( - RequestKillDriver(driverId), masterActor, askTimeout) - new KillDriverResponseMessage() - .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(MESSAGE, response.message) - .setField(DRIVER_ID, driverId) - .setField(SUCCESS, response.success.toString) + request: KillDriverRequest): KillDriverResponse = { + val driverId = request.getDriverId + val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(driverId), masterActor, askTimeout) + new KillDriverResponse() + .setSparkVersion(sparkVersion) + .setMessage(response.message) + .setDriverId(driverId) + .setSuccess(response.success.toString) } /** Handle a request for a driver's status. */ override protected def handleStatus( - request: DriverStatusRequestMessage): DriverStatusResponseMessage = { - import DriverStatusResponseField._ - val driverId = request.getField(DriverStatusRequestField.DRIVER_ID) - val response = AkkaUtils.askWithReply[DriverStatusResponse]( - RequestDriverStatus(driverId), masterActor, askTimeout) + request: DriverStatusRequest): DriverStatusResponse = { + val driverId = request.getDriverId + val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(driverId), masterActor, askTimeout) // Format exception nicely, if it exists val message = response.exception.map { e => val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") s"Exception from the cluster:\n$e\n$stackTraceString" } - new DriverStatusResponseMessage() - .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(DRIVER_ID, driverId) - .setField(SUCCESS, response.found.toString) - .setFieldIfNotNull(DRIVER_STATE, response.state.map(_.toString).orNull) - .setFieldIfNotNull(WORKER_ID, response.workerId.orNull) - .setFieldIfNotNull(WORKER_HOST_PORT, response.workerHostPort.orNull) - .setFieldIfNotNull(MESSAGE, message.orNull) + new DriverStatusResponse() + .setSparkVersion(sparkVersion) + .setDriverId(driverId) + .setSuccess(response.found.toString) + .setDriverState(response.state.map(_.toString).orNull) + .setWorkerId(response.workerId.orNull) + .setWorkerHostPort(response.workerHostPort.orNull) + .setMessage(message.orNull) } /** @@ -109,25 +106,23 @@ private[spark] class StandaloneRestServerHandler( * This does not currently consider fields used by python applications since * python is not supported in standalone cluster mode yet. */ - private def buildDriverDescription(request: SubmitDriverRequestMessage): DriverDescription = { - import SubmitDriverRequestField._ - + private def buildDriverDescription(request: SubmitDriverRequest): DriverDescription = { // Required fields, including the main class because python is not yet supported - val appName = request.getFieldNotNull(APP_NAME) - val appResource = request.getFieldNotNull(APP_RESOURCE) - val mainClass = request.getFieldNotNull(MAIN_CLASS) + val appName = request.getAppName + val appResource = request.getAppResource + val mainClass = request.getMainClass // Optional fields - val jars = request.getFieldOption(JARS) - val files = request.getFieldOption(FILES) - val driverMemory = request.getFieldOption(DRIVER_MEMORY) - val driverCores = request.getFieldOption(DRIVER_CORES) - val driverExtraJavaOptions = request.getFieldOption(DRIVER_EXTRA_JAVA_OPTIONS) - val driverExtraClassPath = request.getFieldOption(DRIVER_EXTRA_CLASS_PATH) - val driverExtraLibraryPath = request.getFieldOption(DRIVER_EXTRA_LIBRARY_PATH) - val superviseDriver = request.getFieldOption(SUPERVISE_DRIVER) - val executorMemory = request.getFieldOption(EXECUTOR_MEMORY) - val totalExecutorCores = request.getFieldOption(TOTAL_EXECUTOR_CORES) + val jars = Option(request.getJars) + val files = Option(request.getFiles) + val driverMemory = Option(request.getDriverMemory) + val driverCores = Option(request.getDriverCores) + val driverExtraJavaOptions = Option(request.getDriverExtraJavaOptions) + val driverExtraClassPath = Option(request.getDriverExtraClassPath) + val driverExtraLibraryPath = Option(request.getDriverExtraLibraryPath) + val superviseDriver = Option(request.getSuperviseDriver) + val executorMemory = Option(request.getExecutorMemory) + val totalExecutorCores = Option(request.getTotalExecutorCores) val appArgs = request.getAppArgs val sparkProperties = request.getSparkProperties val environmentVariables = request.getEnvironmentVariables @@ -155,7 +150,7 @@ private[spark] class StandaloneRestServerHandler( "org.apache.spark.deploy.worker.DriverWrapper", Seq("{{WORKER_URL}}", mainClass) ++ appArgs, // args to the DriverWrapper environmentVariables, extraClassPath, extraLibraryPath, javaOpts) - val actualDriverMemory = driverMemory.map(_.toInt).getOrElse(DEFAULT_MEMORY) + val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY) val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES) val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE) new DriverDescription( diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala new file mode 100644 index 000000000000..9bde3345d03f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import com.fasterxml.jackson.annotation.{JsonIgnore, JsonProperty} +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.util.JsonProtocol + +class SubmitDriverRequest extends SubmitRestProtocolRequest { + protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST + private val appName = new SubmitRestProtocolField[String] + private val appResource = new SubmitRestProtocolField[String] + private val mainClass = new SubmitRestProtocolField[String] + private val jars = new SubmitRestProtocolField[String] + private val files = new SubmitRestProtocolField[String] + private val pyFiles = new SubmitRestProtocolField[String] + private val driverMemory = new SubmitRestProtocolField[String] + private val driverCores = new SubmitRestProtocolField[Int] + private val driverExtraJavaOptions = new SubmitRestProtocolField[String] + private val driverExtraClassPath = new SubmitRestProtocolField[String] + private val driverExtraLibraryPath = new SubmitRestProtocolField[String] + private val superviseDriver = new SubmitRestProtocolField[Boolean] + private val executorMemory = new SubmitRestProtocolField[String] + private val totalExecutorCores = new SubmitRestProtocolField[Int] + + // Special fields + private val appArgs = new ArrayBuffer[String] + private val sparkProperties = new mutable.HashMap[String, String] + private val envVars = new mutable.HashMap[String, String] + + def getAppName: String = appName.toString + def getAppResource: String = appResource.toString + def getMainClass: String = mainClass.toString + def getJars: String = jars.toString + def getFiles: String = files.toString + def getPyFiles: String = pyFiles.toString + def getDriverMemory: String = driverMemory.toString + def getDriverCores: String = driverCores.toString + def getDriverExtraJavaOptions: String = driverExtraJavaOptions.toString + def getDriverExtraClassPath: String = driverExtraClassPath.toString + def getDriverExtraLibraryPath: String = driverExtraLibraryPath.toString + def getSuperviseDriver: String = superviseDriver.toString + def getExecutorMemory: String = executorMemory.toString + def getTotalExecutorCores: String = totalExecutorCores.toString + + // Special getters required for JSON de/serialization + @JsonProperty("appArgs") + private def getAppArgsJson: String = arrayToJson(getAppArgs) + @JsonProperty("sparkProperties") + private def getSparkPropertiesJson: String = mapToJson(getSparkProperties) + @JsonProperty("environmentVariables") + private def getEnvironmentVariablesJson: String = mapToJson(getEnvironmentVariables) + + def setAppName(s: String): this.type = setField(appName, s) + def setAppResource(s: String): this.type = setField(appResource, s) + def setMainClass(s: String): this.type = setField(mainClass, s) + def setJars(s: String): this.type = setField(jars, s) + def setFiles(s: String): this.type = setField(files, s) + def setPyFiles(s: String): this.type = setField(pyFiles, s) + def setDriverMemory(s: String): this.type = setField(driverMemory, s) + def setDriverCores(s: String): this.type = setNumericField(driverCores, s) + def setDriverExtraJavaOptions(s: String): this.type = setField(driverExtraJavaOptions, s) + def setDriverExtraClassPath(s: String): this.type = setField(driverExtraClassPath, s) + def setDriverExtraLibraryPath(s: String): this.type = setField(driverExtraLibraryPath, s) + def setSuperviseDriver(s: String): this.type = setBooleanField(superviseDriver, s) + def setExecutorMemory(s: String): this.type = setField(executorMemory, s) + def setTotalExecutorCores(s: String): this.type = setNumericField(totalExecutorCores, s) + + // Special setters required for JSON de/serialization + @JsonProperty("appArgs") + private def setAppArgsJson(s: String): Unit = { + appArgs.clear() + appArgs ++= JsonProtocol.arrayFromJson(parse(s)) + } + @JsonProperty("sparkProperties") + private def setSparkPropertiesJson(s: String): Unit = { + sparkProperties.clear() + sparkProperties ++= JsonProtocol.mapFromJson(parse(s)) + } + @JsonProperty("environmentVariables") + private def setEnvironmentVariablesJson(s: String): Unit = { + envVars.clear() + envVars ++= JsonProtocol.mapFromJson(parse(s)) + } + + @JsonIgnore + def getAppArgs: Array[String] = appArgs.toArray + @JsonIgnore + def getSparkProperties: Map[String, String] = sparkProperties.toMap + @JsonIgnore + def getEnvironmentVariables: Map[String, String] = envVars.toMap + @JsonIgnore + def addAppArg(s: String): this.type = { appArgs += s; this } + @JsonIgnore + def setSparkProperty(k: String, v: String): this.type = { sparkProperties(k) = v; this } + @JsonIgnore + def setEnvironmentVariable(k: String, v: String): this.type = { envVars(k) = v; this } + + private def arrayToJson(arr: Array[String]): String = { + if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else { null } + } + + private def mapToJson(map: Map[String, String]): String = { + if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else { null } + } + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(appName, "app_name") + assertFieldIsSet(appResource, "app_resource") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala deleted file mode 100644 index 90d7e408fefc..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequestMessage.scala +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.json4s.JsonAST._ - -import org.apache.spark.util.JsonProtocol - -/** - * A field used in a SubmitDriverRequestMessage. - */ -private[spark] abstract class SubmitDriverRequestField extends SubmitRestProtocolField -private[spark] object SubmitDriverRequestField - extends SubmitRestProtocolFieldCompanion[SubmitDriverRequestField] { - case object ACTION extends SubmitDriverRequestField with ActionField - case object CLIENT_SPARK_VERSION extends SubmitDriverRequestField - case object MESSAGE extends SubmitDriverRequestField - case object APP_NAME extends SubmitDriverRequestField - case object APP_RESOURCE extends SubmitDriverRequestField - case object MAIN_CLASS extends SubmitDriverRequestField - case object JARS extends SubmitDriverRequestField - case object FILES extends SubmitDriverRequestField - case object PY_FILES extends SubmitDriverRequestField - case object DRIVER_MEMORY extends SubmitDriverRequestField with MemoryField - case object DRIVER_CORES extends SubmitDriverRequestField with NumericField - case object DRIVER_EXTRA_JAVA_OPTIONS extends SubmitDriverRequestField - case object DRIVER_EXTRA_CLASS_PATH extends SubmitDriverRequestField - case object DRIVER_EXTRA_LIBRARY_PATH extends SubmitDriverRequestField - case object SUPERVISE_DRIVER extends SubmitDriverRequestField with BooleanField - case object EXECUTOR_MEMORY extends SubmitDriverRequestField with MemoryField - case object TOTAL_EXECUTOR_CORES extends SubmitDriverRequestField with NumericField - - // Special fields that should not be set directly - case object APP_ARGS extends SubmitDriverRequestField { - override def validateValue(v: String): Unit = { - validateFailed(v, "Use message.appendAppArg(arg) instead") - } - } - case object SPARK_PROPERTIES extends SubmitDriverRequestField { - override def validateValue(v: String): Unit = { - validateFailed(v, "Use message.setSparkProperty(k, v) instead") - } - } - case object ENVIRONMENT_VARIABLES extends SubmitDriverRequestField { - override def validateValue(v: String): Unit = { - validateFailed(v, "Use message.setEnvironmentVariable(k, v) instead") - } - } - - override val requiredFields = Seq(ACTION, CLIENT_SPARK_VERSION, APP_NAME, APP_RESOURCE) - override val optionalFields = Seq(MESSAGE, MAIN_CLASS, JARS, FILES, PY_FILES, DRIVER_MEMORY, - DRIVER_CORES, DRIVER_EXTRA_JAVA_OPTIONS, DRIVER_EXTRA_CLASS_PATH, DRIVER_EXTRA_LIBRARY_PATH, - SUPERVISE_DRIVER, EXECUTOR_MEMORY, TOTAL_EXECUTOR_CORES, APP_ARGS, SPARK_PROPERTIES, - ENVIRONMENT_VARIABLES) -} - -/** - * A request sent to the cluster manager to submit a driver - * in the stable application submission REST protocol. - */ -private[spark] class SubmitDriverRequestMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST, - SubmitDriverRequestField.ACTION, - SubmitDriverRequestField.requiredFields) { - - import SubmitDriverRequestField._ - - private val appArgs = new ArrayBuffer[String] - private val sparkProperties = new mutable.HashMap[String, String] - private val environmentVariables = new mutable.HashMap[String, String] - - // Setters for special fields - def appendAppArg(arg: String): Unit = { appArgs += arg } - def setSparkProperty(k: String, v: String): Unit = { sparkProperties(k) = v } - def setEnvironmentVariable(k: String, v: String): Unit = { environmentVariables(k) = v } - - // Getters for special fields - def getAppArgs: Seq[String] = appArgs.clone() - def getSparkProperties: Map[String, String] = sparkProperties.toMap - def getEnvironmentVariables: Map[String, String] = environmentVariables.toMap - - // Include app args, spark properties, and environment variables in the JSON object - // The order imposed here is as follows: * < APP_ARGS < SPARK_PROPERTIES < ENVIRONMENT_VARIABLES - override def toJsonObject: JObject = { - val otherFields = super.toJsonObject.obj - val appArgsJson = JArray(appArgs.map(JString).toList) - val sparkPropertiesJson = JsonProtocol.mapToJson(sparkProperties) - val environmentVariablesJson = JsonProtocol.mapToJson(environmentVariables) - val jsonFields = new ArrayBuffer[JField] - jsonFields ++= otherFields - if (appArgs.nonEmpty) { - jsonFields += JField(APP_ARGS.toString, appArgsJson) - } - if (sparkProperties.nonEmpty) { - jsonFields += JField(SPARK_PROPERTIES.toString, sparkPropertiesJson) - } - if (environmentVariables.nonEmpty) { - jsonFields += JField(ENVIRONMENT_VARIABLES.toString, environmentVariablesJson) - } - JObject(jsonFields.toList) - } -} - -private[spark] object SubmitDriverRequestMessage - extends SubmitRestProtocolMessageCompanion[SubmitDriverRequestMessage] { - - import SubmitDriverRequestField._ - - protected override def newMessage() = new SubmitDriverRequestMessage - protected override def fieldFromString(f: String) = SubmitDriverRequestField.fromString(f) - - /** - * Process the given field and value appropriately based on the type of the field. - * This handles certain nested values in addition to flat values. - */ - override def handleField( - message: SubmitDriverRequestMessage, - field: SubmitRestProtocolField, - value: JValue): Unit = { - (field, value) match { - case (APP_ARGS, JArray(args)) => - args.map(_.asInstanceOf[JString].s).foreach { arg => - message.appendAppArg(arg) - } - case (SPARK_PROPERTIES, props: JObject) => - JsonProtocol.mapFromJson(props).foreach { case (k, v) => - message.setSparkProperty(k, v) - } - case (ENVIRONMENT_VARIABLES, envVars: JObject) => - JsonProtocol.mapFromJson(envVars).foreach { case (envKey, envValue) => - message.setEnvironmentVariable(envKey, envValue) - } - // All other fields are assumed to have flat values - case _ => super.handleField(message, field, value) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala new file mode 100644 index 000000000000..8a1676767cec --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +class SubmitDriverResponse extends SubmitRestProtocolResponse { + protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE + private val success = new SubmitRestProtocolField[Boolean] + private val driverId = new SubmitRestProtocolField[String] + + def getSuccess: String = success.toString + def getDriverId: String = driverId.toString + + def setSuccess(s: String): this.type = setBooleanField(success, s) + def setDriverId(s: String): this.type = setField(driverId, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(success, "success") + } +} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala deleted file mode 100644 index d5a2e1660eb0..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponseMessage.scala +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in a SubmitDriverResponseMessage. - */ -private[spark] abstract class SubmitDriverResponseField extends SubmitRestProtocolField -private[spark] object SubmitDriverResponseField - extends SubmitRestProtocolFieldCompanion[SubmitDriverResponseField] { - case object ACTION extends SubmitDriverResponseField with ActionField - case object SERVER_SPARK_VERSION extends SubmitDriverResponseField - case object MESSAGE extends SubmitDriverResponseField - case object SUCCESS extends SubmitDriverResponseField with BooleanField - case object DRIVER_ID extends SubmitDriverResponseField - override val requiredFields = Seq(ACTION, SERVER_SPARK_VERSION, MESSAGE, SUCCESS) - override val optionalFields = Seq(DRIVER_ID) -} - -/** - * A message sent from the cluster manager in response to a SubmitDriverRequestMessage - * in the stable application submission REST protocol. - */ -private[spark] class SubmitDriverResponseMessage extends SubmitRestProtocolMessage( - SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE, - SubmitDriverResponseField.ACTION, - SubmitDriverResponseField.requiredFields) - -private[spark] object SubmitDriverResponseMessage - extends SubmitRestProtocolMessageCompanion[SubmitDriverResponseMessage] { - protected override def newMessage() = new SubmitDriverResponseMessage - protected override def fieldFromString(f: String) = SubmitDriverResponseField.fromString(f) -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index 513c17deee89..eb258290bdc7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -34,39 +34,39 @@ import org.apache.spark.deploy.SparkSubmitArguments private[spark] abstract class SubmitRestClient extends Logging { /** Request that the REST server submit a driver specified by the provided arguments. */ - def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolMessage = { + def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { validateSubmitArguments(args) val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) logInfo(s"Submitting a request to launch a driver in ${args.master}.") - sendHttp(url, request) + sendHttp(url, request).asInstanceOf[SubmitDriverResponse] } /** Request that the REST server kill the specified driver. */ - def killDriver(master: String, driverId: String): SubmitRestProtocolMessage = { + def killDriver(master: String, driverId: String): KillDriverResponse = { validateMaster(master) val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) logInfo(s"Submitting a request to kill driver $driverId in $master.") - sendHttp(url, request) + sendHttp(url, request).asInstanceOf[KillDriverResponse] } /** Request the status of the specified driver from the REST server. */ - def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolMessage = { + def requestDriverStatus(master: String, driverId: String): DriverStatusResponse = { validateMaster(master) val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) logInfo(s"Submitting a request for the status of driver $driverId in $master.") - sendHttp(url, request) + sendHttp(url, request).asInstanceOf[DriverStatusResponse] } /** Return the HTTP URL of the REST server that corresponds to the given master URL. */ protected def getHttpUrl(master: String): URL // Construct the appropriate type of message based on the request type - protected def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequestMessage - protected def constructKillRequest(master: String, driverId: String): KillDriverRequestMessage - protected def constructStatusRequest(master: String, driverId: String): DriverStatusRequestMessage + protected def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequest + protected def constructKillRequest(master: String, driverId: String): KillDriverRequest + protected def constructStatusRequest(master: String, driverId: String): DriverStatusRequest // If the provided arguments are not as expected, throw an exception protected def validateMaster(master: String): Unit @@ -81,7 +81,7 @@ private[spark] abstract class SubmitRestClient extends Logging { * This assumes both the request and the response use the JSON format. * Return the response received from the REST server. */ - private def sendHttp(url: URL, request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { + private def sendHttp(url: URL, request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { try { val conn = url.openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod("POST") @@ -96,7 +96,7 @@ private[spark] abstract class SubmitRestClient extends Logging { out.close() val responseJson = Source.fromInputStream(conn.getInputStream).mkString logDebug(s"Response from the REST server:\n$responseJson") - SubmitRestProtocolMessage.fromJson(responseJson) + SubmitRestProtocolResponse.fromJson(responseJson) } catch { case e: FileNotFoundException => throw new SparkException(s"Unable to connect to REST server $url", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala index 639e00d912e7..4c0c45b450fa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala @@ -17,68 +17,11 @@ package org.apache.spark.deploy.rest -import scala.collection.Map -import scala.util.Try - -import org.apache.spark.util.Utils - -/** - * A field used in a SubmitRestProtocolMessage. - * There are a few special fields: - * - ACTION entirely specifies the type of the message and is required in all messages - * - MESSAGE contains arbitrary messages and is common, but not required, in all messages - * - CLIENT_SPARK_VERSION is required in all messages sent from the client - * - SERVER_SPARK_VERSION is required in all messages sent from the server - */ -private[spark] abstract class SubmitRestProtocolField { - protected val fieldName = Utils.getFormattedClassName(this) - def validateValue(value: String): Unit = { } - def validateFailed(v: String, msg: String): Unit = { - throw new IllegalArgumentException(s"Detected setting of $fieldName to $v: $msg") - } -} -private[spark] object SubmitRestProtocolField { - def isActionField(field: String): Boolean = field == "ACTION" -} - -/** A field that should accept only boolean values. */ -private[spark] trait BooleanField extends SubmitRestProtocolField { - override def validateValue(v: String): Unit = { - Try(v.toBoolean).getOrElse { validateFailed(v, s"Error parsing $v as a boolean!") } - } -} - -/** A field that should accept only numeric values. */ -private[spark] trait NumericField extends SubmitRestProtocolField { - override def validateValue(v: String): Unit = { - Try(v.toInt).getOrElse { validateFailed(v, s"Error parsing $v as an integer!") } - } -} - -/** A field that should accept only memory values. */ -private[spark] trait MemoryField extends SubmitRestProtocolField { - override def validateValue(v: String): Unit = { - Try(Utils.memoryStringToMb(v)).getOrElse { - validateFailed(v, s"Error parsing $v as a memory string!") - } - } -} - -/** - * The main action field in every message. - * This should be set only on message instantiation. - */ -private[spark] trait ActionField extends SubmitRestProtocolField { - override def validateValue(v: String): Unit = { - validateFailed(v, "The ACTION field must not be set directly after instantiation.") - } -} - /** * All possible values of the ACTION field in a SubmitRestProtocolMessage. */ -private[spark] abstract class SubmitRestProtocolAction -private[spark] object SubmitRestProtocolAction { +abstract class SubmitRestProtocolAction +object SubmitRestProtocolAction { case object SUBMIT_DRIVER_REQUEST extends SubmitRestProtocolAction case object SUBMIT_DRIVER_RESPONSE extends SubmitRestProtocolAction case object KILL_DRIVER_REQUEST extends SubmitRestProtocolAction @@ -98,24 +41,12 @@ private[spark] object SubmitRestProtocolAction { } } -/** - * Common methods used by companion objects of SubmitRestProtocolField's subclasses. - * This keeps track of all fields that belong to this object in order to reconstruct - * the fields from their names. - */ -private[spark] trait SubmitRestProtocolFieldCompanion[FieldType <: SubmitRestProtocolField] { - val requiredFields: Seq[FieldType] - val optionalFields: Seq[FieldType] - - // Listing of all fields indexed by the field's string representation - private lazy val allFieldsMap: Map[String, FieldType] = { - (requiredFields ++ optionalFields).map { f => (f.toString, f) }.toMap - } - - /** Return the appropriate SubmitRestProtocolField from its string representation. */ - def fromString(field: String): FieldType = { - allFieldsMap.get(field).getOrElse { - throw new IllegalArgumentException(s"Unknown field $field") - } - } +class SubmitRestProtocolField[T] { + protected var value: Option[T] = None + def isSet: Boolean = value.isDefined + def getValue: T = value.getOrElse { throw new IllegalAccessException("Value not set!") } + def getValueOption: Option[T] = value + def setValue(v: T): Unit = { value = Some(v) } + def clearValue(): Unit = { value = None } + override def toString: String = value.map(_.toString).orNull } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 7899668ac526..0b2085b5e3bf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -17,185 +17,185 @@ package org.apache.spark.deploy.rest -import scala.collection.Map -import scala.collection.JavaConversions._ - -import org.json4s.jackson.JsonMethods._ +import com.fasterxml.jackson.annotation._ +import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility +import com.fasterxml.jackson.annotation.JsonInclude.Include +import com.fasterxml.jackson.databind.ObjectMapper import org.json4s.JsonAST._ +import org.json4s.jackson.JsonMethods._ -import org.apache.spark.Logging import org.apache.spark.util.Utils +import org.apache.spark.deploy.rest.SubmitRestProtocolAction._ + +@JsonInclude(Include.NON_NULL) +@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) +@JsonPropertyOrder(alphabetic = true) +abstract class SubmitRestProtocolMessage { + import SubmitRestProtocolMessage._ + + private val messageType = Utils.getFormattedClassName(this) + protected val action: SubmitRestProtocolAction + protected val sparkVersion = new SubmitRestProtocolField[String] + protected val message = new SubmitRestProtocolField[String] + + // Required for JSON de/serialization and not explicitly used + private def getAction: String = action.toString + private def setAction(s: String): this.type = this + + // Spark version implementation depends on whether this is a request or a response + @JsonIgnore + def getSparkVersion: String + @JsonIgnore + def setSparkVersion(s: String): this.type + + def getMessage: String = message.toString + def setMessage(s: String): this.type = setField(message, s) + + def toJson: String = { + validate() + val mapper = new ObjectMapper + val json = mapper.writeValueAsString(this) + postProcessJson(json) + } -/** - * A general message exchanged in the stable application submission REST protocol. - * - * The message is represented by a set of fields in the form of key value pairs. - * Each message must contain an ACTION field, which fully specifies the type of the message. - * For compatibility with older versions of Spark, existing fields must not be removed or - * modified, though new fields can be added as necessary. - */ -private[spark] abstract class SubmitRestProtocolMessage( - action: SubmitRestProtocolAction, - actionField: ActionField, - requiredFields: Seq[SubmitRestProtocolField]) { - - // Maintain the insert order for converting to JSON later - private val fields = new java.util.LinkedHashMap[SubmitRestProtocolField, String] - val className = Utils.getFormattedClassName(this) - - // Set the action field - fields.put(actionField, action.toString) - - /** Return all fields currently set in this message. */ - def getFields: Map[SubmitRestProtocolField, String] = fields.toMap - - /** Return the value of the given field. If the field is not present, return null. */ - def getField(key: SubmitRestProtocolField): String = getFieldOption(key).orNull + def validate(): Unit = { + assert(action != null, s"The action field is missing in $messageType!") + } - /** Return the value of the given field. If the field is not present, throw an exception. */ - def getFieldNotNull(key: SubmitRestProtocolField): String = { - getFieldOption(key).getOrElse { - throw new IllegalArgumentException(s"Field $key is not set in message $className") - } + protected def assertFieldIsSet(field: SubmitRestProtocolField[_], name: String): Unit = { + assert(field.isSet, s"The $name field is missing in $messageType!") } - /** Return the value of the given field as an option. */ - def getFieldOption(key: SubmitRestProtocolField): Option[String] = Option(fields.get(key)) + protected def setField(field: SubmitRestProtocolField[String], value: String): this.type = { + if (value == null) { field.clearValue() } else { field.setValue(value) } + this + } - /** Assign the given value to the field, overriding any existing value. */ - def setField(key: SubmitRestProtocolField, value: String): this.type = { - key.validateValue(value) - fields.put(key, value) + protected def setBooleanField( + field: SubmitRestProtocolField[Boolean], + value: String): this.type = { + if (value == null) { field.clearValue() } else { field.setValue(value.toBoolean) } this } - /** Assign the given value to the field only if the value is not null. */ - def setFieldIfNotNull(key: SubmitRestProtocolField, value: String): this.type = { - if (value != null) { - setField(key, value) - } + protected def setNumericField( + field: SubmitRestProtocolField[Int], + value: String): this.type = { + if (value == null) { field.clearValue() } else { field.setValue(value.toInt) } this } - /** - * Validate that all required fields are set and the value of the ACTION field is as expected. - * If any of these conditions are not met, throw an exception. - */ - def validate(): this.type = { - if (!fields.contains(actionField)) { - throw new IllegalArgumentException(s"The action field is missing from message $className.") - } - if (fields(actionField) != action.toString) { - throw new IllegalArgumentException( - s"Expected action $action in message $className, but actual was ${fields(actionField)}.") - } - val missingFields = requiredFields.filterNot(fields.contains) - if (missingFields.nonEmpty) { - val missingFieldsString = missingFields.mkString(", ") - throw new IllegalArgumentException( - s"The following fields are missing from message $className: $missingFieldsString.") - } + protected def setMemoryField( + field: SubmitRestProtocolField[String], + value: String): this.type = { + Utils.memoryStringToMb(value) + setField(field, value) this } - /** Return the JSON representation of this message. */ - def toJson: String = pretty(render(toJsonObject)) + private def postProcessJson(json: String): String = { + val fields = parse(json).asInstanceOf[JObject].obj + val newFields = fields.map { case (k, v) => (camelCaseToUnderscores(k), v) } + pretty(render(JObject(newFields))) + } +} + +abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + def getClientSparkVersion: String = sparkVersion.toString + def setClientSparkVersion(s: String): this.type = setField(sparkVersion, s) + override def getSparkVersion: String = getClientSparkVersion + override def setSparkVersion(s: String) = setClientSparkVersion(s) + override def validate(): Unit = { + super.validate() + assertFieldIsSet(sparkVersion, "client_spark_version") + } +} - /** - * Return a JObject that represents the JSON form of this message. - * This ignores fields with null values. - */ - protected def toJsonObject: JObject = { - val jsonFields = fields.toSeq - .filter { case (_, v) => v != null } - .map { case (k, v) => JField(k.toString, JString(v)) } - .toList - JObject(jsonFields) +abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + def getServerSparkVersion: String = sparkVersion.toString + def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s) + override def getSparkVersion: String = getServerSparkVersion + override def setSparkVersion(s: String) = setServerSparkVersion(s) + override def validate(): Unit = { + super.validate() + assertFieldIsSet(sparkVersion, "server_spark_version") } } -private[spark] object SubmitRestProtocolMessage { - import SubmitRestProtocolField._ - import SubmitRestProtocolAction._ +object SubmitRestProtocolMessage { + private val mapper = new ObjectMapper - /** - * Construct a SubmitRestProtocolMessage from its JSON representation. - * This uses the ACTION field to determine the type of the message to reconstruct. - * If such a field does not exist, throw an exception. - */ def fromJson(json: String): SubmitRestProtocolMessage = { - val jsonObject = parse(json).asInstanceOf[JObject] - val action = getAction(jsonObject).getOrElse { - throw new IllegalArgumentException(s"ACTION not found in message:\n$json") + val fields = parse(json).asInstanceOf[JObject].obj + val action = fields + .find { case (f, _) => f == "action" } + .map { case (_, v) => v.asInstanceOf[JString].s } + .getOrElse { + throw new IllegalArgumentException(s"Could not find action field in message:\n$json") + } + val clazz = SubmitRestProtocolAction.fromString(action) match { + case SUBMIT_DRIVER_REQUEST => classOf[SubmitDriverRequest] + case SUBMIT_DRIVER_RESPONSE => classOf[SubmitDriverResponse] + case KILL_DRIVER_REQUEST => classOf[KillDriverRequest] + case KILL_DRIVER_RESPONSE => classOf[KillDriverResponse] + case DRIVER_STATUS_REQUEST => classOf[DriverStatusRequest] + case DRIVER_STATUS_RESPONSE => classOf[DriverStatusResponse] + case ERROR => classOf[ErrorResponse] } - SubmitRestProtocolAction.fromString(action) match { - case SUBMIT_DRIVER_REQUEST => SubmitDriverRequestMessage.fromJsonObject(jsonObject) - case SUBMIT_DRIVER_RESPONSE => SubmitDriverResponseMessage.fromJsonObject(jsonObject) - case KILL_DRIVER_REQUEST => KillDriverRequestMessage.fromJsonObject(jsonObject) - case KILL_DRIVER_RESPONSE => KillDriverResponseMessage.fromJsonObject(jsonObject) - case DRIVER_STATUS_REQUEST => DriverStatusRequestMessage.fromJsonObject(jsonObject) - case DRIVER_STATUS_RESPONSE => DriverStatusResponseMessage.fromJsonObject(jsonObject) - case ERROR => ErrorMessage.fromJsonObject(jsonObject) + fromJson(json, clazz) + } + + def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = { + val fields = parse(json).asInstanceOf[JObject].obj + val processedFields = fields.map { case (k, v) => (underscoresToCamelCase(k), v) } + val processedJson = compact(render(JObject(processedFields))) + mapper.readValue(processedJson, clazz) + } + + private def camelCaseToUnderscores(s: String): String = { + val newString = new StringBuilder + s.foreach { c => + if (c.isUpper) { + newString.append("_" + c.toLower) + } else { + newString.append(c) + } } + newString.toString() } - /** - * Extract the value of the ACTION field in the JSON object. - */ - private def getAction(jsonObject: JObject): Option[String] = { - jsonObject.obj - .collect { case JField(k, JString(v)) if isActionField(k) => v } - .headOption + private def underscoresToCamelCase(s: String): String = { + val newString = new StringBuilder + var capitalizeNext = false + s.foreach { c => + if (c == '_') { + capitalizeNext = true + } else { + val nextChar = if (capitalizeNext) c.toUpper else c + newString.append(nextChar) + capitalizeNext = false + } + } + newString.toString() } } -/** - * Common methods used by companion objects of SubmitRestProtocolMessage's subclasses. - */ -private[spark] trait SubmitRestProtocolMessageCompanion[MessageType <: SubmitRestProtocolMessage] - extends Logging { - - import SubmitRestProtocolField._ - - /** Construct a new message of the relevant type. */ - protected def newMessage(): MessageType - - /** Return a field of the relevant type from the field's string representation. */ - protected def fieldFromString(field: String): SubmitRestProtocolField - - /** - * Populate the given field and value in the provided message. - * The default behavior only handles fields that have flat values and ignores other fields. - * If the subclass uses fields with nested values, it should override this method appropriately. - */ - protected def handleField( - message: MessageType, - field: SubmitRestProtocolField, - value: JValue): Unit = { - value match { - case JString(s) => message.setField(field, s) - case _ => logWarning( - s"Unexpected value for field $field in message ${message.className}:\n$value") +object SubmitRestProtocolRequest { + def fromJson(s: String): SubmitRestProtocolRequest = { + SubmitRestProtocolMessage.fromJson(s) match { + case req: SubmitRestProtocolRequest => req + case res: SubmitRestProtocolResponse => + throw new IllegalArgumentException(s"Message was not a request:\n$s") } } +} - /** Construct a SubmitRestProtocolMessage from the given JSON object. */ - def fromJsonObject(jsonObject: JObject): MessageType = { - val message = newMessage() - val fields = jsonObject.obj - .map { case JField(k, v) => (k, v) } - // The ACTION field is already handled on instantiation - .filter { case (k, _) => !isActionField(k) } - .flatMap { case (k, v) => - try { - Some((fieldFromString(k), v)) - } catch { - case e: IllegalArgumentException => - logWarning(s"Unexpected field $k in message ${Utils.getFormattedClassName(this)}") - None - } - } - fields.foreach { case (k, v) => handleField(message, k, v) } - message +object SubmitRestProtocolResponse { + def fromJson(s: String): SubmitRestProtocolResponse = { + SubmitRestProtocolMessage.fromJson(s) match { + case req: SubmitRestProtocolRequest => + throw new IllegalArgumentException(s"Message was not a response:\n$s") + case res: SubmitRestProtocolResponse => res + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 9bc7220eb19d..89a2b83d2cde 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -70,9 +70,9 @@ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, * This represents the main handler used in the SubmitRestServer. */ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler with Logging { - protected def handleSubmit(request: SubmitDriverRequestMessage): SubmitDriverResponseMessage - protected def handleKill(request: KillDriverRequestMessage): KillDriverResponseMessage - protected def handleStatus(request: DriverStatusRequestMessage): DriverStatusResponseMessage + protected def handleSubmit(request: SubmitDriverRequest): SubmitDriverResponse + protected def handleKill(request: KillDriverRequest): KillDriverResponse + protected def handleStatus(request: DriverStatusRequest): DriverStatusResponse /** * Handle a request submitted by the SubmitRestClient. @@ -85,7 +85,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi response: HttpServletResponse): Unit = { try { val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString - val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + val requestMessage = SubmitRestProtocolRequest.fromJson(requestMessageJson) val responseMessage = constructResponseMessage(requestMessage) response.setContentType("application/json") response.setCharacterEncoding("utf-8") @@ -105,7 +105,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi * If an IllegalArgumentException is thrown in the process, construct an error message instead. */ private def constructResponseMessage( - request: SubmitRestProtocolMessage): SubmitRestProtocolMessage = { + request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { // Validate the request message to ensure that it is correctly constructed. If the request // is sent via the SubmitRestClient, it should have already been validated remotely. In case // this is not true, do it again here to guard against potential NPEs. If validation fails, @@ -114,9 +114,9 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi try { request.validate() request match { - case submit: SubmitDriverRequestMessage => handleSubmit(submit) - case kill: KillDriverRequestMessage => handleKill(kill) - case status: DriverStatusRequestMessage => handleStatus(status) + case submit: SubmitDriverRequest => handleSubmit(submit) + case kill: KillDriverRequest => handleKill(kill) + case status: DriverStatusRequest => handleStatus(status) case unexpected => handleError( s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") } @@ -130,13 +130,13 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi } catch { case e: IllegalArgumentException => handleError(s"Internal server error: ${e.getMessage}") } + response } /** Construct an error message to signal the fact that an exception has been thrown. */ - private def handleError(message: String): ErrorMessage = { - import ErrorField._ - new ErrorMessage() - .setField(SERVER_SPARK_VERSION, sparkVersion) - .setField(MESSAGE, message) + private def handleError(message: String): ErrorResponse = { + new ErrorResponse() + .setSparkVersion(sparkVersion) + .setMessage(message) } } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f896b5072e4f..0a1fba3bad75 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -389,6 +389,10 @@ private[spark] object JsonProtocol { * Util JSON serialization methods | * ------------------------------- */ + def arrayToJson(a: Array[String]): JValue = { + JArray(a.toList.map(JString)) + } + def mapToJson(m: Map[String, String]): JValue = { val jsonFields = m.map { case (k, v) => JField(k, JString(v)) } JObject(jsonFields.toList) @@ -795,6 +799,11 @@ private[spark] object JsonProtocol { * Util JSON deserialization methods | * --------------------------------- */ + def arrayFromJson(json: JValue): Array[String] = { + val values = json.asInstanceOf[JArray].arr + values.toArray.map(_.asInstanceOf[JString].s) + } + def mapFromJson(json: JValue): Map[String, String] = { val jsonFields = json.asInstanceOf[JObject].obj jsonFields.map { case JField(k, JString(v)) => (k, v) }.toMap diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index d7dc8234d57a..11e49077d893 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -60,7 +60,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B test("kill empty driver") { val killResponse = client.killDriver(masterRestUrl, "driver-that-does-not-exist") - val killSuccess = killResponse.getFieldNotNull(KillDriverResponseField.SUCCESS) + val killSuccess = killResponse.getSuccess assert(killSuccess === "false") } @@ -70,11 +70,11 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val size = 500 val driverId = submitApplication(resultsFile, numbers, size) val killResponse = client.killDriver(masterRestUrl, driverId) - val killSuccess = killResponse.getFieldNotNull(KillDriverResponseField.SUCCESS) + val killSuccess = killResponse.getSuccess waitUntilFinished(driverId) val statusResponse = client.requestDriverStatus(masterRestUrl, driverId) - val statusSuccess = statusResponse.getFieldNotNull(DriverStatusResponseField.SUCCESS) - val driverState = statusResponse.getFieldNotNull(DriverStatusResponseField.DRIVER_STATE) + val statusSuccess = statusResponse.getSuccess + val driverState = statusResponse.getDriverState assert(killSuccess === "true") assert(statusSuccess === "true") assert(driverState === DriverState.KILLED.toString) @@ -83,7 +83,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B test("request status for empty driver") { val statusResponse = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist") - val statusSuccess = statusResponse.getFieldNotNull(DriverStatusResponseField.SUCCESS) + val statusSuccess = statusResponse.getSuccess assert(statusSuccess === "false") } @@ -125,7 +125,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val args = new SparkSubmitArguments(commandLineArgs) SparkSubmit.prepareSubmitEnvironment(args) val submitResponse = client.submitDriver(args) - submitResponse.getFieldNotNull(SubmitDriverResponseField.DRIVER_ID) + submitResponse.getDriverId } /** Wait until the given driver has finished running up to the specified timeout. */ @@ -134,7 +134,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val expireTime = System.currentTimeMillis + maxSeconds * 1000 while (!finished) { val statusResponse = client.requestDriverStatus(masterRestUrl, driverId) - val driverState = statusResponse.getFieldNotNull(DriverStatusResponseField.DRIVER_STATE) + val driverState = statusResponse.getDriverState finished = driverState != DriverState.SUBMITTED.toString && driverState != DriverState.RUNNING.toString diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 18091e98c0b2..a7468a02dfe8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -17,36 +17,37 @@ package org.apache.spark.deploy.rest -import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.scalatest.FunSuite -/** - * Dummy fields and messages for testing. - */ -private abstract class DummyField extends SubmitRestProtocolField -private object DummyField extends SubmitRestProtocolFieldCompanion[DummyField] { - case object ACTION extends DummyField with ActionField - case object DUMMY_FIELD extends DummyField - case object BOOLEAN_FIELD extends DummyField with BooleanField - case object MEMORY_FIELD extends DummyField with MemoryField - case object NUMERIC_FIELD extends DummyField with NumericField - case object REQUIRED_FIELD extends DummyField - override val requiredFields = Seq(ACTION, REQUIRED_FIELD) - override val optionalFields = Seq(DUMMY_FIELD, BOOLEAN_FIELD, MEMORY_FIELD, NUMERIC_FIELD) -} -private object DUMMY_ACTION extends SubmitRestProtocolAction { - override def toString: String = "DUMMY_ACTION" -} -private class DummyMessage extends SubmitRestProtocolMessage( - DUMMY_ACTION, - DummyField.ACTION, - DummyField.requiredFields) -private object DummyMessage extends SubmitRestProtocolMessageCompanion[DummyMessage] { - protected override def newMessage() = new DummyMessage - protected override def fieldFromString(f: String) = DummyField.fromString(f) +case object DUMMY_REQUEST extends SubmitRestProtocolAction +case object DUMMY_RESPONSE extends SubmitRestProtocolAction + +class DummyRequest extends SubmitRestProtocolRequest { + protected override val action = DUMMY_REQUEST + private val active = new SubmitRestProtocolField[Boolean] + private val age = new SubmitRestProtocolField[Int] + private val name = new SubmitRestProtocolField[String] + + def getActive: String = active.toString + def getAge: String = age.toString + def getName: String = name.toString + + def setActive(s: String): this.type = setBooleanField(active, s) + def setAge(s: String): this.type = setNumericField(age, s) + def setName(s: String): this.type = setField(name, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(name, "name") + assertFieldIsSet(age, "age") + assert(age.getValue > 5, "Not old enough!") + } } +class DummyResponse extends SubmitRestProtocolResponse { + protected override val action = DUMMY_RESPONSE +} /** * Tests for the stable application submission REST protocol. @@ -65,110 +66,123 @@ class SubmitRestProtocolSuite extends FunSuite { } test("get and set fields") { - import DummyField._ - val message = new DummyMessage - // action field is already set on instantiation - assert(message.getFields.size === 1) - assert(message.getField(ACTION) === DUMMY_ACTION.toString) - // required field not set yet - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.getFieldNotNull(DUMMY_FIELD) } - intercept[IllegalArgumentException] { message.getFieldNotNull(REQUIRED_FIELD) } - message.setField(DUMMY_FIELD, "dummy value") - message.setField(BOOLEAN_FIELD, "true") - message.setField(MEMORY_FIELD, "401k") - message.setField(NUMERIC_FIELD, "401") - message.setFieldIfNotNull(REQUIRED_FIELD, null) // no-op because value is null - assert(message.getFields.size === 5) - // required field still not set - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.getFieldNotNull(REQUIRED_FIELD) } - message.setFieldIfNotNull(REQUIRED_FIELD, "dummy value") - // all required fields are now set - assert(message.getFields.size === 6) - assert(message.getField(DUMMY_FIELD) === "dummy value") - assert(message.getField(BOOLEAN_FIELD) === "true") - assert(message.getField(MEMORY_FIELD) === "401k") - assert(message.getField(NUMERIC_FIELD) === "401") - assert(message.getField(REQUIRED_FIELD) === "dummy value") - message.validate() - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(BOOLEAN_FIELD, "not T nor F") } - intercept[IllegalArgumentException] { message.setField(MEMORY_FIELD, "not memory") } - intercept[IllegalArgumentException] { message.setField(NUMERIC_FIELD, "not a number") } + val request = new DummyRequest + assert(request.getSparkVersion === null) + assert(request.getMessage === null) + assert(request.getActive === null) + assert(request.getAge === null) + assert(request.getName === null) + request.setSparkVersion("1.2.3") + request.setActive("true") + request.setAge("10") + request.setName("dolphin") + assert(request.getSparkVersion === "1.2.3") + assert(request.getMessage === null) + assert(request.getActive === "true") + assert(request.getAge === "10") + assert(request.getName === "dolphin") + // overwrite + request.setName("shark") + request.setActive("false") + assert(request.getName === "shark") + assert(request.getActive === "false") + } + + test("get and set fields with null values") { + val request = new DummyRequest + request.setSparkVersion(null) + request.setActive(null) + request.setAge(null) + request.setName(null) + request.setMessage(null) + assert(request.getSparkVersion === null) + assert(request.getMessage === null) + assert(request.getActive === null) + assert(request.getAge === null) + assert(request.getName === null) + } + + test("set fields with illegal argument") { + val request = new DummyRequest + intercept[IllegalArgumentException] { request.setActive("not-a-boolean") } + intercept[IllegalArgumentException] { request.setActive("150") } + intercept[IllegalArgumentException] { request.setAge("not-a-number") } + intercept[IllegalArgumentException] { request.setAge("true") } } - test("to and from JSON") { - import DummyField._ - val message = new DummyMessage() - .setField(DUMMY_FIELD, "dummy value") - .setField(BOOLEAN_FIELD, "true") - .setField(MEMORY_FIELD, "401k") - .setField(NUMERIC_FIELD, "401") - .setField(REQUIRED_FIELD, "dummy value") - .validate() - val expectedJson = - """ - |{ - | "ACTION" : "DUMMY_ACTION", - | "DUMMY_FIELD" : "dummy value", - | "BOOLEAN_FIELD" : "true", - | "MEMORY_FIELD" : "401k", - | "NUMERIC_FIELD" : "401", - | "REQUIRED_FIELD" : "dummy value" - |} - """.stripMargin - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - // Do not use SubmitRestProtocolMessage.fromJson here - // because DUMMY_ACTION is not a known action - val jsonObject = parse(expectedJson).asInstanceOf[JObject] - val newMessage = DummyMessage.fromJsonObject(jsonObject) - assert(newMessage.getFieldNotNull(ACTION) === "DUMMY_ACTION") - assert(newMessage.getFieldNotNull(DUMMY_FIELD) === "dummy value") - assert(newMessage.getFieldNotNull(BOOLEAN_FIELD) === "true") - assert(newMessage.getFieldNotNull(MEMORY_FIELD) === "401k") - assert(newMessage.getFieldNotNull(NUMERIC_FIELD) === "401") - assert(newMessage.getFieldNotNull(REQUIRED_FIELD) === "dummy value") - assert(newMessage.getFields.size === 6) + test("validate") { + val request = new DummyRequest + intercept[AssertionError] { request.validate() } // missing everything + request.setSparkVersion("1.4.8") + intercept[AssertionError] { request.validate() } // missing name and age + request.setName("something") + intercept[AssertionError] { request.validate() } // missing only age + request.setAge("2") + intercept[AssertionError] { request.validate() } // age too low + request.setAge("10") + request.validate() // everything is set + request.setSparkVersion(null) + intercept[AssertionError] { request.validate() } // missing only Spark version + request.setSparkVersion("1.2.3") + request.setName(null) + intercept[AssertionError] { request.validate() } // missing only name + request.setMessage("not-setting-name") + intercept[AssertionError] { request.validate() } // still missing name } - test("SubmitDriverRequestMessage") { - import SubmitDriverRequestField._ - val message = new SubmitDriverRequestMessage - intercept[IllegalArgumentException] { message.validate() } - message.setField(CLIENT_SPARK_VERSION, "1.2.3") - message.setField(MESSAGE, "Submitting them drivers.") - message.setField(APP_NAME, "SparkPie") - message.setField(APP_RESOURCE, "honey-walnut-cherry.jar") - // all required fields are now set + test("request to and from JSON") { + val request = new DummyRequest() + .setSparkVersion("1.2.3") + .setActive("true") + .setAge("25") + .setName("jung") + val json = request.toJson + assertJsonEquals(json, dummyRequestJson) + val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest]) + assert(newRequest.getSparkVersion === "1.2.3") + assert(newRequest.getClientSparkVersion === "1.2.3") + assert(newRequest.getActive === "true") + assert(newRequest.getAge === "25") + assert(newRequest.getName === "jung") + assert(newRequest.getMessage === null) + } + + test("response to and from JSON") { + val response = new DummyResponse().setSparkVersion("3.3.4") + val json = response.toJson + assertJsonEquals(json, dummyResponseJson) + val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse]) + assert(newResponse.getSparkVersion === "3.3.4") + assert(newResponse.getServerSparkVersion === "3.3.4") + assert(newResponse.getMessage === null) + } + + test("SubmitDriverRequest") { + val message = new SubmitDriverRequest + intercept[AssertionError] { message.validate() } + intercept[IllegalArgumentException] { message.setDriverCores("one hundred feet") } + intercept[IllegalArgumentException] { message.setSuperviseDriver("nope, never") } + intercept[IllegalArgumentException] { message.setTotalExecutorCores("two men") } + message.setSparkVersion("1.2.3") + message.setAppName("SparkPie") + message.setAppResource("honey-walnut-cherry.jar") message.validate() - message.setField(MAIN_CLASS, "org.apache.spark.examples.SparkPie") - message.setField(JARS, "mayonnaise.jar,ketchup.jar") - message.setField(FILES, "fireball.png") - message.setField(PY_FILES, "do-not-eat-my.py") - message.setField(DRIVER_MEMORY, "512m") - message.setField(DRIVER_CORES, "180") - message.setField(DRIVER_EXTRA_JAVA_OPTIONS, " -Dslices=5 -Dcolor=mostly_red") - message.setField(DRIVER_EXTRA_CLASS_PATH, "food-coloring.jar") - message.setField(DRIVER_EXTRA_LIBRARY_PATH, "pickle.jar") - message.setField(SUPERVISE_DRIVER, "false") - message.setField(EXECUTOR_MEMORY, "256m") - message.setField(TOTAL_EXECUTOR_CORES, "10000") - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(DRIVER_MEMORY, "more than expected") } - intercept[IllegalArgumentException] { message.setField(DRIVER_CORES, "one hundred feet") } - intercept[IllegalArgumentException] { message.setField(SUPERVISE_DRIVER, "nope, never") } - intercept[IllegalArgumentException] { message.setField(EXECUTOR_MEMORY, "less than expected") } - intercept[IllegalArgumentException] { message.setField(TOTAL_EXECUTOR_CORES, "two men") } - intercept[IllegalArgumentException] { message.setField(APP_ARGS, "anything") } - intercept[IllegalArgumentException] { message.setField(SPARK_PROPERTIES, "anything") } - intercept[IllegalArgumentException] { message.setField(ENVIRONMENT_VARIABLES, "anything") } + // optional fields + message.setMainClass("org.apache.spark.examples.SparkPie") + message.setJars("mayonnaise.jar,ketchup.jar") + message.setFiles("fireball.png") + message.setPyFiles("do-not-eat-my.py") + message.setDriverMemory("512m") + message.setDriverCores("180") + message.setDriverExtraJavaOptions(" -Dslices=5 -Dcolor=mostly_red") + message.setDriverExtraClassPath("food-coloring.jar") + message.setDriverExtraLibraryPath("pickle.jar") + message.setSuperviseDriver("false") + message.setExecutorMemory("256m") + message.setTotalExecutorCores("10000") // special fields - message.appendAppArg("two slices") - message.appendAppArg("a hint of cinnamon") + message.addAppArg("two slices") + message.addAppArg("a hint of cinnamon") message.setSparkProperty("spark.live.long", "true") message.setSparkProperty("spark.shuffle.enabled", "false") message.setEnvironmentVariable("PATH", "/dev/null") @@ -181,231 +195,234 @@ class SubmitRestProtocolSuite extends FunSuite { assert(message.getEnvironmentVariables("PATH") === "/dev/null") assert(message.getEnvironmentVariables("PYTHONPATH") === "/dev/null") // test JSON - val expectedJson = submitDriverRequestJson - assertJsonEquals(message.toJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - .asInstanceOf[SubmitDriverRequestMessage] - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, submitDriverRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmitDriverRequest]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getClientSparkVersion === "1.2.3") + assert(newMessage.getAppName === "SparkPie") + assert(newMessage.getAppResource === "honey-walnut-cherry.jar") + assert(newMessage.getMainClass === "org.apache.spark.examples.SparkPie") + assert(newMessage.getJars === "mayonnaise.jar,ketchup.jar") + assert(newMessage.getFiles === "fireball.png") + assert(newMessage.getPyFiles === "do-not-eat-my.py") + assert(newMessage.getDriverMemory === "512m") + assert(newMessage.getDriverCores === "180") + assert(newMessage.getDriverExtraJavaOptions === " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.getDriverExtraClassPath === "food-coloring.jar") + assert(newMessage.getDriverExtraLibraryPath === "pickle.jar") + assert(newMessage.getSuperviseDriver === "false") + assert(newMessage.getExecutorMemory === "256m") + assert(newMessage.getTotalExecutorCores === "10000") assert(newMessage.getAppArgs === message.getAppArgs) assert(newMessage.getSparkProperties === message.getSparkProperties) assert(newMessage.getEnvironmentVariables === message.getEnvironmentVariables) } - test("SubmitDriverResponseMessage") { - import SubmitDriverResponseField._ - val message = new SubmitDriverResponseMessage - intercept[IllegalArgumentException] { message.validate() } - message.setField(SERVER_SPARK_VERSION, "1.2.3") - message.setField(MESSAGE, "Dem driver is now submitted.") - message.setField(DRIVER_ID, "driver_123") - message.setField(SUCCESS, "true") - // all required fields are now set + test("SubmitDriverResponse") { + val message = new SubmitDriverResponse + intercept[AssertionError] { message.validate() } + intercept[IllegalArgumentException] { message.setSuccess("maybe not") } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") + message.setSuccess("true") message.validate() - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe not") } // test JSON - val expectedJson = submitDriverResponseJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[SubmitDriverResponseMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, submitDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmitDriverResponse]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getServerSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") + assert(newMessage.getSuccess === "true") } - test("KillDriverRequestMessage") { - import KillDriverRequestField._ - val message = new KillDriverRequestMessage - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - message.setField(CLIENT_SPARK_VERSION, "1.2.3") - message.setField(DRIVER_ID, "driver_123") - // all required fields are now set + test("KillDriverRequest") { + val message = new KillDriverRequest + intercept[AssertionError] { message.validate() } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") message.validate() // test JSON - val expectedJson = killDriverRequestJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[KillDriverRequestMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, killDriverRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillDriverRequest]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getClientSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") } - test("KillDriverResponseMessage") { - import KillDriverResponseField._ - val message = new KillDriverResponseMessage - intercept[IllegalArgumentException] { message.validate() } - message.setField(SERVER_SPARK_VERSION, "1.2.3") - message.setField(DRIVER_ID, "driver_123") - message.setField(SUCCESS, "true") - // all required fields are now set + test("KillDriverResponse") { + val message = new KillDriverResponse + intercept[AssertionError] { message.validate() } + intercept[IllegalArgumentException] { message.setSuccess("maybe not") } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") + message.setSuccess("true") message.validate() - message.setField(MESSAGE, "Killing dem reckless drivers.") - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe?") } // test JSON - val expectedJson = killDriverResponseJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[KillDriverResponseMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, killDriverResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillDriverResponse]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getServerSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") + assert(newMessage.getSuccess === "true") } - test("DriverStatusRequestMessage") { - import DriverStatusRequestField._ - val message = new DriverStatusRequestMessage - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - message.setField(CLIENT_SPARK_VERSION, "1.2.3") - message.setField(DRIVER_ID, "driver_123") - // all required fields are now set + test("DriverStatusRequest") { + val message = new DriverStatusRequest + intercept[AssertionError] { message.validate() } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") message.validate() // test JSON - val expectedJson = driverStatusRequestJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[DriverStatusRequestMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, driverStatusRequestJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[DriverStatusRequest]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getClientSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") } - test("DriverStatusResponseMessage") { - import DriverStatusResponseField._ - val message = new DriverStatusResponseMessage - intercept[IllegalArgumentException] { message.validate() } - message.setField(SERVER_SPARK_VERSION, "1.2.3") - message.setField(DRIVER_ID, "driver_123") - message.setField(SUCCESS, "true") - // all required fields are now set + test("DriverStatusResponse") { + val message = new DriverStatusResponse + intercept[AssertionError] { message.validate() } + intercept[IllegalArgumentException] { message.setSuccess("maybe") } + message.setSparkVersion("1.2.3") + message.setDriverId("driver_123") + message.setSuccess("true") message.validate() - message.setField(MESSAGE, "Your driver is having some trouble...") - message.setField(DRIVER_STATE, "RUNNING") - message.setField(WORKER_ID, "worker_123") - message.setField(WORKER_HOST_PORT, "1.2.3.4:7780") - // bad field values - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - intercept[IllegalArgumentException] { message.setField(SUCCESS, "maybe") } + // optional fields + message.setDriverState("RUNNING") + message.setWorkerId("worker_123") + message.setWorkerHostPort("1.2.3.4:7780") // test JSON - val expectedJson = driverStatusResponseJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[DriverStatusResponseMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, driverStatusResponseJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[DriverStatusResponse]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getServerSparkVersion === "1.2.3") + assert(newMessage.getDriverId === "driver_123") + assert(newMessage.getSuccess === "true") } - test("ErrorMessage") { - import ErrorField._ - val message = new ErrorMessage - intercept[IllegalArgumentException] { message.validate() } - intercept[IllegalArgumentException] { message.setField(ACTION, "anything") } - message.setField(SERVER_SPARK_VERSION, "1.2.3") - message.setField(MESSAGE, "Your wife threw an exception!") - // all required fields are now set + test("ErrorResponse") { + val message = new ErrorResponse + intercept[AssertionError] { message.validate() } + message.setSparkVersion("1.2.3") + message.setMessage("Field not found in submit request: X") message.validate() // test JSON - val expectedJson = errorJson - val actualJson = message.toJson - assertJsonEquals(actualJson, expectedJson) - val newMessage = SubmitRestProtocolMessage.fromJson(expectedJson) - assert(newMessage.isInstanceOf[ErrorMessage]) - assert(newMessage.getFields === message.getFields) + val json = message.toJson + assertJsonEquals(json, errorJson) + val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[ErrorResponse]) + assert(newMessage.getSparkVersion === "1.2.3") + assert(newMessage.getServerSparkVersion === "1.2.3") + assert(newMessage.getMessage === "Field not found in submit request: X") } + private val dummyRequestJson = + """ + |{ + | "action" : "DUMMY_REQUEST", + | "active" : "true", + | "age" : "25", + | "client_spark_version" : "1.2.3", + | "name" : "jung" + |} + """.stripMargin + + private val dummyResponseJson = + """ + |{ + | "action" : "DUMMY_RESPONSE", + | "server_spark_version" : "3.3.4" + |} + """.stripMargin + private val submitDriverRequestJson = """ |{ - | "ACTION" : "SUBMIT_DRIVER_REQUEST", - | "CLIENT_SPARK_VERSION" : "1.2.3", - | "MESSAGE" : "Submitting them drivers.", - | "APP_NAME" : "SparkPie", - | "APP_RESOURCE" : "honey-walnut-cherry.jar", - | "MAIN_CLASS" : "org.apache.spark.examples.SparkPie", - | "JARS" : "mayonnaise.jar,ketchup.jar", - | "FILES" : "fireball.png", - | "PY_FILES" : "do-not-eat-my.py", - | "DRIVER_MEMORY" : "512m", - | "DRIVER_CORES" : "180", - | "DRIVER_EXTRA_JAVA_OPTIONS" : " -Dslices=5 -Dcolor=mostly_red", - | "DRIVER_EXTRA_CLASS_PATH" : "food-coloring.jar", - | "DRIVER_EXTRA_LIBRARY_PATH" : "pickle.jar", - | "SUPERVISE_DRIVER" : "false", - | "EXECUTOR_MEMORY" : "256m", - | "TOTAL_EXECUTOR_CORES" : "10000", - | "APP_ARGS" : [ "two slices", "a hint of cinnamon" ], - | "SPARK_PROPERTIES" : { - | "spark.live.long" : "true", - | "spark.shuffle.enabled" : "false" - | }, - | "ENVIRONMENT_VARIABLES" : { - | "PATH" : "/dev/null", - | "PYTHONPATH" : "/dev/null" - | } + | "action" : "SUBMIT_DRIVER_REQUEST", + | "app_args" : "[\"two slices\",\"a hint of cinnamon\"]", + | "app_name" : "SparkPie", + | "app_resource" : "honey-walnut-cherry.jar", + | "client_spark_version" : "1.2.3", + | "driver_cores" : "180", + | "driver_extra_class_path" : "food-coloring.jar", + | "driver_extra_java_options" : " -Dslices=5 -Dcolor=mostly_red", + | "driver_extra_library_path" : "pickle.jar", + | "driver_memory" : "512m", + | "environment_variables" : "{\"PATH\":\"/dev/null\",\"PYTHONPATH\":\"/dev/null\"}", + | "executor_memory" : "256m", + | "files" : "fireball.png", + | "jars" : "mayonnaise.jar,ketchup.jar", + | "main_class" : "org.apache.spark.examples.SparkPie", + | "py_files" : "do-not-eat-my.py", + | "spark_properties" : "{\"spark.live.long\":\"true\",\"spark.shuffle.enabled\":\"false\"}", + | "supervise_driver" : "false", + | "total_executor_cores" : "10000" |} """.stripMargin private val submitDriverResponseJson = """ |{ - | "ACTION" : "SUBMIT_DRIVER_RESPONSE", - | "SERVER_SPARK_VERSION" : "1.2.3", - | "MESSAGE" : "Dem driver is now submitted.", - | "DRIVER_ID" : "driver_123", - | "SUCCESS" : "true" + | "action" : "SUBMIT_DRIVER_RESPONSE", + | "driver_id" : "driver_123", + | "server_spark_version" : "1.2.3", + | "success" : "true" |} """.stripMargin private val killDriverRequestJson = """ |{ - | "ACTION" : "KILL_DRIVER_REQUEST", - | "CLIENT_SPARK_VERSION" : "1.2.3", - | "DRIVER_ID" : "driver_123" + | "action" : "KILL_DRIVER_REQUEST", + | "client_spark_version" : "1.2.3", + | "driver_id" : "driver_123" |} """.stripMargin private val killDriverResponseJson = """ |{ - | "ACTION" : "KILL_DRIVER_RESPONSE", - | "SERVER_SPARK_VERSION" : "1.2.3", - | "DRIVER_ID" : "driver_123", - | "SUCCESS" : "true", - | "MESSAGE" : "Killing dem reckless drivers." + | "action" : "KILL_DRIVER_RESPONSE", + | "driver_id" : "driver_123", + | "server_spark_version" : "1.2.3", + | "success" : "true" |} """.stripMargin private val driverStatusRequestJson = """ |{ - | "ACTION" : "DRIVER_STATUS_REQUEST", - | "CLIENT_SPARK_VERSION" : "1.2.3", - | "DRIVER_ID" : "driver_123" + | "action" : "DRIVER_STATUS_REQUEST", + | "client_spark_version" : "1.2.3", + | "driver_id" : "driver_123" |} """.stripMargin private val driverStatusResponseJson = """ |{ - | "ACTION" : "DRIVER_STATUS_RESPONSE", - | "SERVER_SPARK_VERSION" : "1.2.3", - | "DRIVER_ID" : "driver_123", - | "SUCCESS" : "true", - | "MESSAGE" : "Your driver is having some trouble...", - | "DRIVER_STATE" : "RUNNING", - | "WORKER_ID" : "worker_123", - | "WORKER_HOST_PORT" : "1.2.3.4:7780" + | "action" : "DRIVER_STATUS_RESPONSE", + | "driver_id" : "driver_123", + | "driver_state" : "RUNNING", + | "server_spark_version" : "1.2.3", + | "success" : "true", + | "worker_host_port" : "1.2.3.4:7780", + | "worker_id" : "worker_123" |} """.stripMargin private val errorJson = """ |{ - | "ACTION" : "ERROR", - | "SERVER_SPARK_VERSION" : "1.2.3", - | "MESSAGE" : "Your wife threw an exception!" + | "action" : "ERROR", + | "message" : "Field not found in submit request: X", + | "server_spark_version" : "1.2.3" |} """.stripMargin } From 8d43486cf1a8e257409f1131bebcebccea66caa3 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 29 Jan 2015 11:49:04 -0800 Subject: [PATCH 19/48] Replace SubmitRestProtocolAction with class name This makes it easier for users to define their own messages. Rather than forcing the users to introduce an action class that extends our inflexible enum, they can now implement custom logic in their REST servers depending on the action of the request. --- .../deploy/rest/DriverStatusRequest.scala | 3 -- .../deploy/rest/DriverStatusResponse.scala | 1 - .../spark/deploy/rest/ErrorResponse.scala | 1 - .../spark/deploy/rest/KillDriverRequest.scala | 3 -- .../deploy/rest/KillDriverResponse.scala | 1 - .../deploy/rest/SubmitDriverRequest.scala | 9 ++--- .../deploy/rest/SubmitDriverResponse.scala | 1 - .../deploy/rest/SubmitRestProtocolField.scala | 24 ------------ .../rest/SubmitRestProtocolMessage.scala | 37 +++++++++++-------- .../deploy/rest/SubmitRestProtocolSuite.scala | 26 +++++-------- 10 files changed, 35 insertions(+), 71 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala index f5d4d95cebf1..bb73c61e68ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -18,12 +18,9 @@ package org.apache.spark.deploy.rest class DriverStatusRequest extends SubmitRestProtocolRequest { - protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_REQUEST private val driverId = new SubmitRestProtocolField[String] - def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) - override def validate(): Unit = { super.validate() assertFieldIsSet(driverId, "driver_id") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala index 1e8090c33681..6da41d09b3f2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.rest class DriverStatusResponse extends SubmitRestProtocolResponse { - protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE private val driverId = new SubmitRestProtocolField[String] private val success = new SubmitRestProtocolField[Boolean] private val driverState = new SubmitRestProtocolField[String] diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala index 8c30d3185088..0e08831e7b6a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.rest class ErrorResponse extends SubmitRestProtocolResponse { - protected override val action = SubmitRestProtocolAction.ERROR override def validate(): Unit = { super.validate() assertFieldIsSet(message, "message") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala index c44c94d95a1f..31b127876c43 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -18,12 +18,9 @@ package org.apache.spark.deploy.rest class KillDriverRequest extends SubmitRestProtocolRequest { - protected override val action = SubmitRestProtocolAction.KILL_DRIVER_REQUEST private val driverId = new SubmitRestProtocolField[String] - def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) - override def validate(): Unit = { super.validate() assertFieldIsSet(driverId, "driver_id") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala index e75a52bc9bf0..107b447c3d1c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.rest class KillDriverResponse extends SubmitRestProtocolResponse { - protected override val action = SubmitRestProtocolAction.KILL_DRIVER_RESPONSE private val driverId = new SubmitRestProtocolField[String] private val success = new SubmitRestProtocolField[Boolean] diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala index 9bde3345d03f..87132511587a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala @@ -26,7 +26,6 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.util.JsonProtocol class SubmitDriverRequest extends SubmitRestProtocolRequest { - protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST private val appName = new SubmitRestProtocolField[String] private val appResource = new SubmitRestProtocolField[String] private val mainClass = new SubmitRestProtocolField[String] @@ -62,7 +61,7 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest { def getExecutorMemory: String = executorMemory.toString def getTotalExecutorCores: String = totalExecutorCores.toString - // Special getters required for JSON de/serialization + // Special getters required for JSON serialization @JsonProperty("appArgs") private def getAppArgsJson: String = arrayToJson(getAppArgs) @JsonProperty("sparkProperties") @@ -85,7 +84,7 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest { def setExecutorMemory(s: String): this.type = setField(executorMemory, s) def setTotalExecutorCores(s: String): this.type = setNumericField(totalExecutorCores, s) - // Special setters required for JSON de/serialization + // Special setters required for JSON deserialization @JsonProperty("appArgs") private def setAppArgsJson(s: String): Unit = { appArgs.clear() @@ -116,11 +115,11 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest { def setEnvironmentVariable(k: String, v: String): this.type = { envVars(k) = v; this } private def arrayToJson(arr: Array[String]): String = { - if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else { null } + if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else null } private def mapToJson(map: Map[String, String]): String = { - if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else { null } + if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else null } override def validate(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala index 8a1676767cec..b1825af8ce56 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.rest class SubmitDriverResponse extends SubmitRestProtocolResponse { - protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE private val success = new SubmitRestProtocolField[Boolean] private val driverId = new SubmitRestProtocolField[String] diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala index 4c0c45b450fa..33e4fe4d5c2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala @@ -17,30 +17,6 @@ package org.apache.spark.deploy.rest -/** - * All possible values of the ACTION field in a SubmitRestProtocolMessage. - */ -abstract class SubmitRestProtocolAction -object SubmitRestProtocolAction { - case object SUBMIT_DRIVER_REQUEST extends SubmitRestProtocolAction - case object SUBMIT_DRIVER_RESPONSE extends SubmitRestProtocolAction - case object KILL_DRIVER_REQUEST extends SubmitRestProtocolAction - case object KILL_DRIVER_RESPONSE extends SubmitRestProtocolAction - case object DRIVER_STATUS_REQUEST extends SubmitRestProtocolAction - case object DRIVER_STATUS_RESPONSE extends SubmitRestProtocolAction - case object ERROR extends SubmitRestProtocolAction - private val allActions = - Seq(SUBMIT_DRIVER_REQUEST, SUBMIT_DRIVER_RESPONSE, KILL_DRIVER_REQUEST, - KILL_DRIVER_RESPONSE, DRIVER_STATUS_REQUEST, DRIVER_STATUS_RESPONSE, ERROR) - private val allActionsMap = allActions.map { a => (a.toString, a) }.toMap - - def fromString(action: String): SubmitRestProtocolAction = { - allActionsMap.get(action).getOrElse { - throw new IllegalArgumentException(s"Unknown action $action") - } - } -} - class SubmitRestProtocolField[T] { protected var value: Option[T] = None def isSet: Boolean = value.isDefined diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 0b2085b5e3bf..0aa72d236e1a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -25,7 +25,6 @@ import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.util.Utils -import org.apache.spark.deploy.rest.SubmitRestProtocolAction._ @JsonInclude(Include.NON_NULL) @JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) @@ -34,12 +33,12 @@ abstract class SubmitRestProtocolMessage { import SubmitRestProtocolMessage._ private val messageType = Utils.getFormattedClassName(this) - protected val action: SubmitRestProtocolAction + protected val action: String = camelCaseToUnderscores(decapitalize(messageType)) protected val sparkVersion = new SubmitRestProtocolField[String] protected val message = new SubmitRestProtocolField[String] // Required for JSON de/serialization and not explicitly used - private def getAction: String = action.toString + private def getAction: String = action private def setAction(s: String): this.type = this // Spark version implementation depends on whether this is a request or a response @@ -124,24 +123,22 @@ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { object SubmitRestProtocolMessage { private val mapper = new ObjectMapper + private val packagePrefix = this.getClass.getPackage.getName - def fromJson(json: String): SubmitRestProtocolMessage = { - val fields = parse(json).asInstanceOf[JObject].obj - val action = fields + def parseAction(json: String): String = { + parse(json).asInstanceOf[JObject].obj .find { case (f, _) => f == "action" } .map { case (_, v) => v.asInstanceOf[JString].s } .getOrElse { - throw new IllegalArgumentException(s"Could not find action field in message:\n$json") - } - val clazz = SubmitRestProtocolAction.fromString(action) match { - case SUBMIT_DRIVER_REQUEST => classOf[SubmitDriverRequest] - case SUBMIT_DRIVER_RESPONSE => classOf[SubmitDriverResponse] - case KILL_DRIVER_REQUEST => classOf[KillDriverRequest] - case KILL_DRIVER_RESPONSE => classOf[KillDriverResponse] - case DRIVER_STATUS_REQUEST => classOf[DriverStatusRequest] - case DRIVER_STATUS_RESPONSE => classOf[DriverStatusResponse] - case ERROR => classOf[ErrorResponse] + throw new IllegalArgumentException(s"Could not find action field in message:\n$json") } + } + + def fromJson(json: String): SubmitRestProtocolMessage = { + val action = parseAction(json) + val className = underscoresToCamelCase(action).capitalize + val clazz = Class.forName(packagePrefix + "." + className) + .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) fromJson(json, clazz) } @@ -178,6 +175,14 @@ object SubmitRestProtocolMessage { } newString.toString() } + + private def decapitalize(s: String): String = { + if (s != null && s.nonEmpty) { + s(0).toLower + s.substring(1) + } else { + s + } + } } object SubmitRestProtocolRequest { diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index a7468a02dfe8..bc9639095d1b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -20,11 +20,7 @@ package org.apache.spark.deploy.rest import org.json4s.jackson.JsonMethods._ import org.scalatest.FunSuite -case object DUMMY_REQUEST extends SubmitRestProtocolAction -case object DUMMY_RESPONSE extends SubmitRestProtocolAction - class DummyRequest extends SubmitRestProtocolRequest { - protected override val action = DUMMY_REQUEST private val active = new SubmitRestProtocolField[Boolean] private val age = new SubmitRestProtocolField[Int] private val name = new SubmitRestProtocolField[String] @@ -45,9 +41,7 @@ class DummyRequest extends SubmitRestProtocolRequest { } } -class DummyResponse extends SubmitRestProtocolResponse { - protected override val action = DUMMY_RESPONSE -} +class DummyResponse extends SubmitRestProtocolResponse /** * Tests for the stable application submission REST protocol. @@ -325,7 +319,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val dummyRequestJson = """ |{ - | "action" : "DUMMY_REQUEST", + | "action" : "dummy_request", | "active" : "true", | "age" : "25", | "client_spark_version" : "1.2.3", @@ -336,7 +330,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val dummyResponseJson = """ |{ - | "action" : "DUMMY_RESPONSE", + | "action" : "dummy_response", | "server_spark_version" : "3.3.4" |} """.stripMargin @@ -344,7 +338,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val submitDriverRequestJson = """ |{ - | "action" : "SUBMIT_DRIVER_REQUEST", + | "action" : "submit_driver_request", | "app_args" : "[\"two slices\",\"a hint of cinnamon\"]", | "app_name" : "SparkPie", | "app_resource" : "honey-walnut-cherry.jar", @@ -369,7 +363,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val submitDriverResponseJson = """ |{ - | "action" : "SUBMIT_DRIVER_RESPONSE", + | "action" : "submit_driver_response", | "driver_id" : "driver_123", | "server_spark_version" : "1.2.3", | "success" : "true" @@ -379,7 +373,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val killDriverRequestJson = """ |{ - | "action" : "KILL_DRIVER_REQUEST", + | "action" : "kill_driver_request", | "client_spark_version" : "1.2.3", | "driver_id" : "driver_123" |} @@ -388,7 +382,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val killDriverResponseJson = """ |{ - | "action" : "KILL_DRIVER_RESPONSE", + | "action" : "kill_driver_response", | "driver_id" : "driver_123", | "server_spark_version" : "1.2.3", | "success" : "true" @@ -398,7 +392,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val driverStatusRequestJson = """ |{ - | "action" : "DRIVER_STATUS_REQUEST", + | "action" : "driver_status_request", | "client_spark_version" : "1.2.3", | "driver_id" : "driver_123" |} @@ -407,7 +401,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val driverStatusResponseJson = """ |{ - | "action" : "DRIVER_STATUS_RESPONSE", + | "action" : "driver_status_response", | "driver_id" : "driver_123", | "driver_state" : "RUNNING", | "server_spark_version" : "1.2.3", @@ -420,7 +414,7 @@ class SubmitRestProtocolSuite extends FunSuite { private val errorJson = """ |{ - | "action" : "ERROR", + | "action" : "error_response", | "message" : "Field not found in submit request: X", | "server_spark_version" : "1.2.3" |} From 3db7379e9dea29108eb5f2390b7b46c2db93c172 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 29 Jan 2015 15:16:33 -0800 Subject: [PATCH 20/48] Fix comments and name fields for better error messages This commit also includes comprehensive cleanups across the board and simplifies the serialization process by eliminating the naming constraints on the JSON fields. --- .../deploy/rest/DriverStatusRequest.scala | 9 +- .../deploy/rest/DriverStatusResponse.scala | 18 +- .../spark/deploy/rest/ErrorResponse.scala | 5 +- .../spark/deploy/rest/KillDriverRequest.scala | 9 +- .../deploy/rest/KillDriverResponse.scala | 11 +- .../deploy/rest/StandaloneRestClient.scala | 60 ++++--- .../deploy/rest/StandaloneRestServer.scala | 29 ++-- .../deploy/rest/SubmitDriverRequest.scala | 48 ++++-- .../deploy/rest/SubmitDriverResponse.scala | 9 +- .../spark/deploy/rest/SubmitRestClient.scala | 21 +-- .../deploy/rest/SubmitRestProtocolField.scala | 14 +- .../rest/SubmitRestProtocolMessage.scala | 158 ++++++++---------- .../spark/deploy/rest/SubmitRestServer.scala | 28 ++-- .../deploy/rest/SubmitRestProtocolSuite.scala | 151 +++++++++-------- 14 files changed, 301 insertions(+), 269 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala index bb73c61e68ba..ec0e197cfa34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -17,12 +17,17 @@ package org.apache.spark.deploy.rest +/** + * A request to query the status of a driver in the REST application submission protocol. + */ class DriverStatusRequest extends SubmitRestProtocolRequest { - private val driverId = new SubmitRestProtocolField[String] + private val driverId = new SubmitRestProtocolField[String]("driverId") + def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) + override def validate(): Unit = { super.validate() - assertFieldIsSet(driverId, "driver_id") + assertFieldIsSet(driverId) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala index 6da41d09b3f2..2819ef50a75d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -17,12 +17,16 @@ package org.apache.spark.deploy.rest +/** + * A response to the [[DriverStatusRequest]] in the REST application submission protocol. + */ class DriverStatusResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String] - private val success = new SubmitRestProtocolField[Boolean] - private val driverState = new SubmitRestProtocolField[String] - private val workerId = new SubmitRestProtocolField[String] - private val workerHostPort = new SubmitRestProtocolField[String] + private val driverId = new SubmitRestProtocolField[String]("driverId") + private val success = new SubmitRestProtocolField[Boolean]("success") + // standalone cluster mode only + private val driverState = new SubmitRestProtocolField[String]("driverState") + private val workerId = new SubmitRestProtocolField[String]("workerId") + private val workerHostPort = new SubmitRestProtocolField[String]("workerHostPort") def getDriverId: String = driverId.toString def getSuccess: String = success.toString @@ -38,7 +42,7 @@ class DriverStatusResponse extends SubmitRestProtocolResponse { override def validate(): Unit = { super.validate() - assertFieldIsSet(driverId, "driver_id") - assertFieldIsSet(success, "success") + assertFieldIsSet(driverId) + assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala index 0e08831e7b6a..5388cddb070f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -17,9 +17,12 @@ package org.apache.spark.deploy.rest +/** + * An error response message used in the REST application submission protocol. + */ class ErrorResponse extends SubmitRestProtocolResponse { override def validate(): Unit = { super.validate() - assertFieldIsSet(message, "message") + assertFieldIsSet(message) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala index 31b127876c43..97f5dd2ba822 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -17,12 +17,17 @@ package org.apache.spark.deploy.rest +/** + * A request to kill a driver in the REST application submission protocol. + */ class KillDriverRequest extends SubmitRestProtocolRequest { - private val driverId = new SubmitRestProtocolField[String] + private val driverId = new SubmitRestProtocolField[String]("driverId") + def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) + override def validate(): Unit = { super.validate() - assertFieldIsSet(driverId, "driver_id") + assertFieldIsSet(driverId) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala index 107b447c3d1c..fe68800e9980 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -17,9 +17,12 @@ package org.apache.spark.deploy.rest +/** + * A response to the [[KillDriverRequest]] in the REST application submission protocol. + */ class KillDriverResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String] - private val success = new SubmitRestProtocolField[Boolean] + private val driverId = new SubmitRestProtocolField[String]("driverId") + private val success = new SubmitRestProtocolField[Boolean]("success") def getDriverId: String = driverId.toString def getSuccess: String = success.toString @@ -29,7 +32,7 @@ class KillDriverResponse extends SubmitRestProtocolResponse { override def validate(): Unit = { super.validate() - assertFieldIsSet(driverId, "driver_id") - assertFieldIsSet(success, "success") + assertFieldIsSet(driverId) + assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index b564006fd745..6f2752c848a0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -21,11 +21,10 @@ import java.net.URL import org.apache.spark.{SPARK_VERSION => sparkVersion} import org.apache.spark.deploy.SparkSubmitArguments -import org.apache.spark.util.Utils /** - * A client that submits applications to the standalone Master using the stable REST protocol. - * This client is intended to communicate with the StandaloneRestServer. Cluster mode only. + * A client that submits applications to the standalone Master using the REST protocol + * This client is intended to communicate with the [[StandaloneRestServer]]. Cluster mode only. */ private[spark] class StandaloneRestClient extends SubmitRestClient { import StandaloneRestClient._ @@ -38,7 +37,8 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { * this reports failure and logs an error message provided by the REST server. */ override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { - val submitResponse = super.submitDriver(args).asInstanceOf[SubmitDriverResponse] + validateSubmitArgs(args) + val submitResponse = super.submitDriver(args) val submitSuccess = submitResponse.getSuccess.toBoolean if (submitSuccess) { val driverId = submitResponse.getDriverId @@ -51,14 +51,25 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { submitResponse } + /** Request that the REST server kill the specified driver. */ + override def killDriver(master: String, driverId: String): KillDriverResponse = { + validateMaster(master) + super.killDriver(master, driverId) + } + + /** Request the status of the specified driver from the REST server. */ + override def requestDriverStatus(master: String, driverId: String): DriverStatusResponse = { + validateMaster(master) + super.requestDriverStatus(master, driverId) + } + /** - * Poll the status of the driver that was just submitted and report it. - * This retries up to a fixed number of times until giving up. + * Poll the status of the driver that was just submitted and log it. + * This retries up to a fixed number of times before giving up. */ private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => val statusResponse = requestDriverStatus(master, driverId) - .asInstanceOf[DriverStatusResponse] val statusSuccess = statusResponse.getSuccess.toBoolean if (statusSuccess) { val driverState = statusResponse.getDriverState @@ -75,13 +86,13 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { exception.foreach { e => logError(e) } return } + Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL) } logError(s"Error: Master did not recognize driver $driverId.") } /** Construct a submit driver request message. */ - override protected def constructSubmitRequest( - args: SparkSubmitArguments): SubmitDriverRequest = { + protected override def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequest = { val message = new SubmitDriverRequest() .setSparkVersion(sparkVersion) .setAppName(args.name) @@ -99,12 +110,14 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setTotalExecutorCores(args.totalExecutorCores) args.childArgs.foreach(message.addAppArg) args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } - // TODO: send special environment variables? + sys.env.foreach { case (k, v) => + if (k.startsWith("SPARK_")) { message.setEnvironmentVariable(k, v) } + } message } /** Construct a kill driver request message. */ - override protected def constructKillRequest( + protected override def constructKillRequest( master: String, driverId: String): KillDriverRequest = { new KillDriverRequest() @@ -113,7 +126,7 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { } /** Construct a driver status request message. */ - override protected def constructStatusRequest( + protected override def constructStatusRequest( master: String, driverId: String): DriverStatusRequest = { new DriverStatusRequest() @@ -121,25 +134,26 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { .setDriverId(driverId) } + /** Extract the URL portion of the master address. */ + protected override def getHttpUrl(master: String): URL = { + validateMaster(master) + new URL("http://" + master.stripPrefix("spark://")) + } + /** Throw an exception if this is not standalone mode. */ - override protected def validateMaster(master: String): Unit = { + private def validateMaster(master: String): Unit = { if (!master.startsWith("spark://")) { throw new IllegalArgumentException("This REST client is only supported in standalone mode.") } } - /** Throw an exception if this is not cluster deploy mode. */ - override protected def validateDeployMode(deployMode: String): Unit = { - if (deployMode != "cluster") { - throw new IllegalArgumentException("This REST client is only supported in cluster mode.") + /** Throw an exception if this is not standalone cluster mode. */ + private def validateSubmitArgs(args: SparkSubmitArguments): Unit = { + if (!args.isStandaloneCluster) { + throw new IllegalArgumentException( + "This REST client is only supported in standalone cluster mode.") } } - - /** Extract the URL portion of the master address. */ - override protected def getHttpUrl(master: String): URL = { - validateMaster(master) - new URL("http://" + master.stripPrefix("spark://")) - } } private object StandaloneRestClient { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 3fcfe189c6a1..1838647f6ed6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -24,23 +24,22 @@ import akka.actor.ActorRef import org.apache.spark.{SPARK_VERSION => sparkVersion} import org.apache.spark.SparkConf import org.apache.spark.util.{AkkaUtils, Utils} -import org.apache.spark.deploy.{Command, DriverDescription} +import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} import org.apache.spark.deploy.ClientArguments._ -import org.apache.spark.deploy.DeployMessages import org.apache.spark.deploy.master.Master /** - * A server that responds to requests submitted by the StandaloneRestClient. - * This is intended to be embedded in the standalone Master. Cluster mode only. + * A server that responds to requests submitted by the [[StandaloneRestClient]]. + * This is intended to be embedded in the standalone Master. Cluster mode only */ private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int) extends SubmitRestServer(host, requestedPort, master.conf) { - override protected val handler = new StandaloneRestServerHandler(master) + protected override val handler = new StandaloneRestServerHandler(master) } /** - * A handler for requests submitted to the standalone Master - * via the stable application submission REST protocol. + * A handler for requests submitted to the standalone + * Master via the REST application submission protocol. */ private[spark] class StandaloneRestServerHandler( conf: SparkConf, @@ -55,8 +54,7 @@ private[spark] class StandaloneRestServerHandler( } /** Handle a request to submit a driver. */ - override protected def handleSubmit( - request: SubmitDriverRequest): SubmitDriverResponse = { + protected override def handleSubmit(request: SubmitDriverRequest): SubmitDriverResponse = { val driverDescription = buildDriverDescription(request) val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) @@ -68,8 +66,7 @@ private[spark] class StandaloneRestServerHandler( } /** Handle a request to kill a driver. */ - override protected def handleKill( - request: KillDriverRequest): KillDriverResponse = { + protected override def handleKill(request: KillDriverRequest): KillDriverResponse = { val driverId = request.getDriverId val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( DeployMessages.RequestKillDriver(driverId), masterActor, askTimeout) @@ -81,16 +78,11 @@ private[spark] class StandaloneRestServerHandler( } /** Handle a request for a driver's status. */ - override protected def handleStatus( - request: DriverStatusRequest): DriverStatusResponse = { + protected override def handleStatus(request: DriverStatusRequest): DriverStatusResponse = { val driverId = request.getDriverId val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( DeployMessages.RequestDriverStatus(driverId), masterActor, askTimeout) - // Format exception nicely, if it exists - val message = response.exception.map { e => - val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") - s"Exception from the cluster:\n$e\n$stackTraceString" - } + val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } new DriverStatusResponse() .setSparkVersion(sparkVersion) .setDriverId(driverId) @@ -103,6 +95,7 @@ private[spark] class StandaloneRestServerHandler( /** * Build a driver description from the fields specified in the submit request. + * * This does not currently consider fields used by python applications since * python is not supported in standalone cluster mode yet. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala index 87132511587a..f2154b48f7d3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala @@ -25,21 +25,24 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.util.JsonProtocol +/** + * A request to submit a driver in the REST application submission protocol. + */ class SubmitDriverRequest extends SubmitRestProtocolRequest { - private val appName = new SubmitRestProtocolField[String] - private val appResource = new SubmitRestProtocolField[String] - private val mainClass = new SubmitRestProtocolField[String] - private val jars = new SubmitRestProtocolField[String] - private val files = new SubmitRestProtocolField[String] - private val pyFiles = new SubmitRestProtocolField[String] - private val driverMemory = new SubmitRestProtocolField[String] - private val driverCores = new SubmitRestProtocolField[Int] - private val driverExtraJavaOptions = new SubmitRestProtocolField[String] - private val driverExtraClassPath = new SubmitRestProtocolField[String] - private val driverExtraLibraryPath = new SubmitRestProtocolField[String] - private val superviseDriver = new SubmitRestProtocolField[Boolean] - private val executorMemory = new SubmitRestProtocolField[String] - private val totalExecutorCores = new SubmitRestProtocolField[Int] + private val appName = new SubmitRestProtocolField[String]("appName") + private val appResource = new SubmitRestProtocolField[String]("appResource") + private val mainClass = new SubmitRestProtocolField[String]("mainClass") + private val jars = new SubmitRestProtocolField[String]("jars") + private val files = new SubmitRestProtocolField[String]("files") + private val pyFiles = new SubmitRestProtocolField[String]("pyFiles") + private val driverMemory = new SubmitRestProtocolField[String]("driverMemory") + private val driverCores = new SubmitRestProtocolField[Int]("driverCores") + private val driverExtraJavaOptions = new SubmitRestProtocolField[String]("driverExtraJavaOptions") + private val driverExtraClassPath = new SubmitRestProtocolField[String]("driverExtraClassPath") + private val driverExtraLibraryPath = new SubmitRestProtocolField[String]("driverExtraLibraryPath") + private val superviseDriver = new SubmitRestProtocolField[Boolean]("superviseDriver") + private val executorMemory = new SubmitRestProtocolField[String]("executorMemory") + private val totalExecutorCores = new SubmitRestProtocolField[Int]("totalExecutorCores") // Special fields private val appArgs = new ArrayBuffer[String] @@ -101,30 +104,43 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest { envVars ++= JsonProtocol.mapFromJson(parse(s)) } + /** Return an array of arguments to be passed to the application. */ @JsonIgnore def getAppArgs: Array[String] = appArgs.toArray + + /** Return a map of Spark properties to be passed to the application as java options. */ @JsonIgnore def getSparkProperties: Map[String, String] = sparkProperties.toMap + + /** Return a map of environment variables to be passed to the application. */ @JsonIgnore def getEnvironmentVariables: Map[String, String] = envVars.toMap + + /** Add a command line argument to be passed to the application. */ @JsonIgnore def addAppArg(s: String): this.type = { appArgs += s; this } + + /** Set a Spark property to be passed to the application as a java option. */ @JsonIgnore def setSparkProperty(k: String, v: String): this.type = { sparkProperties(k) = v; this } + + /** Set an environment variable to be passed to the application. */ @JsonIgnore def setEnvironmentVariable(k: String, v: String): this.type = { envVars(k) = v; this } + /** Serialize the given Array to a compact JSON string. */ private def arrayToJson(arr: Array[String]): String = { if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else null } + /** Serialize the given Map to a compact JSON string. */ private def mapToJson(map: Map[String, String]): String = { if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else null } override def validate(): Unit = { super.validate() - assertFieldIsSet(appName, "app_name") - assertFieldIsSet(appResource, "app_resource") + assertFieldIsSet(appName) + assertFieldIsSet(appResource) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala index b1825af8ce56..a9adf3634f23 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -17,9 +17,12 @@ package org.apache.spark.deploy.rest +/** + * A response to the [[SubmitDriverRequest]] in the REST application submission protocol. + */ class SubmitDriverResponse extends SubmitRestProtocolResponse { - private val success = new SubmitRestProtocolField[Boolean] - private val driverId = new SubmitRestProtocolField[String] + private val success = new SubmitRestProtocolField[Boolean]("success") + private val driverId = new SubmitRestProtocolField[String]("driverId") def getSuccess: String = success.toString def getDriverId: String = driverId.toString @@ -29,6 +32,6 @@ class SubmitDriverResponse extends SubmitRestProtocolResponse { override def validate(): Unit = { super.validate() - assertFieldIsSet(success, "success") + assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index eb258290bdc7..a1be15c9fa5d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -28,14 +28,13 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.deploy.SparkSubmitArguments /** - * An abstract client that submits applications using the stable REST protocol. - * This client is intended to communicate with the SubmitRestServer. + * An abstract client that submits applications using the REST protocol. + * This client is intended to communicate with the [[SubmitRestServer]]. */ private[spark] abstract class SubmitRestClient extends Logging { - /** Request that the REST server submit a driver specified by the provided arguments. */ + /** Request that the REST server submit a driver using the provided arguments. */ def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { - validateSubmitArguments(args) val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) logInfo(s"Submitting a request to launch a driver in ${args.master}.") @@ -44,7 +43,6 @@ private[spark] abstract class SubmitRestClient extends Logging { /** Request that the REST server kill the specified driver. */ def killDriver(master: String, driverId: String): KillDriverResponse = { - validateMaster(master) val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) logInfo(s"Submitting a request to kill driver $driverId in $master.") @@ -53,7 +51,6 @@ private[spark] abstract class SubmitRestClient extends Logging { /** Request the status of the specified driver from the REST server. */ def requestDriverStatus(master: String, driverId: String): DriverStatusResponse = { - validateMaster(master) val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) logInfo(s"Submitting a request for the status of driver $driverId in $master.") @@ -68,17 +65,9 @@ private[spark] abstract class SubmitRestClient extends Logging { protected def constructKillRequest(master: String, driverId: String): KillDriverRequest protected def constructStatusRequest(master: String, driverId: String): DriverStatusRequest - // If the provided arguments are not as expected, throw an exception - protected def validateMaster(master: String): Unit - protected def validateDeployMode(deployMode: String): Unit - protected def validateSubmitArguments(args: SparkSubmitArguments): Unit = { - validateMaster(args.master) - validateDeployMode(args.deployMode) - } - /** * Send the provided request in an HTTP message to the given URL. - * This assumes both the request and the response use the JSON format. + * This assumes that both the request and the response use the JSON format. * Return the response received from the REST server. */ private def sendHttp(url: URL, request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { @@ -96,7 +85,7 @@ private[spark] abstract class SubmitRestClient extends Logging { out.close() val responseJson = Source.fromInputStream(conn.getInputStream).mkString logDebug(s"Response from the REST server:\n$responseJson") - SubmitRestProtocolResponse.fromJson(responseJson) + SubmitRestProtocolMessage.fromJson(responseJson).asInstanceOf[SubmitRestProtocolResponse] } catch { case e: FileNotFoundException => throw new SparkException(s"Unable to connect to REST server $url", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala index 33e4fe4d5c2b..2b52fd6bc44a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala @@ -17,10 +17,20 @@ package org.apache.spark.deploy.rest -class SubmitRestProtocolField[T] { +/** + * A field used in [[SubmitRestProtocolMessage]]s. + */ +class SubmitRestProtocolField[T](val name: String) { protected var value: Option[T] = None + + /** Return the value or throw an [[IllegalArgumentException]] if the value is not set. */ + def getValue: T = { + value.getOrElse { + throw new IllegalAccessException(s"Value not set in field '$name'!") + } + } + def isSet: Boolean = value.isDefined - def getValue: T = value.getOrElse { throw new IllegalAccessException("Value not set!") } def getValueOption: Option[T] = value def setValue(v: T): Unit = { value = Some(v) } def clearValue(): Unit = { value = None } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 0aa72d236e1a..2e92eb926d33 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -26,22 +26,29 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.util.Utils +/** + * An abstract message exchanged in the REST application submission protocol. + * + * This message is intended to be serialized to and deserialized from JSON in the exchange. + * Each message can either be a request or a response and consists of three common fields: + * (1) the action, which fully specifies the type of the message + * (2) the Spark version of the client / server + * (3) an optional message + */ @JsonInclude(Include.NON_NULL) @JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) @JsonPropertyOrder(alphabetic = true) abstract class SubmitRestProtocolMessage { - import SubmitRestProtocolMessage._ - private val messageType = Utils.getFormattedClassName(this) - protected val action: String = camelCaseToUnderscores(decapitalize(messageType)) - protected val sparkVersion = new SubmitRestProtocolField[String] - protected val message = new SubmitRestProtocolField[String] + protected val action: String = messageType + protected val sparkVersion: SubmitRestProtocolField[String] + protected val message = new SubmitRestProtocolField[String]("message") // Required for JSON de/serialization and not explicitly used private def getAction: String = action private def setAction(s: String): this.type = this - // Spark version implementation depends on whether this is a request or a response + // Intended for the user and not for JSON de/serialization, which expects more specific keys @JsonIgnore def getSparkVersion: String @JsonIgnore @@ -50,26 +57,37 @@ abstract class SubmitRestProtocolMessage { def getMessage: String = message.toString def setMessage(s: String): this.type = setField(message, s) + /** + * Serialize the message to JSON. + * This also ensures that the message is valid and its fields are in the expected format. + */ def toJson: String = { validate() val mapper = new ObjectMapper - val json = mapper.writeValueAsString(this) - postProcessJson(json) + pretty(parse(mapper.writeValueAsString(this))) } + /** Assert the validity of the message. */ def validate(): Unit = { assert(action != null, s"The action field is missing in $messageType!") + assertFieldIsSet(sparkVersion) } - protected def assertFieldIsSet(field: SubmitRestProtocolField[_], name: String): Unit = { - assert(field.isSet, s"The $name field is missing in $messageType!") + /** Assert that the specified field is set in this message. */ + protected def assertFieldIsSet(field: SubmitRestProtocolField[_]): Unit = { + assert(field.isSet, s"Field '${field.name}' is missing in $messageType!") } + /** Set the field to the given value, or clear the field if the value is null. */ protected def setField(field: SubmitRestProtocolField[String], value: String): this.type = { if (value == null) { field.clearValue() } else { field.setValue(value) } this } + /** + * Set the field to the given boolean value, or clear the field if the value is null. + * If the provided value does not represent a boolean, throw an exception. + */ protected def setBooleanField( field: SubmitRestProtocolField[Boolean], value: String): this.type = { @@ -77,6 +95,10 @@ abstract class SubmitRestProtocolMessage { this } + /** + * Set the field to the given numeric value, or clear the field if the value is null. + * If the provided value does not represent a numeric, throw an exception. + */ protected def setNumericField( field: SubmitRestProtocolField[Int], value: String): this.type = { @@ -84,6 +106,11 @@ abstract class SubmitRestProtocolMessage { this } + /** + * Set the field to the given memory value, or clear the field if the value is null. + * If the provided value does not represent a memory value, throw an exception. + * Valid examples of memory values include "512m", "24g", and "128000". + */ protected def setMemoryField( field: SubmitRestProtocolField[String], value: String): this.type = { @@ -91,116 +118,69 @@ abstract class SubmitRestProtocolMessage { setField(field, value) this } - - private def postProcessJson(json: String): String = { - val fields = parse(json).asInstanceOf[JObject].obj - val newFields = fields.map { case (k, v) => (camelCaseToUnderscores(k), v) } - pretty(render(JObject(newFields))) - } } +/** + * An abstract request sent from the client in the REST application submission protocol. + */ abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { + protected override val sparkVersion = new SubmitRestProtocolField[String]("client_spark_version") def getClientSparkVersion: String = sparkVersion.toString def setClientSparkVersion(s: String): this.type = setField(sparkVersion, s) override def getSparkVersion: String = getClientSparkVersion override def setSparkVersion(s: String) = setClientSparkVersion(s) - override def validate(): Unit = { - super.validate() - assertFieldIsSet(sparkVersion, "client_spark_version") - } } +/** + * An abstract response sent from the server in the REST application submission protocol. + */ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { + protected override val sparkVersion = new SubmitRestProtocolField[String]("server_spark_version") def getServerSparkVersion: String = sparkVersion.toString def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s) override def getSparkVersion: String = getServerSparkVersion override def setSparkVersion(s: String) = setServerSparkVersion(s) - override def validate(): Unit = { - super.validate() - assertFieldIsSet(sparkVersion, "server_spark_version") - } } object SubmitRestProtocolMessage { private val mapper = new ObjectMapper private val packagePrefix = this.getClass.getPackage.getName - def parseAction(json: String): String = { + /** Parse the value of the action field from the given JSON. */ + def parseAction(json: String): String = parseField(json, "action") + + /** Parse the value of the specified field from the given JSON. */ + def parseField(json: String, field: String): String = { parse(json).asInstanceOf[JObject].obj - .find { case (f, _) => f == "action" } + .find { case (f, _) => f == field } .map { case (_, v) => v.asInstanceOf[JString].s } .getOrElse { - throw new IllegalArgumentException(s"Could not find action field in message:\n$json") - } + throw new IllegalArgumentException(s"Could not find field '$field' in message:\n$json") + } } + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method first parses the action from the JSON and uses it to infers the message type. + * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in + * this package. Otherwise, a [[ClassNotFoundException]] will be thrown. + */ def fromJson(json: String): SubmitRestProtocolMessage = { - val action = parseAction(json) - val className = underscoresToCamelCase(action).capitalize + val className = parseAction(json) val clazz = Class.forName(packagePrefix + "." + className) .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage]) fromJson(json, clazz) } + /** + * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. + * + * This method determines the type of the message from the class provided instead of + * inferring it from the action field. This is useful for deserializing JSON that + * represents custom user-defined messages. + */ def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = { - val fields = parse(json).asInstanceOf[JObject].obj - val processedFields = fields.map { case (k, v) => (underscoresToCamelCase(k), v) } - val processedJson = compact(render(JObject(processedFields))) - mapper.readValue(processedJson, clazz) - } - - private def camelCaseToUnderscores(s: String): String = { - val newString = new StringBuilder - s.foreach { c => - if (c.isUpper) { - newString.append("_" + c.toLower) - } else { - newString.append(c) - } - } - newString.toString() - } - - private def underscoresToCamelCase(s: String): String = { - val newString = new StringBuilder - var capitalizeNext = false - s.foreach { c => - if (c == '_') { - capitalizeNext = true - } else { - val nextChar = if (capitalizeNext) c.toUpper else c - newString.append(nextChar) - capitalizeNext = false - } - } - newString.toString() - } - - private def decapitalize(s: String): String = { - if (s != null && s.nonEmpty) { - s(0).toLower + s.substring(1) - } else { - s - } - } -} - -object SubmitRestProtocolRequest { - def fromJson(s: String): SubmitRestProtocolRequest = { - SubmitRestProtocolMessage.fromJson(s) match { - case req: SubmitRestProtocolRequest => req - case res: SubmitRestProtocolResponse => - throw new IllegalArgumentException(s"Message was not a request:\n$s") - } - } -} - -object SubmitRestProtocolResponse { - def fromJson(s: String): SubmitRestProtocolResponse = { - SubmitRestProtocolMessage.fromJson(s) match { - case req: SubmitRestProtocolRequest => - throw new IllegalArgumentException(s"Message was not a response:\n$s") - case res: SubmitRestProtocolResponse => res - } + mapper.readValue(json, clazz) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 89a2b83d2cde..5d3fb70f8bcc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -32,8 +32,8 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, Logging, SparkConf} import org.apache.spark.util.Utils /** - * An abstract server that responds to requests submitted by the SubmitRestClient - * in the stable application submission REST protocol. + * An abstract server that responds to requests submitted by the + * [[SubmitRestClient]] in the REST application submission protocol. */ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, conf: SparkConf) extends Logging { @@ -66,8 +66,8 @@ private[spark] abstract class SubmitRestServer(host: String, requestedPort: Int, } /** - * An abstract handler for requests submitted via the stable application submission REST protocol. - * This represents the main handler used in the SubmitRestServer. + * An abstract handler for requests submitted via the REST application submission protocol. + * This represents the main handler used in the [[SubmitRestServer]]. */ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler with Logging { protected def handleSubmit(request: SubmitDriverRequest): SubmitDriverResponse @@ -75,8 +75,8 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi protected def handleStatus(request: DriverStatusRequest): DriverStatusResponse /** - * Handle a request submitted by the SubmitRestClient. - * This assumes both the request and the response use the JSON format. + * Handle a request submitted by the [[SubmitRestClient]]. + * This assumes that both the request and the response use the JSON format. */ override def handle( target: String, @@ -85,7 +85,8 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi response: HttpServletResponse): Unit = { try { val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString - val requestMessage = SubmitRestProtocolRequest.fromJson(requestMessageJson) + val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson) + .asInstanceOf[SubmitRestProtocolRequest] val responseMessage = constructResponseMessage(requestMessage) response.setContentType("application/json") response.setCharacterEncoding("utf-8") @@ -102,7 +103,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi /** * Construct the appropriate response message based on the type of the request message. - * If an IllegalArgumentException is thrown in the process, construct an error message instead. + * If an [[IllegalArgumentException]] is thrown, construct an error message instead. */ private def constructResponseMessage( request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { @@ -121,14 +122,15 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") } } catch { - case e: IllegalArgumentException => handleError(e.getMessage) + case e: IllegalArgumentException => handleError(formatException(e)) } // Validate the response message to ensure that it is correctly constructed. If it is not, // propagate the exception back to the client and signal that it is a server error. try { response.validate() } catch { - case e: IllegalArgumentException => handleError(s"Internal server error: ${e.getMessage}") + case e: IllegalArgumentException => + handleError("Internal server error: " + formatException(e)) } response } @@ -139,4 +141,10 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi .setSparkVersion(sparkVersion) .setMessage(message) } + + /** Return a human readable String representation of the exception. */ + protected def formatException(e: Exception): String = { + val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n") + s"$e\n$stackTraceString" + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index bc9639095d1b..4c81f5fabdc1 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -20,45 +20,11 @@ package org.apache.spark.deploy.rest import org.json4s.jackson.JsonMethods._ import org.scalatest.FunSuite -class DummyRequest extends SubmitRestProtocolRequest { - private val active = new SubmitRestProtocolField[Boolean] - private val age = new SubmitRestProtocolField[Int] - private val name = new SubmitRestProtocolField[String] - - def getActive: String = active.toString - def getAge: String = age.toString - def getName: String = name.toString - - def setActive(s: String): this.type = setBooleanField(active, s) - def setAge(s: String): this.type = setNumericField(age, s) - def setName(s: String): this.type = setField(name, s) - - override def validate(): Unit = { - super.validate() - assertFieldIsSet(name, "name") - assertFieldIsSet(age, "age") - assert(age.getValue > 5, "Not old enough!") - } -} - -class DummyResponse extends SubmitRestProtocolResponse - /** - * Tests for the stable application submission REST protocol. + * Tests for the REST application submission protocol. */ class SubmitRestProtocolSuite extends FunSuite { - /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */ - private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = { - val trimmedJson1 = jsonString1.trim - val trimmedJson2 = jsonString2.trim - val json1 = compact(render(parse(trimmedJson1))) - val json2 = compact(render(parse(trimmedJson2))) - // Put this on a separate line to avoid printing comparison twice when test fails - val equals = json1 == json2 - assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2)) - } - test("get and set fields") { val request = new DummyRequest assert(request.getSparkVersion === null) @@ -319,10 +285,10 @@ class SubmitRestProtocolSuite extends FunSuite { private val dummyRequestJson = """ |{ - | "action" : "dummy_request", + | "action" : "DummyRequest", | "active" : "true", | "age" : "25", - | "client_spark_version" : "1.2.3", + | "clientSparkVersion" : "1.2.3", | "name" : "jung" |} """.stripMargin @@ -330,42 +296,42 @@ class SubmitRestProtocolSuite extends FunSuite { private val dummyResponseJson = """ |{ - | "action" : "dummy_response", - | "server_spark_version" : "3.3.4" + | "action" : "DummyResponse", + | "serverSparkVersion" : "3.3.4" |} """.stripMargin private val submitDriverRequestJson = """ |{ - | "action" : "submit_driver_request", - | "app_args" : "[\"two slices\",\"a hint of cinnamon\"]", - | "app_name" : "SparkPie", - | "app_resource" : "honey-walnut-cherry.jar", - | "client_spark_version" : "1.2.3", - | "driver_cores" : "180", - | "driver_extra_class_path" : "food-coloring.jar", - | "driver_extra_java_options" : " -Dslices=5 -Dcolor=mostly_red", - | "driver_extra_library_path" : "pickle.jar", - | "driver_memory" : "512m", - | "environment_variables" : "{\"PATH\":\"/dev/null\",\"PYTHONPATH\":\"/dev/null\"}", - | "executor_memory" : "256m", + | "action" : "SubmitDriverRequest", + | "appArgs" : "[\"two slices\",\"a hint of cinnamon\"]", + | "appName" : "SparkPie", + | "appResource" : "honey-walnut-cherry.jar", + | "clientSparkVersion" : "1.2.3", + | "driverCores" : "180", + | "driverExtraClassPath" : "food-coloring.jar", + | "driverExtraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", + | "driverExtraLibraryPath" : "pickle.jar", + | "driverMemory" : "512m", + | "environmentVariables" : "{\"PATH\":\"/dev/null\",\"PYTHONPATH\":\"/dev/null\"}", + | "executorMemory" : "256m", | "files" : "fireball.png", | "jars" : "mayonnaise.jar,ketchup.jar", - | "main_class" : "org.apache.spark.examples.SparkPie", - | "py_files" : "do-not-eat-my.py", - | "spark_properties" : "{\"spark.live.long\":\"true\",\"spark.shuffle.enabled\":\"false\"}", - | "supervise_driver" : "false", - | "total_executor_cores" : "10000" + | "mainClass" : "org.apache.spark.examples.SparkPie", + | "pyFiles" : "do-not-eat-my.py", + | "sparkProperties" : "{\"spark.live.long\":\"true\",\"spark.shuffle.enabled\":\"false\"}", + | "superviseDriver" : "false", + | "totalExecutorCores" : "10000" |} """.stripMargin private val submitDriverResponseJson = """ |{ - | "action" : "submit_driver_response", - | "driver_id" : "driver_123", - | "server_spark_version" : "1.2.3", + | "action" : "SubmitDriverResponse", + | "driverId" : "driver_123", + | "serverSparkVersion" : "1.2.3", | "success" : "true" |} """.stripMargin @@ -373,18 +339,18 @@ class SubmitRestProtocolSuite extends FunSuite { private val killDriverRequestJson = """ |{ - | "action" : "kill_driver_request", - | "client_spark_version" : "1.2.3", - | "driver_id" : "driver_123" + | "action" : "KillDriverRequest", + | "clientSparkVersion" : "1.2.3", + | "driverId" : "driver_123" |} """.stripMargin private val killDriverResponseJson = """ |{ - | "action" : "kill_driver_response", - | "driver_id" : "driver_123", - | "server_spark_version" : "1.2.3", + | "action" : "KillDriverResponse", + | "driverId" : "driver_123", + | "serverSparkVersion" : "1.2.3", | "success" : "true" |} """.stripMargin @@ -392,31 +358,64 @@ class SubmitRestProtocolSuite extends FunSuite { private val driverStatusRequestJson = """ |{ - | "action" : "driver_status_request", - | "client_spark_version" : "1.2.3", - | "driver_id" : "driver_123" + | "action" : "DriverStatusRequest", + | "clientSparkVersion" : "1.2.3", + | "driverId" : "driver_123" |} """.stripMargin private val driverStatusResponseJson = """ |{ - | "action" : "driver_status_response", - | "driver_id" : "driver_123", - | "driver_state" : "RUNNING", - | "server_spark_version" : "1.2.3", + | "action" : "DriverStatusResponse", + | "driverId" : "driver_123", + | "driverState" : "RUNNING", + | "serverSparkVersion" : "1.2.3", | "success" : "true", - | "worker_host_port" : "1.2.3.4:7780", - | "worker_id" : "worker_123" + | "workerHostPort" : "1.2.3.4:7780", + | "workerId" : "worker_123" |} """.stripMargin private val errorJson = """ |{ - | "action" : "error_response", + | "action" : "ErrorResponse", | "message" : "Field not found in submit request: X", - | "server_spark_version" : "1.2.3" + | "serverSparkVersion" : "1.2.3" |} """.stripMargin + + /** Assert that the contents in the two JSON strings are equal after ignoring whitespace. */ + private def assertJsonEquals(jsonString1: String, jsonString2: String): Unit = { + val trimmedJson1 = jsonString1.trim + val trimmedJson2 = jsonString2.trim + val json1 = compact(render(parse(trimmedJson1))) + val json2 = compact(render(parse(trimmedJson2))) + // Put this on a separate line to avoid printing comparison twice when test fails + val equals = json1 == json2 + assert(equals, "\"[%s]\" did not equal \"[%s]\"".format(trimmedJson1, trimmedJson2)) + } +} + +private class DummyResponse extends SubmitRestProtocolResponse +private class DummyRequest extends SubmitRestProtocolRequest { + private val active = new SubmitRestProtocolField[Boolean]("active") + private val age = new SubmitRestProtocolField[Int]("age") + private val name = new SubmitRestProtocolField[String]("name") + + def getActive: String = active.toString + def getAge: String = age.toString + def getName: String = name.toString + + def setActive(s: String): this.type = setBooleanField(active, s) + def setAge(s: String): this.type = setNumericField(age, s) + def setName(s: String): this.type = setField(name, s) + + override def validate(): Unit = { + super.validate() + assertFieldIsSet(name) + assertFieldIsSet(age) + assert(age.getValue > 5, "Not old enough!") + } } From e2104e6943ffcfe003c583b3b88564f6ff48a5e2 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 29 Jan 2015 15:26:33 -0800 Subject: [PATCH 21/48] stable -> rest --- .../resources/org/apache/spark/ui/static/webui.css | 2 +- .../org/apache/spark/deploy/DeployMessage.scala | 4 ++-- .../scala/org/apache/spark/deploy/SparkSubmit.scala | 11 ++++++----- .../apache/spark/deploy/SparkSubmitArguments.scala | 2 +- .../org/apache/spark/deploy/master/Master.scala | 4 ++-- .../apache/spark/deploy/master/MasterMessages.scala | 2 +- .../apache/spark/deploy/master/ui/MasterPage.scala | 6 +++--- .../deploy/rest/StandaloneRestProtocolSuite.scala | 12 ++++++------ 8 files changed, 22 insertions(+), 21 deletions(-) diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 7d3abee8c956..68b33b5f0d7c 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -103,7 +103,7 @@ span.expand-details { float: right; } -span.stable-uri { +span.rest-uri { font-size: 10pt; font-style: italic; color: gray; diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index d95830d515ca..7f600d89604a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -151,7 +151,7 @@ private[deploy] object DeployMessages { case class MasterStateResponse( host: String, port: Int, - stablePort: Option[Int], + restPort: Option[Int], workers: Array[WorkerInfo], activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo], @@ -163,7 +163,7 @@ private[deploy] object DeployMessages { assert (port > 0) def uri = "spark://" + host + ":" + port - def stableUri: Option[String] = stablePort.map { p => "spark://" + host + ":" + p } + def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p } } // WorkerWebUI to Worker diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index fac37d5a748e..328ecc9e768c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -89,14 +89,15 @@ object SparkSubmit { } /** - * Kill an existing driver using the stable REST protocol. Standalone cluster mode only. + * Kill an existing driver using the REST application submission protocol. + * Standalone cluster mode only. */ private def kill(args: SparkSubmitArguments): Unit = { new StandaloneRestClient().killDriver(args.master, args.driverToKill) } /** - * Request the status of an existing driver using the stable REST protocol. + * Request the status of an existing driver using the REST application submission protocol. * Standalone cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { @@ -112,7 +113,7 @@ object SparkSubmit { * Second, we use this launch environment to invoke the main method of the child * main class. * - * As of Spark 1.3, a stable REST-based application submission gateway is introduced. + * As of Spark 1.3, a REST-based application submission gateway is introduced. * If this is enabled, then we will run standalone cluster mode by passing the submit * parameters directly to a REST client, which will submit the application using the * REST protocol instead. @@ -120,7 +121,7 @@ object SparkSubmit { private[spark] def submit(args: SparkSubmitArguments): Unit = { val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) if (args.isStandaloneCluster && args.isRestEnabled) { - printStream.println("Running standalone cluster mode using the stable REST protocol.") + printStream.println("Running Spark using the REST application submission protocol.") new StandaloneRestClient().submitDriver(args) } else { runMain(childArgs, childClasspath, sysProps, childMainClass) @@ -305,7 +306,7 @@ object SparkSubmit { } // In standalone-cluster mode, use Client as a wrapper around the user class - // Note that we won't actually launch this class if we're using the stable REST protocol + // Note that we won't actually launch this class if we're using the REST protocol if (args.isStandaloneCluster && !args.isRestEnabled) { childMainClass = "org.apache.spark.deploy.Client" if (args.supervise) { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 0032759d7577..19ac58df04e8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -256,7 +256,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St master.startsWith("spark://") && deployMode == "cluster" } - /** Return whether the stable application submission REST gateway is enabled. */ + /** Return whether the REST application submission protocol is enabled. */ def isRestEnabled: Boolean = { sparkProperties.get("spark.submit.rest.enabled").getOrElse("false").toBoolean } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9b5286255985..8525c03fe33d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -126,7 +126,7 @@ private[spark] class Master( private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true) private val restServer = if (restServerEnabled) { - val port = conf.getInt("spark.master.rest.port", 17077) + val port = conf.getInt("spark.master.rest.port", 6066) Some(new StandaloneRestServer(this, host, port)) } else { None @@ -910,6 +910,6 @@ private[spark] object Master extends Logging { val timeout = AkkaUtils.askTimeout(conf) val portsRequest = actor.ask(BoundPortsRequest)(timeout) val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] - (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.stablePort) + (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index ca9f93edca25..15c6296888f7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -38,5 +38,5 @@ private[master] object MasterMessages { case object BoundPortsRequest - case class BoundPortsResponse(actorPort: Int, webUIPort: Int, stablePort: Option[Int]) + case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int]) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 16530d60fd03..49628d704696 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -74,10 +74,10 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
    • URL: {state.uri}
    • { - state.stableUri.map { uri => + state.restUri.map { uri =>
    • - Stable URL: {uri} - (for standalone cluster mode in Spark 1.3+) + REST URL: {uri} + (for standalone cluster mode in Spark 1.3+)
    • }.getOrElse { Seq.empty } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index 11e49077d893..414122bff09f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -36,7 +36,7 @@ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.Worker /** - * End-to-end tests for the stable application submission protocol in standalone mode. + * End-to-end tests for the REST application submission protocol in standalone mode. */ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfterEach { private val systemsToStop = new ArrayBuffer[ActorSystem] @@ -89,7 +89,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B /** * Start a local cluster containing one Master and a few Workers. - * Do not use org.apache.spark.deploy.LocalCluster here because we want the REST URL. + * Do not use [[org.apache.spark.deploy.LocalSparkCluster]] here because we want the REST URL. * Return the Master's REST URL to which applications should be submitted. */ private def startLocalCluster(): String = { @@ -112,7 +112,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B masterRestUrl } - /** Submit the StandaloneRestApp and return the corresponding driver ID. */ + /** Submit the [[StandaloneRestApp]] and return the corresponding driver ID. */ private def submitApplication(resultsFile: File, numbers: Seq[Int], size: Int): String = { val appArgs = Seq(resultsFile.getAbsolutePath) ++ numbers.map(_.toString) ++ Seq(size.toString) val commandLineArgs = Array( @@ -164,7 +164,7 @@ private object StandaloneRestProtocolSuite { private val pathPrefix = "org/apache/spark/deploy/rest" /** - * Create a jar that contains all the class files needed for running the StandaloneRestApp. + * Create a jar that contains all the class files needed for running the [[StandaloneRestApp]]. * Return the absolute path to that jar. */ def createJar(): String = { @@ -184,7 +184,7 @@ private object StandaloneRestProtocolSuite { } /** - * Return a list of class files compiled for StandaloneRestApp. + * Return a list of class files compiled for [[StandaloneRestApp]]. * This includes all the anonymous classes used in the application. */ private def getClassFiles: Seq[File] = { @@ -197,7 +197,7 @@ private object StandaloneRestProtocolSuite { } /** - * Sample application to be submitted to the cluster using the stable gateway. + * Sample application to be submitted to the cluster using the REST gateway. * All relevant classes will be packaged into a jar at run time. */ object StandaloneRestApp { From 9581df73d1fa8134c349e13251bf3c2a4daa4055 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 29 Jan 2015 16:41:54 -0800 Subject: [PATCH 22/48] Clean up uses of exceptions --- .../deploy/rest/DriverStatusRequest.scala | 4 +- .../deploy/rest/DriverStatusResponse.scala | 4 +- .../spark/deploy/rest/ErrorResponse.scala | 4 +- .../spark/deploy/rest/KillDriverRequest.scala | 4 +- .../deploy/rest/KillDriverResponse.scala | 4 +- .../deploy/rest/SubmitDriverRequest.scala | 4 +- .../deploy/rest/SubmitDriverResponse.scala | 4 +- .../deploy/rest/SubmitRestProtocolField.scala | 10 +-- .../rest/SubmitRestProtocolMessage.scala | 65 +++++++++++++------ .../spark/deploy/rest/SubmitRestServer.scala | 7 +- .../rest/StandaloneRestProtocolSuite.scala | 1 + .../deploy/rest/SubmitRestProtocolSuite.scala | 34 +++++----- 12 files changed, 80 insertions(+), 65 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala index ec0e197cfa34..9c925548b0e4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -26,8 +26,8 @@ class DriverStatusRequest extends SubmitRestProtocolRequest { def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) - override def validate(): Unit = { - super.validate() + protected override def doValidate(): Unit = { + super.doValidate() assertFieldIsSet(driverId) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala index 2819ef50a75d..de6b24b0e80c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -40,8 +40,8 @@ class DriverStatusResponse extends SubmitRestProtocolResponse { def setWorkerId(s: String): this.type = setField(workerId, s) def setWorkerHostPort(s: String): this.type = setField(workerHostPort, s) - override def validate(): Unit = { - super.validate() + protected override def doValidate(): Unit = { + super.doValidate() assertFieldIsSet(driverId) assertFieldIsSet(success) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala index 5388cddb070f..b7fcc97ea2a8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -21,8 +21,8 @@ package org.apache.spark.deploy.rest * An error response message used in the REST application submission protocol. */ class ErrorResponse extends SubmitRestProtocolResponse { - override def validate(): Unit = { - super.validate() + protected override def doValidate(): Unit = { + super.doValidate() assertFieldIsSet(message) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala index 97f5dd2ba822..764f3e7753ae 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -26,8 +26,8 @@ class KillDriverRequest extends SubmitRestProtocolRequest { def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) - override def validate(): Unit = { - super.validate() + protected override def doValidate(): Unit = { + super.doValidate() assertFieldIsSet(driverId) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala index fe68800e9980..790527cc67a9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -30,8 +30,8 @@ class KillDriverResponse extends SubmitRestProtocolResponse { def setDriverId(s: String): this.type = setField(driverId, s) def setSuccess(s: String): this.type = setBooleanField(success, s) - override def validate(): Unit = { - super.validate() + protected override def doValidate(): Unit = { + super.doValidate() assertFieldIsSet(driverId) assertFieldIsSet(success) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala index f2154b48f7d3..b083d27e7901 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala @@ -138,8 +138,8 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest { if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else null } - override def validate(): Unit = { - super.validate() + protected override def doValidate(): Unit = { + super.doValidate() assertFieldIsSet(appName) assertFieldIsSet(appResource) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala index a9adf3634f23..39f8a67aea98 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -30,8 +30,8 @@ class SubmitDriverResponse extends SubmitRestProtocolResponse { def setSuccess(s: String): this.type = setBooleanField(success, s) def setDriverId(s: String): this.type = setField(driverId, s) - override def validate(): Unit = { - super.validate() + protected override def doValidate(): Unit = { + super.doValidate() assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala index 2b52fd6bc44a..3932e68fcd2b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala @@ -22,16 +22,8 @@ package org.apache.spark.deploy.rest */ class SubmitRestProtocolField[T](val name: String) { protected var value: Option[T] = None - - /** Return the value or throw an [[IllegalArgumentException]] if the value is not set. */ - def getValue: T = { - value.getOrElse { - throw new IllegalAccessException(s"Value not set in field '$name'!") - } - } - def isSet: Boolean = value.isDefined - def getValueOption: Option[T] = value + def getValue: Option[T] = value def setValue(v: T): Unit = { value = Some(v) } def clearValue(): Unit = { value = None } override def toString: String = value.map(_.toString).orNull diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 2e92eb926d33..f527c76a02ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -67,20 +67,42 @@ abstract class SubmitRestProtocolMessage { pretty(parse(mapper.writeValueAsString(this))) } - /** Assert the validity of the message. */ - def validate(): Unit = { - assert(action != null, s"The action field is missing in $messageType!") + /** + * Assert the validity of the message. + * If the validation fails, throw a [[SubmitRestValidationException]]. + */ + final def validate(): Unit = { + try { + doValidate() + } catch { + case e: Exception => + throw new SubmitRestValidationException( + s"Validation of message $messageType failed!", e) + } + } + + /** Assert the validity of the message */ + protected def doValidate(): Unit = { + assert(action != null, s"The action field is missing.") assertFieldIsSet(sparkVersion) } /** Assert that the specified field is set in this message. */ protected def assertFieldIsSet(field: SubmitRestProtocolField[_]): Unit = { - assert(field.isSet, s"Field '${field.name}' is missing in $messageType!") + assert(field.isSet, s"Field '${field.name}' is missing.") + } + + /** + * Assert a condition when validating this message. + * If the assertion fails, throw a [[SubmitRestValidationException]]. + */ + protected def assert(condition: Boolean, failMessage: String): Unit = { + if (!condition) { throw new SubmitRestValidationException(failMessage) } } /** Set the field to the given value, or clear the field if the value is null. */ - protected def setField(field: SubmitRestProtocolField[String], value: String): this.type = { - if (value == null) { field.clearValue() } else { field.setValue(value) } + protected def setField(f: SubmitRestProtocolField[String], v: String): this.type = { + if (v == null) { f.clearValue() } else { f.setValue(v) } this } @@ -88,10 +110,8 @@ abstract class SubmitRestProtocolMessage { * Set the field to the given boolean value, or clear the field if the value is null. * If the provided value does not represent a boolean, throw an exception. */ - protected def setBooleanField( - field: SubmitRestProtocolField[Boolean], - value: String): this.type = { - if (value == null) { field.clearValue() } else { field.setValue(value.toBoolean) } + protected def setBooleanField(f: SubmitRestProtocolField[Boolean], v: String): this.type = { + if (v == null) { f.clearValue() } else { f.setValue(v.toBoolean) } this } @@ -99,10 +119,8 @@ abstract class SubmitRestProtocolMessage { * Set the field to the given numeric value, or clear the field if the value is null. * If the provided value does not represent a numeric, throw an exception. */ - protected def setNumericField( - field: SubmitRestProtocolField[Int], - value: String): this.type = { - if (value == null) { field.clearValue() } else { field.setValue(value.toInt) } + protected def setNumericField(f: SubmitRestProtocolField[Int], v: String): this.type = { + if (v == null) { f.clearValue() } else { f.setValue(v.toInt) } this } @@ -111,12 +129,9 @@ abstract class SubmitRestProtocolMessage { * If the provided value does not represent a memory value, throw an exception. * Valid examples of memory values include "512m", "24g", and "128000". */ - protected def setMemoryField( - field: SubmitRestProtocolField[String], - value: String): this.type = { - Utils.memoryStringToMb(value) - setField(field, value) - this + protected def setMemoryField(f: SubmitRestProtocolField[String], v: String): this.type = { + Utils.memoryStringToMb(v) + setField(f, v) } } @@ -142,6 +157,14 @@ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { override def setSparkVersion(s: String) = setServerSparkVersion(s) } +/** + * An exception thrown if the validation of a [[SubmitRestProtocolMessage]] fails. + */ +class SubmitRestValidationException( + message: String, + cause: Exception = null) + extends Exception(message, cause) + object SubmitRestProtocolMessage { private val mapper = new ObjectMapper private val packagePrefix = this.getClass.getPackage.getName @@ -162,7 +185,7 @@ object SubmitRestProtocolMessage { /** * Construct a [[SubmitRestProtocolMessage]] from its JSON representation. * - * This method first parses the action from the JSON and uses it to infers the message type. + * This method first parses the action from the JSON and uses it to infer the message type. * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in * this package. Otherwise, a [[ClassNotFoundException]] will be thrown. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 5d3fb70f8bcc..4856381a3e09 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -103,7 +103,7 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi /** * Construct the appropriate response message based on the type of the request message. - * If an [[IllegalArgumentException]] is thrown, construct an error message instead. + * If an exception is thrown, construct an error message instead. */ private def constructResponseMessage( request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = { @@ -122,15 +122,14 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.") } } catch { - case e: IllegalArgumentException => handleError(formatException(e)) + case e: Exception => handleError(formatException(e)) } // Validate the response message to ensure that it is correctly constructed. If it is not, // propagate the exception back to the client and signal that it is a server error. try { response.validate() } catch { - case e: IllegalArgumentException => - handleError("Internal server error: " + formatException(e)) + case e: Exception => handleError("Internal server error: " + formatException(e)) } response } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index 414122bff09f..59341ebd5d60 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -78,6 +78,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B assert(killSuccess === "true") assert(statusSuccess === "true") assert(driverState === DriverState.KILLED.toString) + // we should not see the expected results because we killed the driver intercept[TestFailedException] { validateResult(resultsFile, numbers, size) } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 4c81f5fabdc1..2820f1343b4b 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -72,22 +72,22 @@ class SubmitRestProtocolSuite extends FunSuite { test("validate") { val request = new DummyRequest - intercept[AssertionError] { request.validate() } // missing everything + intercept[SubmitRestValidationException] { request.validate() } // missing everything request.setSparkVersion("1.4.8") - intercept[AssertionError] { request.validate() } // missing name and age + intercept[SubmitRestValidationException] { request.validate() } // missing name and age request.setName("something") - intercept[AssertionError] { request.validate() } // missing only age + intercept[SubmitRestValidationException] { request.validate() } // missing only age request.setAge("2") - intercept[AssertionError] { request.validate() } // age too low + intercept[SubmitRestValidationException] { request.validate() } // age too low request.setAge("10") request.validate() // everything is set request.setSparkVersion(null) - intercept[AssertionError] { request.validate() } // missing only Spark version + intercept[SubmitRestValidationException] { request.validate() } // missing only Spark version request.setSparkVersion("1.2.3") request.setName(null) - intercept[AssertionError] { request.validate() } // missing only name + intercept[SubmitRestValidationException] { request.validate() } // missing only name request.setMessage("not-setting-name") - intercept[AssertionError] { request.validate() } // still missing name + intercept[SubmitRestValidationException] { request.validate() } // still missing name } test("request to and from JSON") { @@ -119,7 +119,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("SubmitDriverRequest") { val message = new SubmitDriverRequest - intercept[AssertionError] { message.validate() } + intercept[SubmitRestValidationException] { message.validate() } intercept[IllegalArgumentException] { message.setDriverCores("one hundred feet") } intercept[IllegalArgumentException] { message.setSuperviseDriver("nope, never") } intercept[IllegalArgumentException] { message.setTotalExecutorCores("two men") } @@ -181,7 +181,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("SubmitDriverResponse") { val message = new SubmitDriverResponse - intercept[AssertionError] { message.validate() } + intercept[SubmitRestValidationException] { message.validate() } intercept[IllegalArgumentException] { message.setSuccess("maybe not") } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") @@ -199,7 +199,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("KillDriverRequest") { val message = new KillDriverRequest - intercept[AssertionError] { message.validate() } + intercept[SubmitRestValidationException] { message.validate() } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") message.validate() @@ -214,7 +214,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("KillDriverResponse") { val message = new KillDriverResponse - intercept[AssertionError] { message.validate() } + intercept[SubmitRestValidationException] { message.validate() } intercept[IllegalArgumentException] { message.setSuccess("maybe not") } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") @@ -232,7 +232,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("DriverStatusRequest") { val message = new DriverStatusRequest - intercept[AssertionError] { message.validate() } + intercept[SubmitRestValidationException] { message.validate() } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") message.validate() @@ -247,7 +247,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("DriverStatusResponse") { val message = new DriverStatusResponse - intercept[AssertionError] { message.validate() } + intercept[SubmitRestValidationException] { message.validate() } intercept[IllegalArgumentException] { message.setSuccess("maybe") } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") @@ -269,7 +269,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("ErrorResponse") { val message = new ErrorResponse - intercept[AssertionError] { message.validate() } + intercept[SubmitRestValidationException] { message.validate() } message.setSparkVersion("1.2.3") message.setMessage("Field not found in submit request: X") message.validate() @@ -412,10 +412,10 @@ private class DummyRequest extends SubmitRestProtocolRequest { def setAge(s: String): this.type = setNumericField(age, s) def setName(s: String): this.type = setField(name, s) - override def validate(): Unit = { - super.validate() + protected override def doValidate(): Unit = { + super.doValidate() assertFieldIsSet(name) assertFieldIsSet(age) - assert(age.getValue > 5, "Not old enough!") + assert(age.getValue.get > 5, "Not old enough!") } } From e2f7f5fa16f64b0e6f1c5c551074fb944139c826 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 30 Jan 2015 10:10:46 -0800 Subject: [PATCH 23/48] Provide more safeguard against missing fields This commit adds more null checks for fields that are not required by validation. It also cleans up the uses of exceptions further by introducing exceptions specific to the REST protocol. Additionally, the client used to just cast responses from the server to the expected type. This is not correct if the server responds with an error, however. This is now fixed. --- .../deploy/rest/StandaloneRestClient.scala | 37 ++++++++++++---- .../deploy/rest/StandaloneRestServer.scala | 3 ++ .../spark/deploy/rest/SubmitRestClient.scala | 36 ++++++++++++---- .../rest/SubmitRestProtocolException.scala | 29 +++++++++++++ .../rest/SubmitRestProtocolMessage.scala | 42 +++++++++---------- .../spark/deploy/rest/SubmitRestServer.scala | 3 +- .../rest/StandaloneRestProtocolSuite.scala | 33 +++++++++++---- .../deploy/rest/SubmitRestProtocolSuite.scala | 31 +++++++------- 8 files changed, 155 insertions(+), 59 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 6f2752c848a0..4f893febf744 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -36,9 +36,10 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { * just submitted and reports it to the user. Otherwise, if the submission was unsuccessful, * this reports failure and logs an error message provided by the REST server. */ - override def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { + override def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolResponse = { validateSubmitArgs(args) - val submitResponse = super.submitDriver(args) + val response = super.submitDriver(args) + val submitResponse = getResponse[SubmitDriverResponse](response).getOrElse { return response } val submitSuccess = submitResponse.getSuccess.toBoolean if (submitSuccess) { val driverId = submitResponse.getDriverId @@ -52,13 +53,13 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { } /** Request that the REST server kill the specified driver. */ - override def killDriver(master: String, driverId: String): KillDriverResponse = { + override def killDriver(master: String, driverId: String): SubmitRestProtocolResponse = { validateMaster(master) super.killDriver(master, driverId) } /** Request the status of the specified driver from the REST server. */ - override def requestDriverStatus(master: String, driverId: String): DriverStatusResponse = { + override def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolResponse = { validateMaster(master) super.requestDriverStatus(master, driverId) } @@ -69,14 +70,19 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { */ private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => - val statusResponse = requestDriverStatus(master, driverId) + val response = requestDriverStatus(master, driverId) + val statusResponse = getResponse[DriverStatusResponse](response).getOrElse { return } val statusSuccess = statusResponse.getSuccess.toBoolean if (statusSuccess) { - val driverState = statusResponse.getDriverState + val driverState = Option(statusResponse.getDriverState) val workerId = Option(statusResponse.getWorkerId) val workerHostPort = Option(statusResponse.getWorkerHostPort) val exception = Option(statusResponse.getMessage) - logInfo(s"State of driver $driverId is now $driverState.") + // Log driver state, if present + driverState match { + case Some(state) => logInfo(s"State of driver $driverId is now $state.") + case _ => logError(s"State of driver $driverId was not found!") + } // Log worker node, if present (workerId, workerHostPort) match { case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.") @@ -154,6 +160,23 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { "This REST client is only supported in standalone cluster mode.") } } + + /** + * Return the response as the expected type, or fail with an informative error message. + * Exposed for testing. + */ + private[spark] def getResponse[T <: SubmitRestProtocolResponse]( + response: SubmitRestProtocolResponse): Option[T] = { + try { + // Do not match on type T because types are erased at runtime + // Instead, manually try to cast it to type T ourselves + Some(response.asInstanceOf[T]) + } catch { + case e: ClassCastException => + logError(s"Server returned response of unexpected type:\n${response.toJson}") + None + } + } } private object StandaloneRestClient { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 1838647f6ed6..802933bf3f2a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -104,6 +104,9 @@ private[spark] class StandaloneRestServerHandler( val appName = request.getAppName val appResource = request.getAppResource val mainClass = request.getMainClass + if (mainClass == null) { + throw new SubmitRestMissingFieldException("Main class must be set in submit request.") + } // Optional fields val jars = Option(request.getJars) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index a1be15c9fa5d..b3864db7dc67 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -34,27 +34,30 @@ import org.apache.spark.deploy.SparkSubmitArguments private[spark] abstract class SubmitRestClient extends Logging { /** Request that the REST server submit a driver using the provided arguments. */ - def submitDriver(args: SparkSubmitArguments): SubmitDriverResponse = { + def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to launch a driver in ${args.master}.") val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) - logInfo(s"Submitting a request to launch a driver in ${args.master}.") - sendHttp(url, request).asInstanceOf[SubmitDriverResponse] + val response = sendHttp(url, request) + handleResponse(response) } /** Request that the REST server kill the specified driver. */ - def killDriver(master: String, driverId: String): KillDriverResponse = { + def killDriver(master: String, driverId: String): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request to kill driver $driverId in $master.") val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) - logInfo(s"Submitting a request to kill driver $driverId in $master.") - sendHttp(url, request).asInstanceOf[KillDriverResponse] + val response = sendHttp(url, request) + handleResponse(response) } /** Request the status of the specified driver from the REST server. */ - def requestDriverStatus(master: String, driverId: String): DriverStatusResponse = { + def requestDriverStatus(master: String, driverId: String): SubmitRestProtocolResponse = { + logInfo(s"Submitting a request for the status of driver $driverId in $master.") val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) - logInfo(s"Submitting a request for the status of driver $driverId in $master.") - sendHttp(url, request).asInstanceOf[DriverStatusResponse] + val response = sendHttp(url, request) + handleResponse(response) } /** Return the HTTP URL of the REST server that corresponds to the given master URL. */ @@ -91,4 +94,19 @@ private[spark] abstract class SubmitRestClient extends Logging { throw new SparkException(s"Unable to connect to REST server $url", e) } } + + /** Validate the response and log any error messages produced by the server. */ + private def handleResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = { + try { + response.validate() + response match { + case error: ErrorResponse => logError(s"Server returned error:\n${error.getMessage}") + case _ => + } + } catch { + case e: SubmitRestProtocolException => + throw new SubmitRestProtocolException("Malformed response received from server", e) + } + response + } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala new file mode 100644 index 000000000000..f1aa68e790f5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.rest + +/** + * An exception thrown in the REST application submission protocol. + */ +class SubmitRestProtocolException(message: String, cause: Exception = null) + extends Exception(message, cause) + +/** + * An exception thrown if a field is missing from a [[SubmitRestProtocolMessage]]. + */ +class SubmitRestMissingFieldException(message: String) extends SubmitRestProtocolException(message) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index f527c76a02ed..ff4e74f1ad91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -69,35 +69,39 @@ abstract class SubmitRestProtocolMessage { /** * Assert the validity of the message. - * If the validation fails, throw a [[SubmitRestValidationException]]. + * If the validation fails, throw a [[SubmitRestProtocolException]]. */ final def validate(): Unit = { try { doValidate() } catch { case e: Exception => - throw new SubmitRestValidationException( - s"Validation of message $messageType failed!", e) + throw new SubmitRestProtocolException(s"Validation of message $messageType failed!", e) } } /** Assert the validity of the message */ protected def doValidate(): Unit = { - assert(action != null, s"The action field is missing.") + if (action == null) { + throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType") + } assertFieldIsSet(sparkVersion) } /** Assert that the specified field is set in this message. */ protected def assertFieldIsSet(field: SubmitRestProtocolField[_]): Unit = { - assert(field.isSet, s"Field '${field.name}' is missing.") + if (!field.isSet) { + throw new SubmitRestMissingFieldException( + s"Field '${field.name}' is missing in message $messageType.") + } } /** * Assert a condition when validating this message. - * If the assertion fails, throw a [[SubmitRestValidationException]]. + * If the assertion fails, throw a [[SubmitRestProtocolException]]. */ protected def assert(condition: Boolean, failMessage: String): Unit = { - if (!condition) { throw new SubmitRestValidationException(failMessage) } + if (!condition) { throw new SubmitRestProtocolException(failMessage) } } /** Set the field to the given value, or clear the field if the value is null. */ @@ -157,29 +161,25 @@ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { override def setSparkVersion(s: String) = setServerSparkVersion(s) } -/** - * An exception thrown if the validation of a [[SubmitRestProtocolMessage]] fails. - */ -class SubmitRestValidationException( - message: String, - cause: Exception = null) - extends Exception(message, cause) - object SubmitRestProtocolMessage { private val mapper = new ObjectMapper private val packagePrefix = this.getClass.getPackage.getName - /** Parse the value of the action field from the given JSON. */ - def parseAction(json: String): String = parseField(json, "action") + /** + * Parse the value of the action field from the given JSON. + * If the action field is not found, throw a [[SubmitRestMissingFieldException]]. + */ + def parseAction(json: String): String = { + parseField(json, "action").getOrElse { + throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json") + } + } /** Parse the value of the specified field from the given JSON. */ - def parseField(json: String, field: String): String = { + def parseField(json: String, field: String): Option[String] = { parse(json).asInstanceOf[JObject].obj .find { case (f, _) => f == field } .map { case (_, v) => v.asInstanceOf[JString].s } - .getOrElse { - throw new IllegalArgumentException(s"Could not find field '$field' in message:\n$json") - } } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 4856381a3e09..027f641aa8f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -129,7 +129,8 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi try { response.validate() } catch { - case e: Exception => handleError("Internal server error: " + formatException(e)) + case e: Exception => + return handleError("Internal server error: " + formatException(e)) } response } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index 59341ebd5d60..9a5aef0a4221 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -59,7 +59,8 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B } test("kill empty driver") { - val killResponse = client.killDriver(masterRestUrl, "driver-that-does-not-exist") + val response = client.killDriver(masterRestUrl, "driver-that-does-not-exist") + val killResponse = getResponse[KillDriverResponse](response, client) val killSuccess = killResponse.getSuccess assert(killSuccess === "false") } @@ -69,10 +70,12 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val numbers = Seq(1, 2, 3) val size = 500 val driverId = submitApplication(resultsFile, numbers, size) - val killResponse = client.killDriver(masterRestUrl, driverId) + val response = client.killDriver(masterRestUrl, driverId) + val killResponse = getResponse[KillDriverResponse](response, client) val killSuccess = killResponse.getSuccess waitUntilFinished(driverId) - val statusResponse = client.requestDriverStatus(masterRestUrl, driverId) + val response2 = client.requestDriverStatus(masterRestUrl, driverId) + val statusResponse = getResponse[DriverStatusResponse](response2, client) val statusSuccess = statusResponse.getSuccess val driverState = statusResponse.getDriverState assert(killSuccess === "true") @@ -83,7 +86,8 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B } test("request status for empty driver") { - val statusResponse = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist") + val response = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist") + val statusResponse = getResponse[DriverStatusResponse](response, client) val statusSuccess = statusResponse.getSuccess assert(statusSuccess === "false") } @@ -125,7 +129,8 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) SparkSubmit.prepareSubmitEnvironment(args) - val submitResponse = client.submitDriver(args) + val response = client.submitDriver(args) + val submitResponse = getResponse[SubmitDriverResponse](response, client) submitResponse.getDriverId } @@ -134,7 +139,8 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B var finished = false val expireTime = System.currentTimeMillis + maxSeconds * 1000 while (!finished) { - val statusResponse = client.requestDriverStatus(masterRestUrl, driverId) + val response = client.requestDriverStatus(masterRestUrl, driverId) + val statusResponse = getResponse[DriverStatusResponse](response, client) val driverState = statusResponse.getDriverState finished = driverState != DriverState.SUBMITTED.toString && @@ -142,7 +148,20 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B if (System.currentTimeMillis > expireTime) { fail(s"Driver $driverId did not finish within $maxSeconds seconds.") } - Thread.sleep(1000) + } + } + + /** Return the response as the expected type, or fail with an informative error message. */ + private def getResponse[T <: SubmitRestProtocolResponse]( + response: SubmitRestProtocolResponse, + client: StandaloneRestClient): T = { + response match { + case error: ErrorResponse => + fail(s"Error from the server:\n${error.getMessage}") + case _ => + client.getResponse[T](response).getOrElse { + fail(s"Response type was unexpected: ${response.toJson}") + } } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 2820f1343b4b..8c5cf6baee45 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -72,22 +72,22 @@ class SubmitRestProtocolSuite extends FunSuite { test("validate") { val request = new DummyRequest - intercept[SubmitRestValidationException] { request.validate() } // missing everything + intercept[SubmitRestProtocolException] { request.validate() } // missing everything request.setSparkVersion("1.4.8") - intercept[SubmitRestValidationException] { request.validate() } // missing name and age + intercept[SubmitRestProtocolException] { request.validate() } // missing name and age request.setName("something") - intercept[SubmitRestValidationException] { request.validate() } // missing only age + intercept[SubmitRestProtocolException] { request.validate() } // missing only age request.setAge("2") - intercept[SubmitRestValidationException] { request.validate() } // age too low + intercept[SubmitRestProtocolException] { request.validate() } // age too low request.setAge("10") request.validate() // everything is set request.setSparkVersion(null) - intercept[SubmitRestValidationException] { request.validate() } // missing only Spark version + intercept[SubmitRestProtocolException] { request.validate() } // missing only Spark version request.setSparkVersion("1.2.3") request.setName(null) - intercept[SubmitRestValidationException] { request.validate() } // missing only name + intercept[SubmitRestProtocolException] { request.validate() } // missing only name request.setMessage("not-setting-name") - intercept[SubmitRestValidationException] { request.validate() } // still missing name + intercept[SubmitRestProtocolException] { request.validate() } // still missing name } test("request to and from JSON") { @@ -119,7 +119,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("SubmitDriverRequest") { val message = new SubmitDriverRequest - intercept[SubmitRestValidationException] { message.validate() } + intercept[SubmitRestProtocolException] { message.validate() } intercept[IllegalArgumentException] { message.setDriverCores("one hundred feet") } intercept[IllegalArgumentException] { message.setSuperviseDriver("nope, never") } intercept[IllegalArgumentException] { message.setTotalExecutorCores("two men") } @@ -181,7 +181,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("SubmitDriverResponse") { val message = new SubmitDriverResponse - intercept[SubmitRestValidationException] { message.validate() } + intercept[SubmitRestProtocolException] { message.validate() } intercept[IllegalArgumentException] { message.setSuccess("maybe not") } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") @@ -199,7 +199,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("KillDriverRequest") { val message = new KillDriverRequest - intercept[SubmitRestValidationException] { message.validate() } + intercept[SubmitRestProtocolException] { message.validate() } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") message.validate() @@ -214,7 +214,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("KillDriverResponse") { val message = new KillDriverResponse - intercept[SubmitRestValidationException] { message.validate() } + intercept[SubmitRestProtocolException] { message.validate() } intercept[IllegalArgumentException] { message.setSuccess("maybe not") } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") @@ -232,7 +232,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("DriverStatusRequest") { val message = new DriverStatusRequest - intercept[SubmitRestValidationException] { message.validate() } + intercept[SubmitRestProtocolException] { message.validate() } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") message.validate() @@ -247,7 +247,7 @@ class SubmitRestProtocolSuite extends FunSuite { test("DriverStatusResponse") { val message = new DriverStatusResponse - intercept[SubmitRestValidationException] { message.validate() } + intercept[SubmitRestProtocolException] { message.validate() } intercept[IllegalArgumentException] { message.setSuccess("maybe") } message.setSparkVersion("1.2.3") message.setDriverId("driver_123") @@ -264,12 +264,15 @@ class SubmitRestProtocolSuite extends FunSuite { assert(newMessage.getSparkVersion === "1.2.3") assert(newMessage.getServerSparkVersion === "1.2.3") assert(newMessage.getDriverId === "driver_123") + assert(newMessage.getDriverState === "RUNNING") assert(newMessage.getSuccess === "true") + assert(newMessage.getWorkerId === "worker_123") + assert(newMessage.getWorkerHostPort === "1.2.3.4:7780") } test("ErrorResponse") { val message = new ErrorResponse - intercept[SubmitRestValidationException] { message.validate() } + intercept[SubmitRestProtocolException] { message.validate() } message.setSparkVersion("1.2.3") message.setMessage("Field not found in submit request: X") message.validate() From bf696ff0b7135883e53e5fb275b4afa0db6c4a4a Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 30 Jan 2015 10:27:38 -0800 Subject: [PATCH 24/48] Add checks for enabling REST when using kill/status --- .../apache/spark/deploy/SparkSubmitArguments.scala | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 19ac58df04e8..8d070523f059 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -60,6 +60,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var driverToKill: String = null var driverToRequestStatusFor: String = null + private val restEnabledKey = "spark.submit.rest.enabled" + def action: SparkSubmitAction = { (driverToKill, driverToRequestStatusFor) match { case (null, null) => SparkSubmitAction.SUBMIT @@ -237,6 +239,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St if (!isStandaloneCluster) { SparkSubmit.printErrorAndExit("Killing drivers is only supported in standalone cluster mode") } + if (!isRestEnabled) { + SparkSubmit.printErrorAndExit("Killing drivers is currently only supported " + + s"through the REST interface. Please set $restEnabledKey to true.") + } if (driverToKill == null) { SparkSubmit.printErrorAndExit("Please specify a driver to kill") } @@ -247,6 +253,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St SparkSubmit.printErrorAndExit( "Requesting driver statuses is only supported in standalone cluster mode") } + if (!isRestEnabled) { + SparkSubmit.printErrorAndExit("Requesting driver statuses is currently only " + + s"supported through the REST interface. Please set $restEnabledKey to true.") + } if (driverToRequestStatusFor == null) { SparkSubmit.printErrorAndExit("Please specify a driver to request status for") } @@ -258,7 +268,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St /** Return whether the REST application submission protocol is enabled. */ def isRestEnabled: Boolean = { - sparkProperties.get("spark.submit.rest.enabled").getOrElse("false").toBoolean + sparkProperties.get(restEnabledKey).getOrElse("false").toBoolean } override def toString = { From 6c57b4bffdb76e0097007d34de10f66907a41806 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 30 Jan 2015 11:32:46 -0800 Subject: [PATCH 25/48] Increase timeout in end-to-end tests ...to avoid a potential YAFT, yet another flaky test. --- .../apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index 9a5aef0a4221..3984fb325083 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -135,7 +135,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B } /** Wait until the given driver has finished running up to the specified timeout. */ - private def waitUntilFinished(driverId: String, maxSeconds: Int = 10): Unit = { + private def waitUntilFinished(driverId: String, maxSeconds: Int = 30): Unit = { var finished = false val expireTime = System.currentTimeMillis + maxSeconds * 1000 while (!finished) { From b2fef8bc64e954fa42efe10ada71c8d57cc96ce0 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 30 Jan 2015 14:40:04 -0800 Subject: [PATCH 26/48] Abstract the success field to the general response This was common basically across all response messages. --- .../org/apache/spark/deploy/SparkSubmit.scala | 8 +++-- .../spark/deploy/SparkSubmitArguments.scala | 31 ++++++++++--------- .../deploy/rest/DriverStatusRequest.scala | 2 -- .../deploy/rest/DriverStatusResponse.scala | 4 --- .../spark/deploy/rest/ErrorResponse.scala | 4 +++ .../spark/deploy/rest/KillDriverRequest.scala | 2 -- .../deploy/rest/KillDriverResponse.scala | 7 ----- .../deploy/rest/SubmitDriverResponse.scala | 10 ------ .../spark/deploy/rest/SubmitRestClient.scala | 14 +++------ .../rest/SubmitRestProtocolMessage.scala | 14 +++++++-- 10 files changed, 43 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 328ecc9e768c..33aa4f493e32 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -93,7 +93,9 @@ object SparkSubmit { * Standalone cluster mode only. */ private def kill(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient().killDriver(args.master, args.driverToKill) + val client = new StandaloneRestClient + val response = client.killDriver(args.master, args.driverToKill) + printStream.println(response.toJson) } /** @@ -101,7 +103,9 @@ object SparkSubmit { * Standalone cluster mode only. */ private def requestStatus(args: SparkSubmitArguments): Unit = { - new StandaloneRestClient().requestDriverStatus(args.master, args.driverToRequestStatusFor) + val client = new StandaloneRestClient + val response = client.requestDriverStatus(args.master, args.driverToRequestStatusFor) + printStream.println(response.toJson) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 8d070523f059..cbfb93b2e9e6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -22,7 +22,7 @@ import java.util.jar.JarFile import scala.collection.mutable.{ArrayBuffer, HashMap} -import org.apache.spark.deploy.SparkSubmitAction.SparkSubmitAction +import org.apache.spark.deploy.SparkSubmitAction._ import org.apache.spark.util.Utils /** @@ -52,6 +52,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St var verbose: Boolean = false var isPython: Boolean = false var pyFiles: String = null + var action: SparkSubmitAction = null val sparkProperties: HashMap[String, String] = new HashMap[String, String]() // Standalone cluster mode only @@ -62,17 +63,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St private val restEnabledKey = "spark.submit.rest.enabled" - def action: SparkSubmitAction = { - (driverToKill, driverToRequestStatusFor) match { - case (null, null) => SparkSubmitAction.SUBMIT - case (_, null) => SparkSubmitAction.KILL - case (null, _) => SparkSubmitAction.REQUEST_STATUS - case _ => SparkSubmit.printErrorAndExit( - "Requested to both kill and request status for a driver. Choose only one.") - null // never reached - } - } - /** Default properties present in the currently defined defaults file. */ lazy val defaultSparkProperties: HashMap[String, String] = { val defaultProperties = new HashMap[String, String]() @@ -189,14 +179,17 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St if (name == null && primaryResource != null) { name = Utils.stripDirectory(primaryResource) } + + // Action should be SUBMIT unless otherwise specified + action = Option(action).getOrElse(SUBMIT) } /** Ensure that required fields exists. Call this only once all defaults are loaded. */ private def validateArguments(): Unit = { action match { - case SparkSubmitAction.SUBMIT => validateSubmitArguments() - case SparkSubmitAction.KILL => validateKillArguments() - case SparkSubmitAction.REQUEST_STATUS => validateStatusRequestArguments() + case SUBMIT => validateSubmitArguments() + case KILL => validateKillArguments() + case REQUEST_STATUS => validateStatusRequestArguments() } } @@ -379,10 +372,18 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St case ("--kill") :: value :: tail => driverToKill = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.") + } + action = KILL parse(tail) case ("--status") :: value :: tail => driverToRequestStatusFor = value + if (action != null) { + SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.") + } + action = REQUEST_STATUS parse(tail) case ("--supervise") :: tail => diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala index 9c925548b0e4..3f8eb4d54b03 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -22,10 +22,8 @@ package org.apache.spark.deploy.rest */ class DriverStatusRequest extends SubmitRestProtocolRequest { private val driverId = new SubmitRestProtocolField[String]("driverId") - def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) - protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(driverId) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala index de6b24b0e80c..682976ab9c9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -22,20 +22,17 @@ package org.apache.spark.deploy.rest */ class DriverStatusResponse extends SubmitRestProtocolResponse { private val driverId = new SubmitRestProtocolField[String]("driverId") - private val success = new SubmitRestProtocolField[Boolean]("success") // standalone cluster mode only private val driverState = new SubmitRestProtocolField[String]("driverState") private val workerId = new SubmitRestProtocolField[String]("workerId") private val workerHostPort = new SubmitRestProtocolField[String]("workerHostPort") def getDriverId: String = driverId.toString - def getSuccess: String = success.toString def getDriverState: String = driverState.toString def getWorkerId: String = workerId.toString def getWorkerHostPort: String = workerHostPort.toString def setDriverId(s: String): this.type = setField(driverId, s) - def setSuccess(s: String): this.type = setBooleanField(success, s) def setDriverState(s: String): this.type = setField(driverState, s) def setWorkerId(s: String): this.type = setField(workerId, s) def setWorkerHostPort(s: String): this.type = setField(workerHostPort, s) @@ -43,6 +40,5 @@ class DriverStatusResponse extends SubmitRestProtocolResponse { protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(driverId) - assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala index b7fcc97ea2a8..47399e0a6799 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -21,8 +21,12 @@ package org.apache.spark.deploy.rest * An error response message used in the REST application submission protocol. */ class ErrorResponse extends SubmitRestProtocolResponse { + // request was unsuccessful + setSuccess("false") + protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(message) + assert(!getSuccess.toBoolean, "The 'success' field cannot be true in an error response.") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala index 764f3e7753ae..b7a8ea02815e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -22,10 +22,8 @@ package org.apache.spark.deploy.rest */ class KillDriverRequest extends SubmitRestProtocolRequest { private val driverId = new SubmitRestProtocolField[String]("driverId") - def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) - protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(driverId) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala index 790527cc67a9..0224402a5af3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -22,17 +22,10 @@ package org.apache.spark.deploy.rest */ class KillDriverResponse extends SubmitRestProtocolResponse { private val driverId = new SubmitRestProtocolField[String]("driverId") - private val success = new SubmitRestProtocolField[Boolean]("success") - def getDriverId: String = driverId.toString - def getSuccess: String = success.toString - def setDriverId(s: String): this.type = setField(driverId, s) - def setSuccess(s: String): this.type = setBooleanField(success, s) - protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(driverId) - assertFieldIsSet(success) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala index 39f8a67aea98..a24b26067319 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -21,17 +21,7 @@ package org.apache.spark.deploy.rest * A response to the [[SubmitDriverRequest]] in the REST application submission protocol. */ class SubmitDriverResponse extends SubmitRestProtocolResponse { - private val success = new SubmitRestProtocolField[Boolean]("success") private val driverId = new SubmitRestProtocolField[String]("driverId") - - def getSuccess: String = success.toString def getDriverId: String = driverId.toString - - def setSuccess(s: String): this.type = setBooleanField(success, s) def setDriverId(s: String): this.type = setField(driverId, s) - - protected override def doValidate(): Unit = { - super.doValidate() - assertFieldIsSet(success) - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index b3864db7dc67..efea5b6f324d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -39,7 +39,7 @@ private[spark] abstract class SubmitRestClient extends Logging { val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) val response = sendHttp(url, request) - handleResponse(response) + validateResponse(response) } /** Request that the REST server kill the specified driver. */ @@ -48,7 +48,7 @@ private[spark] abstract class SubmitRestClient extends Logging { val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) val response = sendHttp(url, request) - handleResponse(response) + validateResponse(response) } /** Request the status of the specified driver from the REST server. */ @@ -57,7 +57,7 @@ private[spark] abstract class SubmitRestClient extends Logging { val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) val response = sendHttp(url, request) - handleResponse(response) + validateResponse(response) } /** Return the HTTP URL of the REST server that corresponds to the given master URL. */ @@ -95,14 +95,10 @@ private[spark] abstract class SubmitRestClient extends Logging { } } - /** Validate the response and log any error messages produced by the server. */ - private def handleResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = { + /** Validate the response... */ + private def validateResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = { try { response.validate() - response match { - case error: ErrorResponse => logError(s"Server returned error:\n${error.getMessage}") - case _ => - } } catch { case e: SubmitRestProtocolException => throw new SubmitRestProtocolException("Malformed response received from server", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index ff4e74f1ad91..dd325108d38b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -155,10 +155,20 @@ abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { */ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { protected override val sparkVersion = new SubmitRestProtocolField[String]("server_spark_version") - def getServerSparkVersion: String = sparkVersion.toString - def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s) + private val success = new SubmitRestProtocolField[Boolean]("success") + override def getSparkVersion: String = getServerSparkVersion + def getServerSparkVersion: String = sparkVersion.toString + def getSuccess: String = success.toString + override def setSparkVersion(s: String) = setServerSparkVersion(s) + def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s) + def setSuccess(s: String): this.type = setBooleanField(success, s) + + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(success) + } } object SubmitRestProtocolMessage { From ade28fd14a2589051f6f6288f5952299ce0abdef Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 31 Jan 2015 16:01:12 -0800 Subject: [PATCH 27/48] Clean up REST response output in Spark submit Now we don't log a response twice or log an error message twice. Also, before we would actually throw a ClassCastException if the server returns an error due to type erasure. This commit eases the relevant complexity involved. --- .../org/apache/spark/deploy/SparkSubmit.scala | 30 ++++++++++++++++--- .../deploy/rest/StandaloneRestClient.scala | 27 +++++------------ .../spark/deploy/rest/SubmitRestClient.scala | 14 +++++---- .../rest/SubmitRestProtocolMessage.scala | 2 +- 4 files changed, 44 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 33aa4f493e32..c55f5c9716a2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -23,9 +23,9 @@ import java.net.URL import scala.collection.mutable.{ArrayBuffer, HashMap, Map} +import org.apache.spark.deploy.rest._ import org.apache.spark.executor.ExecutorURLClassLoader import org.apache.spark.util.Utils -import org.apache.spark.deploy.rest.StandaloneRestClient /** * Whether to submit, kill, or request the status of an application. @@ -95,7 +95,10 @@ object SparkSubmit { private def kill(args: SparkSubmitArguments): Unit = { val client = new StandaloneRestClient val response = client.killDriver(args.master, args.driverToKill) - printStream.println(response.toJson) + response match { + case k: KillDriverResponse => handleRestResponse(k) + case r => handleUnexpectedRestResponse(r) + } } /** @@ -105,7 +108,10 @@ object SparkSubmit { private def requestStatus(args: SparkSubmitArguments): Unit = { val client = new StandaloneRestClient val response = client.requestDriverStatus(args.master, args.driverToRequestStatusFor) - printStream.println(response.toJson) + response match { + case s: DriverStatusResponse => handleRestResponse(s) + case r => handleUnexpectedRestResponse(r) + } } /** @@ -126,7 +132,12 @@ object SparkSubmit { val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) if (args.isStandaloneCluster && args.isRestEnabled) { printStream.println("Running Spark using the REST application submission protocol.") - new StandaloneRestClient().submitDriver(args) + val client = new StandaloneRestClient + val response = client.submitDriver(args) + response match { + case s: SubmitDriverResponse => handleRestResponse(s) + case r => handleUnexpectedRestResponse(r) + } } else { runMain(childArgs, childClasspath, sysProps, childMainClass) } @@ -461,6 +472,17 @@ object SparkSubmit { } } + /** Log the response sent by the server in the REST application submission protocol. */ + private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = { + printStream.println(s"Server responded with ${response.messageType}:\n${response.toJson}") + } + + /** Log an appropriate error if the response sent by the server is not of the expected type. */ + private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = { + printStream.println( + s"Error: Server responded with message of unexpected type ${unexpected.messageType}.") + } + /** * Return whether the given primary resource represents a user jar. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index 4f893febf744..eb671978e2af 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -39,7 +39,10 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { override def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolResponse = { validateSubmitArgs(args) val response = super.submitDriver(args) - val submitResponse = getResponse[SubmitDriverResponse](response).getOrElse { return response } + val submitResponse = response match { + case s: SubmitDriverResponse => s + case _ => return response + } val submitSuccess = submitResponse.getSuccess.toBoolean if (submitSuccess) { val driverId = submitResponse.getDriverId @@ -71,7 +74,10 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = { (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ => val response = requestDriverStatus(master, driverId) - val statusResponse = getResponse[DriverStatusResponse](response).getOrElse { return } + val statusResponse = response match { + case s: DriverStatusResponse => s + case _ => return + } val statusSuccess = statusResponse.getSuccess.toBoolean if (statusSuccess) { val driverState = Option(statusResponse.getDriverState) @@ -160,23 +166,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { "This REST client is only supported in standalone cluster mode.") } } - - /** - * Return the response as the expected type, or fail with an informative error message. - * Exposed for testing. - */ - private[spark] def getResponse[T <: SubmitRestProtocolResponse]( - response: SubmitRestProtocolResponse): Option[T] = { - try { - // Do not match on type T because types are erased at runtime - // Instead, manually try to cast it to type T ourselves - Some(response.asInstanceOf[T]) - } catch { - case e: ClassCastException => - logError(s"Server returned response of unexpected type:\n${response.toJson}") - None - } - } } private object StandaloneRestClient { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index efea5b6f324d..ea056351bcf8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -39,7 +39,7 @@ private[spark] abstract class SubmitRestClient extends Logging { val url = getHttpUrl(args.master) val request = constructSubmitRequest(args) val response = sendHttp(url, request) - validateResponse(response) + handleResponse(response) } /** Request that the REST server kill the specified driver. */ @@ -48,7 +48,7 @@ private[spark] abstract class SubmitRestClient extends Logging { val url = getHttpUrl(master) val request = constructKillRequest(master, driverId) val response = sendHttp(url, request) - validateResponse(response) + handleResponse(response) } /** Request the status of the specified driver from the REST server. */ @@ -57,7 +57,7 @@ private[spark] abstract class SubmitRestClient extends Logging { val url = getHttpUrl(master) val request = constructStatusRequest(master, driverId) val response = sendHttp(url, request) - validateResponse(response) + handleResponse(response) } /** Return the HTTP URL of the REST server that corresponds to the given master URL. */ @@ -95,10 +95,14 @@ private[spark] abstract class SubmitRestClient extends Logging { } } - /** Validate the response... */ - private def validateResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = { + /** Validate the response and log any error messages provided by the server. */ + private def handleResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = { try { response.validate() + response match { + case e: ErrorResponse => logError(s"Server responded with error:\n${e.getMessage}") + case _ => + } } catch { case e: SubmitRestProtocolException => throw new SubmitRestProtocolException("Malformed response received from server", e) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index dd325108d38b..8f7dd87fbea3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils @JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) @JsonPropertyOrder(alphabetic = true) abstract class SubmitRestProtocolMessage { - private val messageType = Utils.getFormattedClassName(this) + val messageType = Utils.getFormattedClassName(this) protected val action: String = messageType protected val sparkVersion: SubmitRestProtocolField[String] protected val message = new SubmitRestProtocolField[String]("message") From 9229433bf092306dfcdbce5c068e12855ea81971 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sat, 31 Jan 2015 16:26:04 -0800 Subject: [PATCH 28/48] Reduce duplicate naming in REST field This commit also fixes a the standalone REST protocol test, which would fail with ClassCastException if the server returns error for the same reason explained in the previous commit. --- .../deploy/rest/DriverStatusRequest.scala | 4 +- .../deploy/rest/DriverStatusResponse.scala | 10 ++--- .../spark/deploy/rest/ErrorResponse.scala | 10 +++-- .../spark/deploy/rest/KillDriverRequest.scala | 4 +- .../deploy/rest/KillDriverResponse.scala | 4 +- .../deploy/rest/SubmitDriverRequest.scala | 32 ++++++------- .../deploy/rest/SubmitDriverResponse.scala | 2 +- .../deploy/rest/SubmitRestProtocolField.scala | 2 +- .../rest/SubmitRestProtocolMessage.scala | 25 ++++++----- .../rest/StandaloneRestProtocolSuite.scala | 45 ++++++++++++------- .../deploy/rest/SubmitRestProtocolSuite.scala | 18 +++++--- 11 files changed, 91 insertions(+), 65 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala index 3f8eb4d54b03..e25bb45668e5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -21,11 +21,11 @@ package org.apache.spark.deploy.rest * A request to query the status of a driver in the REST application submission protocol. */ class DriverStatusRequest extends SubmitRestProtocolRequest { - private val driverId = new SubmitRestProtocolField[String]("driverId") + private val driverId = new SubmitRestProtocolField[String] def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(driverId) + assertFieldIsSet(driverId, "driverId") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala index 682976ab9c9c..568204ac6c81 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -21,11 +21,11 @@ package org.apache.spark.deploy.rest * A response to the [[DriverStatusRequest]] in the REST application submission protocol. */ class DriverStatusResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String]("driverId") + private val driverId = new SubmitRestProtocolField[String] // standalone cluster mode only - private val driverState = new SubmitRestProtocolField[String]("driverState") - private val workerId = new SubmitRestProtocolField[String]("workerId") - private val workerHostPort = new SubmitRestProtocolField[String]("workerHostPort") + private val driverState = new SubmitRestProtocolField[String] + private val workerId = new SubmitRestProtocolField[String] + private val workerHostPort = new SubmitRestProtocolField[String] def getDriverId: String = driverId.toString def getDriverState: String = driverState.toString @@ -39,6 +39,6 @@ class DriverStatusResponse extends SubmitRestProtocolResponse { protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(driverId) + assertFieldIsSet(driverId, "driverId") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala index 47399e0a6799..0bc003f97ab5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -17,16 +17,20 @@ package org.apache.spark.deploy.rest +import com.fasterxml.jackson.annotation.JsonIgnore + /** * An error response message used in the REST application submission protocol. */ class ErrorResponse extends SubmitRestProtocolResponse { - // request was unsuccessful setSuccess("false") + // Don't bother logging success = false in the JSON + @JsonIgnore + override def getSuccess: String = super.getSuccess + protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(message) - assert(!getSuccess.toBoolean, "The 'success' field cannot be true in an error response.") + assertFieldIsSet(message, "message") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala index b7a8ea02815e..99f52a083954 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -21,11 +21,11 @@ package org.apache.spark.deploy.rest * A request to kill a driver in the REST application submission protocol. */ class KillDriverRequest extends SubmitRestProtocolRequest { - private val driverId = new SubmitRestProtocolField[String]("driverId") + private val driverId = new SubmitRestProtocolField[String] def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(driverId) + assertFieldIsSet(driverId, "driverId") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala index 0224402a5af3..cb0112653b2f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -21,11 +21,11 @@ package org.apache.spark.deploy.rest * A response to the [[KillDriverRequest]] in the REST application submission protocol. */ class KillDriverResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String]("driverId") + private val driverId = new SubmitRestProtocolField[String] def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(driverId) + assertFieldIsSet(driverId, "driverId") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala index b083d27e7901..c5daa616860e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala @@ -29,20 +29,20 @@ import org.apache.spark.util.JsonProtocol * A request to submit a driver in the REST application submission protocol. */ class SubmitDriverRequest extends SubmitRestProtocolRequest { - private val appName = new SubmitRestProtocolField[String]("appName") - private val appResource = new SubmitRestProtocolField[String]("appResource") - private val mainClass = new SubmitRestProtocolField[String]("mainClass") - private val jars = new SubmitRestProtocolField[String]("jars") - private val files = new SubmitRestProtocolField[String]("files") - private val pyFiles = new SubmitRestProtocolField[String]("pyFiles") - private val driverMemory = new SubmitRestProtocolField[String]("driverMemory") - private val driverCores = new SubmitRestProtocolField[Int]("driverCores") - private val driverExtraJavaOptions = new SubmitRestProtocolField[String]("driverExtraJavaOptions") - private val driverExtraClassPath = new SubmitRestProtocolField[String]("driverExtraClassPath") - private val driverExtraLibraryPath = new SubmitRestProtocolField[String]("driverExtraLibraryPath") - private val superviseDriver = new SubmitRestProtocolField[Boolean]("superviseDriver") - private val executorMemory = new SubmitRestProtocolField[String]("executorMemory") - private val totalExecutorCores = new SubmitRestProtocolField[Int]("totalExecutorCores") + private val appName = new SubmitRestProtocolField[String] + private val appResource = new SubmitRestProtocolField[String] + private val mainClass = new SubmitRestProtocolField[String] + private val jars = new SubmitRestProtocolField[String] + private val files = new SubmitRestProtocolField[String] + private val pyFiles = new SubmitRestProtocolField[String] + private val driverMemory = new SubmitRestProtocolField[String] + private val driverCores = new SubmitRestProtocolField[Int] + private val driverExtraJavaOptions = new SubmitRestProtocolField[String] + private val driverExtraClassPath = new SubmitRestProtocolField[String] + private val driverExtraLibraryPath = new SubmitRestProtocolField[String] + private val superviseDriver = new SubmitRestProtocolField[Boolean] + private val executorMemory = new SubmitRestProtocolField[String] + private val totalExecutorCores = new SubmitRestProtocolField[Int] // Special fields private val appArgs = new ArrayBuffer[String] @@ -140,7 +140,7 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest { protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(appName) - assertFieldIsSet(appResource) + assertFieldIsSet(appName, "appName") + assertFieldIsSet(appResource, "appResource") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala index a24b26067319..1ac769f3110a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -21,7 +21,7 @@ package org.apache.spark.deploy.rest * A response to the [[SubmitDriverRequest]] in the REST application submission protocol. */ class SubmitDriverResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String]("driverId") + private val driverId = new SubmitRestProtocolField[String] def getDriverId: String = driverId.toString def setDriverId(s: String): this.type = setField(driverId, s) } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala index 3932e68fcd2b..3e7208c4d8c3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.rest /** * A field used in [[SubmitRestProtocolMessage]]s. */ -class SubmitRestProtocolField[T](val name: String) { +class SubmitRestProtocolField[T] { protected var value: Option[T] = None def isSet: Boolean = value.isDefined def getValue: Option[T] = value diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index 8f7dd87fbea3..d2f12a5e863d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.rest import com.fasterxml.jackson.annotation._ import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility import com.fasterxml.jackson.annotation.JsonInclude.Include -import com.fasterxml.jackson.databind.ObjectMapper +import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature} import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ @@ -42,7 +42,7 @@ abstract class SubmitRestProtocolMessage { val messageType = Utils.getFormattedClassName(this) protected val action: String = messageType protected val sparkVersion: SubmitRestProtocolField[String] - protected val message = new SubmitRestProtocolField[String]("message") + protected val message = new SubmitRestProtocolField[String] // Required for JSON de/serialization and not explicitly used private def getAction: String = action @@ -64,7 +64,8 @@ abstract class SubmitRestProtocolMessage { def toJson: String = { validate() val mapper = new ObjectMapper - pretty(parse(mapper.writeValueAsString(this))) + mapper.enable(SerializationFeature.INDENT_OUTPUT) + mapper.writeValueAsString(this) } /** @@ -85,14 +86,13 @@ abstract class SubmitRestProtocolMessage { if (action == null) { throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType") } - assertFieldIsSet(sparkVersion) } /** Assert that the specified field is set in this message. */ - protected def assertFieldIsSet(field: SubmitRestProtocolField[_]): Unit = { + protected def assertFieldIsSet(field: SubmitRestProtocolField[_], name: String): Unit = { if (!field.isSet) { throw new SubmitRestMissingFieldException( - s"Field '${field.name}' is missing in message $messageType.") + s"Field '$name' is missing in message $messageType.") } } @@ -143,19 +143,23 @@ abstract class SubmitRestProtocolMessage { * An abstract request sent from the client in the REST application submission protocol. */ abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { - protected override val sparkVersion = new SubmitRestProtocolField[String]("client_spark_version") + protected override val sparkVersion = new SubmitRestProtocolField[String] def getClientSparkVersion: String = sparkVersion.toString def setClientSparkVersion(s: String): this.type = setField(sparkVersion, s) override def getSparkVersion: String = getClientSparkVersion override def setSparkVersion(s: String) = setClientSparkVersion(s) + protected override def doValidate(): Unit = { + super.doValidate() + assertFieldIsSet(sparkVersion, "clientSparkVersion") + } } /** * An abstract response sent from the server in the REST application submission protocol. */ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { - protected override val sparkVersion = new SubmitRestProtocolField[String]("server_spark_version") - private val success = new SubmitRestProtocolField[Boolean]("success") + protected override val sparkVersion = new SubmitRestProtocolField[String] + private val success = new SubmitRestProtocolField[Boolean] override def getSparkVersion: String = getServerSparkVersion def getServerSparkVersion: String = sparkVersion.toString @@ -167,7 +171,8 @@ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(success) + assertFieldIsSet(sparkVersion, "serverSparkVersion") + assertFieldIsSet(success, "success") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index 3984fb325083..2a147c77c9cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -60,7 +60,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B test("kill empty driver") { val response = client.killDriver(masterRestUrl, "driver-that-does-not-exist") - val killResponse = getResponse[KillDriverResponse](response, client) + val killResponse = getKillResponse(response) val killSuccess = killResponse.getSuccess assert(killSuccess === "false") } @@ -71,11 +71,11 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val size = 500 val driverId = submitApplication(resultsFile, numbers, size) val response = client.killDriver(masterRestUrl, driverId) - val killResponse = getResponse[KillDriverResponse](response, client) + val killResponse = getKillResponse(response) val killSuccess = killResponse.getSuccess waitUntilFinished(driverId) val response2 = client.requestDriverStatus(masterRestUrl, driverId) - val statusResponse = getResponse[DriverStatusResponse](response2, client) + val statusResponse = getStatusResponse(response2) val statusSuccess = statusResponse.getSuccess val driverState = statusResponse.getDriverState assert(killSuccess === "true") @@ -87,7 +87,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B test("request status for empty driver") { val response = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist") - val statusResponse = getResponse[DriverStatusResponse](response, client) + val statusResponse = getStatusResponse(response) val statusSuccess = statusResponse.getSuccess assert(statusSuccess === "false") } @@ -130,7 +130,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val args = new SparkSubmitArguments(commandLineArgs) SparkSubmit.prepareSubmitEnvironment(args) val response = client.submitDriver(args) - val submitResponse = getResponse[SubmitDriverResponse](response, client) + val submitResponse = getSubmitResponse(response) submitResponse.getDriverId } @@ -140,7 +140,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val expireTime = System.currentTimeMillis + maxSeconds * 1000 while (!finished) { val response = client.requestDriverStatus(masterRestUrl, driverId) - val statusResponse = getResponse[DriverStatusResponse](response, client) + val statusResponse = getStatusResponse(response) val driverState = statusResponse.getDriverState finished = driverState != DriverState.SUBMITTED.toString && @@ -151,17 +151,30 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B } } - /** Return the response as the expected type, or fail with an informative error message. */ - private def getResponse[T <: SubmitRestProtocolResponse]( - response: SubmitRestProtocolResponse, - client: StandaloneRestClient): T = { + /** Return the response as a submit driver response, or fail with error otherwise. */ + private def getSubmitResponse(response: SubmitRestProtocolResponse): SubmitDriverResponse = { response match { - case error: ErrorResponse => - fail(s"Error from the server:\n${error.getMessage}") - case _ => - client.getResponse[T](response).getOrElse { - fail(s"Response type was unexpected: ${response.toJson}") - } + case s: SubmitDriverResponse => s + case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}") + case r => fail(s"Expected submit response. Actual: ${r.toJson}") + } + } + + /** Return the response as a kill driver response, or fail with error otherwise. */ + private def getKillResponse(response: SubmitRestProtocolResponse): KillDriverResponse = { + response match { + case k: KillDriverResponse => k + case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}") + case r => fail(s"Expected kill response. Actual: ${r.toJson}") + } + } + + /** Return the response as a driver status response, or fail with error otherwise. */ + private def getStatusResponse(response: SubmitRestProtocolResponse): DriverStatusResponse = { + response match { + case s: DriverStatusResponse => s + case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}") + case r => fail(s"Expected status response. Actual: ${r.toJson}") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index 8c5cf6baee45..f00f7848befc 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -108,12 +108,15 @@ class SubmitRestProtocolSuite extends FunSuite { } test("response to and from JSON") { - val response = new DummyResponse().setSparkVersion("3.3.4") + val response = new DummyResponse() + .setSparkVersion("3.3.4") + .setSuccess("true") val json = response.toJson assertJsonEquals(json, dummyResponseJson) val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse]) assert(newResponse.getSparkVersion === "3.3.4") assert(newResponse.getServerSparkVersion === "3.3.4") + assert(newResponse.getSuccess === "true") assert(newResponse.getMessage === null) } @@ -300,7 +303,8 @@ class SubmitRestProtocolSuite extends FunSuite { """ |{ | "action" : "DummyResponse", - | "serverSparkVersion" : "3.3.4" + | "serverSparkVersion" : "3.3.4", + | "success": "true" |} """.stripMargin @@ -403,9 +407,9 @@ class SubmitRestProtocolSuite extends FunSuite { private class DummyResponse extends SubmitRestProtocolResponse private class DummyRequest extends SubmitRestProtocolRequest { - private val active = new SubmitRestProtocolField[Boolean]("active") - private val age = new SubmitRestProtocolField[Int]("age") - private val name = new SubmitRestProtocolField[String]("name") + private val active = new SubmitRestProtocolField[Boolean] + private val age = new SubmitRestProtocolField[Int] + private val name = new SubmitRestProtocolField[String] def getActive: String = active.toString def getAge: String = age.toString @@ -417,8 +421,8 @@ private class DummyRequest extends SubmitRestProtocolRequest { protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(name) - assertFieldIsSet(age) + assertFieldIsSet(name, "name") + assertFieldIsSet(age, "age") assert(age.getValue.get > 5, "Not old enough!") } } From 1f1c03ffdc541cc665b27655c40f10c6e938a1db Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Sun, 1 Feb 2015 16:26:10 -0800 Subject: [PATCH 29/48] Use Jackson's DefaultScalaModule to simplify messages Instead of explicitly defining getters and setters in the messages, we let Jackson's scala module do the work. This simplifies the code for each message significantly, though at the expense of reducing the level of type safety for users who implement their own clients and servers. --- core/pom.xml | 8 + .../deploy/rest/DriverStatusRequest.scala | 4 +- .../deploy/rest/DriverStatusResponse.scala | 19 +- .../spark/deploy/rest/ErrorResponse.scala | 9 +- .../spark/deploy/rest/KillDriverRequest.scala | 4 +- .../deploy/rest/KillDriverResponse.scala | 4 +- .../deploy/rest/StandaloneRestClient.scala | 71 ++-- .../deploy/rest/StandaloneRestServer.scala | 75 ++-- .../deploy/rest/SubmitDriverRequest.scala | 134 ++----- .../deploy/rest/SubmitDriverResponse.scala | 4 +- .../spark/deploy/rest/SubmitRestClient.scala | 2 +- .../deploy/rest/SubmitRestProtocolField.scala | 30 -- .../rest/SubmitRestProtocolMessage.scala | 117 +++--- .../spark/deploy/rest/SubmitRestServer.scala | 7 +- .../rest/StandaloneRestProtocolSuite.scala | 22 +- .../deploy/rest/SubmitRestProtocolSuite.scala | 334 ++++++++---------- pom.xml | 11 + 17 files changed, 360 insertions(+), 495 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala diff --git a/core/pom.xml b/core/pom.xml index 31e919a1c831..d4f2e94b5a14 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -213,6 +213,14 @@ com.codahale.metrics metrics-graphite + + com.fasterxml.jackson.core + jackson-databind + + + com.fasterxml.jackson.module + jackson-module-scala_2.10 + org.apache.derby derby diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala index e25bb45668e5..0c15816b4be2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala @@ -21,9 +21,7 @@ package org.apache.spark.deploy.rest * A request to query the status of a driver in the REST application submission protocol. */ class DriverStatusRequest extends SubmitRestProtocolRequest { - private val driverId = new SubmitRestProtocolField[String] - def getDriverId: String = driverId.toString - def setDriverId(s: String): this.type = setField(driverId, s) + var driverId: String = null protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(driverId, "driverId") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala index 568204ac6c81..97b9d02f1e9f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala @@ -21,21 +21,12 @@ package org.apache.spark.deploy.rest * A response to the [[DriverStatusRequest]] in the REST application submission protocol. */ class DriverStatusResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String] - // standalone cluster mode only - private val driverState = new SubmitRestProtocolField[String] - private val workerId = new SubmitRestProtocolField[String] - private val workerHostPort = new SubmitRestProtocolField[String] - - def getDriverId: String = driverId.toString - def getDriverState: String = driverState.toString - def getWorkerId: String = workerId.toString - def getWorkerHostPort: String = workerHostPort.toString + var driverId: String = null - def setDriverId(s: String): this.type = setField(driverId, s) - def setDriverState(s: String): this.type = setField(driverState, s) - def setWorkerId(s: String): this.type = setField(workerId, s) - def setWorkerHostPort(s: String): this.type = setField(workerHostPort, s) + // standalone cluster mode only + var driverState: String = null + var workerId: String = null + var workerHostPort: String = null protected override def doValidate(): Unit = { super.doValidate() diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala index 0bc003f97ab5..6bb674dc88a3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala @@ -17,20 +17,17 @@ package org.apache.spark.deploy.rest -import com.fasterxml.jackson.annotation.JsonIgnore - /** * An error response message used in the REST application submission protocol. */ class ErrorResponse extends SubmitRestProtocolResponse { - setSuccess("false") - // Don't bother logging success = false in the JSON - @JsonIgnore - override def getSuccess: String = super.getSuccess + // request was unsuccessful + success = "false" protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(message, "message") + assert(!success.toBoolean, s"The 'success' field must be false in $messageType.") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala index 99f52a083954..7660864fbbf6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala @@ -21,9 +21,7 @@ package org.apache.spark.deploy.rest * A request to kill a driver in the REST application submission protocol. */ class KillDriverRequest extends SubmitRestProtocolRequest { - private val driverId = new SubmitRestProtocolField[String] - def getDriverId: String = driverId.toString - def setDriverId(s: String): this.type = setField(driverId, s) + var driverId: String = null protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(driverId, "driverId") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala index cb0112653b2f..1366d6ba77c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala @@ -21,9 +21,7 @@ package org.apache.spark.deploy.rest * A response to the [[KillDriverRequest]] in the REST application submission protocol. */ class KillDriverResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String] - def getDriverId: String = driverId.toString - def setDriverId(s: String): this.type = setField(driverId, s) + var driverId: String = null protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(driverId, "driverId") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala index eb671978e2af..df7319235c65 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala @@ -43,14 +43,19 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { case s: SubmitDriverResponse => s case _ => return response } - val submitSuccess = submitResponse.getSuccess.toBoolean + // Report status of submitted driver to user + val submitSuccess = submitResponse.success.toBoolean if (submitSuccess) { - val driverId = submitResponse.getDriverId - logInfo(s"Driver successfully submitted as $driverId. Polling driver state...") - pollSubmittedDriverStatus(args.master, driverId) + val driverId = submitResponse.driverId + if (driverId != null) { + logInfo(s"Driver successfully submitted as $driverId. Polling driver state...") + pollSubmittedDriverStatus(args.master, driverId) + } else { + logError("Application successfully submitted, but driver ID was not provided!") + } } else { - val submitMessage = submitResponse.getMessage - logError(s"Application submission failed: $submitMessage") + val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("") + logError("Application submission failed" + failMessage) } submitResponse } @@ -78,12 +83,12 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { case s: DriverStatusResponse => s case _ => return } - val statusSuccess = statusResponse.getSuccess.toBoolean + val statusSuccess = statusResponse.success.toBoolean if (statusSuccess) { - val driverState = Option(statusResponse.getDriverState) - val workerId = Option(statusResponse.getWorkerId) - val workerHostPort = Option(statusResponse.getWorkerHostPort) - val exception = Option(statusResponse.getMessage) + val driverState = Option(statusResponse.driverState) + val workerId = Option(statusResponse.workerId) + val workerHostPort = Option(statusResponse.workerHostPort) + val exception = Option(statusResponse.message) // Log driver state, if present driverState match { case Some(state) => logInfo(s"State of driver $driverId is now $state.") @@ -105,21 +110,21 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { /** Construct a submit driver request message. */ protected override def constructSubmitRequest(args: SparkSubmitArguments): SubmitDriverRequest = { - val message = new SubmitDriverRequest() - .setSparkVersion(sparkVersion) - .setAppName(args.name) - .setAppResource(args.primaryResource) - .setMainClass(args.mainClass) - .setJars(args.jars) - .setFiles(args.files) - .setDriverMemory(args.driverMemory) - .setDriverCores(args.driverCores) - .setDriverExtraJavaOptions(args.driverExtraJavaOptions) - .setDriverExtraClassPath(args.driverExtraClassPath) - .setDriverExtraLibraryPath(args.driverExtraLibraryPath) - .setSuperviseDriver(args.supervise.toString) - .setExecutorMemory(args.executorMemory) - .setTotalExecutorCores(args.totalExecutorCores) + val message = new SubmitDriverRequest + message.clientSparkVersion = sparkVersion + message.appName = args.name + message.appResource = args.primaryResource + message.mainClass = args.mainClass + message.jars = args.jars + message.files = args.files + message.driverMemory = args.driverMemory + message.driverCores = args.driverCores + message.driverExtraJavaOptions = args.driverExtraJavaOptions + message.driverExtraClassPath = args.driverExtraClassPath + message.driverExtraLibraryPath = args.driverExtraLibraryPath + message.superviseDriver = args.supervise.toString + message.executorMemory = args.executorMemory + message.totalExecutorCores = args.totalExecutorCores args.childArgs.foreach(message.addAppArg) args.sparkProperties.foreach { case (k, v) => message.setSparkProperty(k, v) } sys.env.foreach { case (k, v) => @@ -132,18 +137,20 @@ private[spark] class StandaloneRestClient extends SubmitRestClient { protected override def constructKillRequest( master: String, driverId: String): KillDriverRequest = { - new KillDriverRequest() - .setSparkVersion(sparkVersion) - .setDriverId(driverId) + val k = new KillDriverRequest + k.clientSparkVersion = sparkVersion + k.driverId = driverId + k } /** Construct a driver status request message. */ protected override def constructStatusRequest( master: String, driverId: String): DriverStatusRequest = { - new DriverStatusRequest() - .setSparkVersion(sparkVersion) - .setDriverId(driverId) + val d = new DriverStatusRequest + d.clientSparkVersion = sparkVersion + d.driverId = driverId + d } /** Extract the URL portion of the master address. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 802933bf3f2a..de0f701fd325 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -58,39 +58,42 @@ private[spark] class StandaloneRestServerHandler( val driverDescription = buildDriverDescription(request) val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) - new SubmitDriverResponse() - .setSparkVersion(sparkVersion) - .setMessage(response.message) - .setSuccess(response.success.toString) - .setDriverId(response.driverId.orNull) + val s = new SubmitDriverResponse + s.serverSparkVersion = sparkVersion + s.message = response.message + s.success = response.success.toString + s.driverId = response.driverId.orNull + s } /** Handle a request to kill a driver. */ protected override def handleKill(request: KillDriverRequest): KillDriverResponse = { - val driverId = request.getDriverId + val driverId = request.driverId val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( DeployMessages.RequestKillDriver(driverId), masterActor, askTimeout) - new KillDriverResponse() - .setSparkVersion(sparkVersion) - .setMessage(response.message) - .setDriverId(driverId) - .setSuccess(response.success.toString) + val k = new KillDriverResponse + k.serverSparkVersion = sparkVersion + k.message = response.message + k.driverId = driverId + k.success = response.success.toString + k } /** Handle a request for a driver's status. */ protected override def handleStatus(request: DriverStatusRequest): DriverStatusResponse = { - val driverId = request.getDriverId + val driverId = request.driverId val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( DeployMessages.RequestDriverStatus(driverId), masterActor, askTimeout) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } - new DriverStatusResponse() - .setSparkVersion(sparkVersion) - .setDriverId(driverId) - .setSuccess(response.found.toString) - .setDriverState(response.state.map(_.toString).orNull) - .setWorkerId(response.workerId.orNull) - .setWorkerHostPort(response.workerHostPort.orNull) - .setMessage(message.orNull) + val d = new DriverStatusResponse + d.serverSparkVersion = sparkVersion + d.driverId = driverId + d.success = response.found.toString + d.driverState = response.state.map(_.toString).orNull + d.workerId = response.workerId.orNull + d.workerHostPort = response.workerHostPort.orNull + d.message = message.orNull + d } /** @@ -101,27 +104,27 @@ private[spark] class StandaloneRestServerHandler( */ private def buildDriverDescription(request: SubmitDriverRequest): DriverDescription = { // Required fields, including the main class because python is not yet supported - val appName = request.getAppName - val appResource = request.getAppResource - val mainClass = request.getMainClass + val appName = request.appName + val appResource = request.appResource + val mainClass = request.mainClass if (mainClass == null) { throw new SubmitRestMissingFieldException("Main class must be set in submit request.") } // Optional fields - val jars = Option(request.getJars) - val files = Option(request.getFiles) - val driverMemory = Option(request.getDriverMemory) - val driverCores = Option(request.getDriverCores) - val driverExtraJavaOptions = Option(request.getDriverExtraJavaOptions) - val driverExtraClassPath = Option(request.getDriverExtraClassPath) - val driverExtraLibraryPath = Option(request.getDriverExtraLibraryPath) - val superviseDriver = Option(request.getSuperviseDriver) - val executorMemory = Option(request.getExecutorMemory) - val totalExecutorCores = Option(request.getTotalExecutorCores) - val appArgs = request.getAppArgs - val sparkProperties = request.getSparkProperties - val environmentVariables = request.getEnvironmentVariables + val jars = Option(request.jars) + val files = Option(request.files) + val driverMemory = Option(request.driverMemory) + val driverCores = Option(request.driverCores) + val driverExtraJavaOptions = Option(request.driverExtraJavaOptions) + val driverExtraClassPath = Option(request.driverExtraClassPath) + val driverExtraLibraryPath = Option(request.driverExtraLibraryPath) + val superviseDriver = Option(request.superviseDriver) + val executorMemory = Option(request.executorMemory) + val totalExecutorCores = Option(request.totalExecutorCores) + val appArgs = request.appArgs + val sparkProperties = request.sparkProperties + val environmentVariables = request.environmentVariables // Translate all fields to the relevant Spark properties val conf = new SparkConf(false) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala index c5daa616860e..d6bc050285b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala @@ -20,127 +20,51 @@ package org.apache.spark.deploy.rest import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import com.fasterxml.jackson.annotation.{JsonIgnore, JsonProperty} -import org.json4s.jackson.JsonMethods._ - -import org.apache.spark.util.JsonProtocol +import com.fasterxml.jackson.annotation.{JsonProperty, JsonIgnore, JsonInclude} /** * A request to submit a driver in the REST application submission protocol. */ class SubmitDriverRequest extends SubmitRestProtocolRequest { - private val appName = new SubmitRestProtocolField[String] - private val appResource = new SubmitRestProtocolField[String] - private val mainClass = new SubmitRestProtocolField[String] - private val jars = new SubmitRestProtocolField[String] - private val files = new SubmitRestProtocolField[String] - private val pyFiles = new SubmitRestProtocolField[String] - private val driverMemory = new SubmitRestProtocolField[String] - private val driverCores = new SubmitRestProtocolField[Int] - private val driverExtraJavaOptions = new SubmitRestProtocolField[String] - private val driverExtraClassPath = new SubmitRestProtocolField[String] - private val driverExtraLibraryPath = new SubmitRestProtocolField[String] - private val superviseDriver = new SubmitRestProtocolField[Boolean] - private val executorMemory = new SubmitRestProtocolField[String] - private val totalExecutorCores = new SubmitRestProtocolField[Int] + var appName: String = null + var appResource: String = null + var mainClass: String = null + var jars: String = null + var files: String = null + var pyFiles: String = null + var driverMemory: String = null + var driverCores: String = null + var driverExtraJavaOptions: String = null + var driverExtraClassPath: String = null + var driverExtraLibraryPath: String = null + var superviseDriver: String = null + var executorMemory: String = null + var totalExecutorCores: String = null // Special fields - private val appArgs = new ArrayBuffer[String] - private val sparkProperties = new mutable.HashMap[String, String] - private val envVars = new mutable.HashMap[String, String] - - def getAppName: String = appName.toString - def getAppResource: String = appResource.toString - def getMainClass: String = mainClass.toString - def getJars: String = jars.toString - def getFiles: String = files.toString - def getPyFiles: String = pyFiles.toString - def getDriverMemory: String = driverMemory.toString - def getDriverCores: String = driverCores.toString - def getDriverExtraJavaOptions: String = driverExtraJavaOptions.toString - def getDriverExtraClassPath: String = driverExtraClassPath.toString - def getDriverExtraLibraryPath: String = driverExtraLibraryPath.toString - def getSuperviseDriver: String = superviseDriver.toString - def getExecutorMemory: String = executorMemory.toString - def getTotalExecutorCores: String = totalExecutorCores.toString - - // Special getters required for JSON serialization @JsonProperty("appArgs") - private def getAppArgsJson: String = arrayToJson(getAppArgs) + private val _appArgs = new ArrayBuffer[String] @JsonProperty("sparkProperties") - private def getSparkPropertiesJson: String = mapToJson(getSparkProperties) + private val _sparkProperties = new mutable.HashMap[String, String] @JsonProperty("environmentVariables") - private def getEnvironmentVariablesJson: String = mapToJson(getEnvironmentVariables) + private val _envVars = new mutable.HashMap[String, String] - def setAppName(s: String): this.type = setField(appName, s) - def setAppResource(s: String): this.type = setField(appResource, s) - def setMainClass(s: String): this.type = setField(mainClass, s) - def setJars(s: String): this.type = setField(jars, s) - def setFiles(s: String): this.type = setField(files, s) - def setPyFiles(s: String): this.type = setField(pyFiles, s) - def setDriverMemory(s: String): this.type = setField(driverMemory, s) - def setDriverCores(s: String): this.type = setNumericField(driverCores, s) - def setDriverExtraJavaOptions(s: String): this.type = setField(driverExtraJavaOptions, s) - def setDriverExtraClassPath(s: String): this.type = setField(driverExtraClassPath, s) - def setDriverExtraLibraryPath(s: String): this.type = setField(driverExtraLibraryPath, s) - def setSuperviseDriver(s: String): this.type = setBooleanField(superviseDriver, s) - def setExecutorMemory(s: String): this.type = setField(executorMemory, s) - def setTotalExecutorCores(s: String): this.type = setNumericField(totalExecutorCores, s) - - // Special setters required for JSON deserialization - @JsonProperty("appArgs") - private def setAppArgsJson(s: String): Unit = { - appArgs.clear() - appArgs ++= JsonProtocol.arrayFromJson(parse(s)) - } - @JsonProperty("sparkProperties") - private def setSparkPropertiesJson(s: String): Unit = { - sparkProperties.clear() - sparkProperties ++= JsonProtocol.mapFromJson(parse(s)) - } - @JsonProperty("environmentVariables") - private def setEnvironmentVariablesJson(s: String): Unit = { - envVars.clear() - envVars ++= JsonProtocol.mapFromJson(parse(s)) - } + def appArgs: Array[String] = _appArgs.toArray + def sparkProperties: Map[String, String] = _sparkProperties.toMap + def environmentVariables: Map[String, String] = _envVars.toMap - /** Return an array of arguments to be passed to the application. */ - @JsonIgnore - def getAppArgs: Array[String] = appArgs.toArray - - /** Return a map of Spark properties to be passed to the application as java options. */ - @JsonIgnore - def getSparkProperties: Map[String, String] = sparkProperties.toMap - - /** Return a map of environment variables to be passed to the application. */ - @JsonIgnore - def getEnvironmentVariables: Map[String, String] = envVars.toMap - - /** Add a command line argument to be passed to the application. */ - @JsonIgnore - def addAppArg(s: String): this.type = { appArgs += s; this } - - /** Set a Spark property to be passed to the application as a java option. */ - @JsonIgnore - def setSparkProperty(k: String, v: String): this.type = { sparkProperties(k) = v; this } - - /** Set an environment variable to be passed to the application. */ - @JsonIgnore - def setEnvironmentVariable(k: String, v: String): this.type = { envVars(k) = v; this } - - /** Serialize the given Array to a compact JSON string. */ - private def arrayToJson(arr: Array[String]): String = { - if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else null - } - - /** Serialize the given Map to a compact JSON string. */ - private def mapToJson(map: Map[String, String]): String = { - if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else null - } + def addAppArg(s: String): this.type = { _appArgs += s; this } + def setSparkProperty(k: String, v: String): this.type = { _sparkProperties(k) = v; this } + def setEnvironmentVariable(k: String, v: String): this.type = { _envVars(k) = v; this } protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(appName, "appName") assertFieldIsSet(appResource, "appResource") + assertFieldIsMemory(driverMemory, "driverMemory") + assertFieldIsNumeric(driverCores, "driverCores") + assertFieldIsBoolean(superviseDriver, "superviseDriver") + assertFieldIsMemory(executorMemory, "executorMemory") + assertFieldIsNumeric(totalExecutorCores, "totalExecutorCores") } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala index 1ac769f3110a..d2b60aac2f74 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala @@ -21,7 +21,5 @@ package org.apache.spark.deploy.rest * A response to the [[SubmitDriverRequest]] in the REST application submission protocol. */ class SubmitDriverResponse extends SubmitRestProtocolResponse { - private val driverId = new SubmitRestProtocolField[String] - def getDriverId: String = driverId.toString - def setDriverId(s: String): this.type = setField(driverId, s) + var driverId: String = null } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala index ea056351bcf8..9af29a41e228 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala @@ -100,7 +100,7 @@ private[spark] abstract class SubmitRestClient extends Logging { try { response.validate() response match { - case e: ErrorResponse => logError(s"Server responded with error:\n${e.getMessage}") + case e: ErrorResponse => logError(s"Server responded with error:\n${e.message}") case _ => } } catch { diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala deleted file mode 100644 index 3e7208c4d8c3..000000000000 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.rest - -/** - * A field used in [[SubmitRestProtocolMessage]]s. - */ -class SubmitRestProtocolField[T] { - protected var value: Option[T] = None - def isSet: Boolean = value.isDefined - def getValue: Option[T] = value - def setValue(v: T): Unit = { value = Some(v) } - def clearValue(): Unit = { value = None } - override def toString: String = value.map(_.toString).orNull -} diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala index d2f12a5e863d..82d7f86ecb90 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala @@ -17,10 +17,13 @@ package org.apache.spark.deploy.rest +import scala.util.Try + import com.fasterxml.jackson.annotation._ import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility import com.fasterxml.jackson.annotation.JsonInclude.Include import com.fasterxml.jackson.databind.{ObjectMapper, SerializationFeature} +import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ @@ -39,23 +42,14 @@ import org.apache.spark.util.Utils @JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY) @JsonPropertyOrder(alphabetic = true) abstract class SubmitRestProtocolMessage { + @JsonIgnore val messageType = Utils.getFormattedClassName(this) - protected val action: String = messageType - protected val sparkVersion: SubmitRestProtocolField[String] - protected val message = new SubmitRestProtocolField[String] - // Required for JSON de/serialization and not explicitly used - private def getAction: String = action - private def setAction(s: String): this.type = this + val action: String = messageType + var message: String = null - // Intended for the user and not for JSON de/serialization, which expects more specific keys - @JsonIgnore - def getSparkVersion: String - @JsonIgnore - def setSparkVersion(s: String): this.type - - def getMessage: String = message.toString - def setMessage(s: String): this.type = setField(message, s) + // For JSON deserialization + private def setAction(a: String): Unit = { } /** * Serialize the message to JSON. @@ -63,9 +57,7 @@ abstract class SubmitRestProtocolMessage { */ def toJson: String = { validate() - val mapper = new ObjectMapper - mapper.enable(SerializationFeature.INDENT_OUTPUT) - mapper.writeValueAsString(this) + SubmitRestProtocolMessage.mapper.writeValueAsString(this) } /** @@ -89,53 +81,52 @@ abstract class SubmitRestProtocolMessage { } /** Assert that the specified field is set in this message. */ - protected def assertFieldIsSet(field: SubmitRestProtocolField[_], name: String): Unit = { - if (!field.isSet) { + protected def assertFieldIsSet(value: String, name: String): Unit = { + if (value == null) { throw new SubmitRestMissingFieldException( s"Field '$name' is missing in message $messageType.") } } - /** - * Assert a condition when validating this message. - * If the assertion fails, throw a [[SubmitRestProtocolException]]. - */ - protected def assert(condition: Boolean, failMessage: String): Unit = { - if (!condition) { throw new SubmitRestProtocolException(failMessage) } - } - - /** Set the field to the given value, or clear the field if the value is null. */ - protected def setField(f: SubmitRestProtocolField[String], v: String): this.type = { - if (v == null) { f.clearValue() } else { f.setValue(v) } - this + /** Assert that the value of the specified field is a boolean. */ + protected def assertFieldIsBoolean(value: String, name: String): Unit = { + if (value != null) { + Try(value.toBoolean).getOrElse { + throw new SubmitRestProtocolException( + s"Field '$name' expected boolean value: actual was '$value'.") + } + } } - /** - * Set the field to the given boolean value, or clear the field if the value is null. - * If the provided value does not represent a boolean, throw an exception. - */ - protected def setBooleanField(f: SubmitRestProtocolField[Boolean], v: String): this.type = { - if (v == null) { f.clearValue() } else { f.setValue(v.toBoolean) } - this + /** Assert that the value of the specified field is a numeric. */ + protected def assertFieldIsNumeric(value: String, name: String): Unit = { + if (value != null) { + Try(value.toInt).getOrElse { + throw new SubmitRestProtocolException( + s"Field '$name' expected numeric value: actual was '$value'.") + } + } } /** - * Set the field to the given numeric value, or clear the field if the value is null. - * If the provided value does not represent a numeric, throw an exception. + * Assert that the value of the specified field is a memory string. + * Examples of valid memory strings include 3g, 512m, 128k, 4096. */ - protected def setNumericField(f: SubmitRestProtocolField[Int], v: String): this.type = { - if (v == null) { f.clearValue() } else { f.setValue(v.toInt) } - this + protected def assertFieldIsMemory(value: String, name: String): Unit = { + if (value != null) { + Try(Utils.memoryStringToMb(value)).getOrElse { + throw new SubmitRestProtocolException( + s"Field '$name' expected memory value: actual was '$value'.") + } + } } /** - * Set the field to the given memory value, or clear the field if the value is null. - * If the provided value does not represent a memory value, throw an exception. - * Valid examples of memory values include "512m", "24g", and "128000". + * Assert a condition when validating this message. + * If the assertion fails, throw a [[SubmitRestProtocolException]]. */ - protected def setMemoryField(f: SubmitRestProtocolField[String], v: String): this.type = { - Utils.memoryStringToMb(v) - setField(f, v) + protected def assert(condition: Boolean, failMessage: String): Unit = { + if (!condition) { throw new SubmitRestProtocolException(failMessage) } } } @@ -143,14 +134,10 @@ abstract class SubmitRestProtocolMessage { * An abstract request sent from the client in the REST application submission protocol. */ abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { - protected override val sparkVersion = new SubmitRestProtocolField[String] - def getClientSparkVersion: String = sparkVersion.toString - def setClientSparkVersion(s: String): this.type = setField(sparkVersion, s) - override def getSparkVersion: String = getClientSparkVersion - override def setSparkVersion(s: String) = setClientSparkVersion(s) + var clientSparkVersion: String = null protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(sparkVersion, "clientSparkVersion") + assertFieldIsSet(clientSparkVersion, "clientSparkVersion") } } @@ -158,27 +145,21 @@ abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage { * An abstract response sent from the server in the REST application submission protocol. */ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage { - protected override val sparkVersion = new SubmitRestProtocolField[String] - private val success = new SubmitRestProtocolField[Boolean] - - override def getSparkVersion: String = getServerSparkVersion - def getServerSparkVersion: String = sparkVersion.toString - def getSuccess: String = success.toString - - override def setSparkVersion(s: String) = setServerSparkVersion(s) - def setServerSparkVersion(s: String): this.type = setField(sparkVersion, s) - def setSuccess(s: String): this.type = setBooleanField(success, s) - + var serverSparkVersion: String = null + var success: String = null protected override def doValidate(): Unit = { super.doValidate() - assertFieldIsSet(sparkVersion, "serverSparkVersion") + assertFieldIsSet(serverSparkVersion, "serverSparkVersion") assertFieldIsSet(success, "success") + assertFieldIsBoolean(success, "success") } } object SubmitRestProtocolMessage { - private val mapper = new ObjectMapper private val packagePrefix = this.getClass.getPackage.getName + private val mapper = new ObjectMapper() + .registerModule(DefaultScalaModule) + .enable(SerializationFeature.INDENT_OUTPUT) /** * Parse the value of the action field from the given JSON. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala index 027f641aa8f9..6c2a3a159da8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestServer.scala @@ -137,9 +137,10 @@ private[spark] abstract class SubmitRestServerHandler extends AbstractHandler wi /** Construct an error message to signal the fact that an exception has been thrown. */ private def handleError(message: String): ErrorResponse = { - new ErrorResponse() - .setSparkVersion(sparkVersion) - .setMessage(message) + val e = new ErrorResponse + e.serverSparkVersion = sparkVersion + e.message = message + e } /** Return a human readable String representation of the exception. */ diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala index 2a147c77c9cd..fa994118883f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestProtocolSuite.scala @@ -61,7 +61,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B test("kill empty driver") { val response = client.killDriver(masterRestUrl, "driver-that-does-not-exist") val killResponse = getKillResponse(response) - val killSuccess = killResponse.getSuccess + val killSuccess = killResponse.success assert(killSuccess === "false") } @@ -72,12 +72,12 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B val driverId = submitApplication(resultsFile, numbers, size) val response = client.killDriver(masterRestUrl, driverId) val killResponse = getKillResponse(response) - val killSuccess = killResponse.getSuccess + val killSuccess = killResponse.success waitUntilFinished(driverId) val response2 = client.requestDriverStatus(masterRestUrl, driverId) val statusResponse = getStatusResponse(response2) - val statusSuccess = statusResponse.getSuccess - val driverState = statusResponse.getDriverState + val statusSuccess = statusResponse.success + val driverState = statusResponse.driverState assert(killSuccess === "true") assert(statusSuccess === "true") assert(driverState === DriverState.KILLED.toString) @@ -88,7 +88,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B test("request status for empty driver") { val response = client.requestDriverStatus(masterRestUrl, "driver-that-does-not-exist") val statusResponse = getStatusResponse(response) - val statusSuccess = statusResponse.getSuccess + val statusSuccess = statusResponse.success assert(statusSuccess === "false") } @@ -131,7 +131,9 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B SparkSubmit.prepareSubmitEnvironment(args) val response = client.submitDriver(args) val submitResponse = getSubmitResponse(response) - submitResponse.getDriverId + val driverId = submitResponse.driverId + assert(driverId != null, "Application submission was unsuccessful!") + driverId } /** Wait until the given driver has finished running up to the specified timeout. */ @@ -141,7 +143,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B while (!finished) { val response = client.requestDriverStatus(masterRestUrl, driverId) val statusResponse = getStatusResponse(response) - val driverState = statusResponse.getDriverState + val driverState = statusResponse.driverState finished = driverState != DriverState.SUBMITTED.toString && driverState != DriverState.RUNNING.toString @@ -155,7 +157,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B private def getSubmitResponse(response: SubmitRestProtocolResponse): SubmitDriverResponse = { response match { case s: SubmitDriverResponse => s - case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}") + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") case r => fail(s"Expected submit response. Actual: ${r.toJson}") } } @@ -164,7 +166,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B private def getKillResponse(response: SubmitRestProtocolResponse): KillDriverResponse = { response match { case k: KillDriverResponse => k - case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}") + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") case r => fail(s"Expected kill response. Actual: ${r.toJson}") } } @@ -173,7 +175,7 @@ class StandaloneRestProtocolSuite extends FunSuite with BeforeAndAfterAll with B private def getStatusResponse(response: SubmitRestProtocolResponse): DriverStatusResponse = { response match { case s: DriverStatusResponse => s - case e: ErrorResponse => fail(s"Server returned error: ${e.toJson}") + case e: ErrorResponse => fail(s"Server returned error: ${e.message}") case r => fail(s"Expected status response. Actual: ${r.toJson}") } } diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala index f00f7848befc..97158a2cecda 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala @@ -25,124 +25,106 @@ import org.scalatest.FunSuite */ class SubmitRestProtocolSuite extends FunSuite { - test("get and set fields") { - val request = new DummyRequest - assert(request.getSparkVersion === null) - assert(request.getMessage === null) - assert(request.getActive === null) - assert(request.getAge === null) - assert(request.getName === null) - request.setSparkVersion("1.2.3") - request.setActive("true") - request.setAge("10") - request.setName("dolphin") - assert(request.getSparkVersion === "1.2.3") - assert(request.getMessage === null) - assert(request.getActive === "true") - assert(request.getAge === "10") - assert(request.getName === "dolphin") - // overwrite - request.setName("shark") - request.setActive("false") - assert(request.getName === "shark") - assert(request.getActive === "false") - } - - test("get and set fields with null values") { - val request = new DummyRequest - request.setSparkVersion(null) - request.setActive(null) - request.setAge(null) - request.setName(null) - request.setMessage(null) - assert(request.getSparkVersion === null) - assert(request.getMessage === null) - assert(request.getActive === null) - assert(request.getAge === null) - assert(request.getName === null) - } - - test("set fields with illegal argument") { - val request = new DummyRequest - intercept[IllegalArgumentException] { request.setActive("not-a-boolean") } - intercept[IllegalArgumentException] { request.setActive("150") } - intercept[IllegalArgumentException] { request.setAge("not-a-number") } - intercept[IllegalArgumentException] { request.setAge("true") } - } - test("validate") { val request = new DummyRequest intercept[SubmitRestProtocolException] { request.validate() } // missing everything - request.setSparkVersion("1.4.8") + request.clientSparkVersion = "1.2.3" intercept[SubmitRestProtocolException] { request.validate() } // missing name and age - request.setName("something") + request.name = "something" intercept[SubmitRestProtocolException] { request.validate() } // missing only age - request.setAge("2") + request.age = "2" intercept[SubmitRestProtocolException] { request.validate() } // age too low - request.setAge("10") - request.validate() // everything is set - request.setSparkVersion(null) + request.age = "10" + request.validate() // everything is set properly + request.clientSparkVersion = null intercept[SubmitRestProtocolException] { request.validate() } // missing only Spark version - request.setSparkVersion("1.2.3") - request.setName(null) + request.clientSparkVersion = "1.2.3" + request.name = null intercept[SubmitRestProtocolException] { request.validate() } // missing only name - request.setMessage("not-setting-name") + request.message = "not-setting-name" intercept[SubmitRestProtocolException] { request.validate() } // still missing name } + test("validate with illegal argument") { + val request = new DummyRequest + request.clientSparkVersion = "1.2.3" + request.name = "abc" + request.age = "not-a-number" + intercept[SubmitRestProtocolException] { request.validate() } + request.age = "true" + intercept[SubmitRestProtocolException] { request.validate() } + request.age = "150" + request.validate() + request.active = "not-a-boolean" + intercept[SubmitRestProtocolException] { request.validate() } + request.active = "150" + intercept[SubmitRestProtocolException] { request.validate() } + request.active = "true" + request.validate() + } + test("request to and from JSON") { - val request = new DummyRequest() - .setSparkVersion("1.2.3") - .setActive("true") - .setAge("25") - .setName("jung") + val request = new DummyRequest + intercept[SubmitRestProtocolException] { request.toJson } // implicit validation + request.clientSparkVersion = "1.2.3" + request.active = "true" + request.age = "25" + request.name = "jung" val json = request.toJson assertJsonEquals(json, dummyRequestJson) val newRequest = SubmitRestProtocolMessage.fromJson(json, classOf[DummyRequest]) - assert(newRequest.getSparkVersion === "1.2.3") - assert(newRequest.getClientSparkVersion === "1.2.3") - assert(newRequest.getActive === "true") - assert(newRequest.getAge === "25") - assert(newRequest.getName === "jung") - assert(newRequest.getMessage === null) + assert(newRequest.clientSparkVersion === "1.2.3") + assert(newRequest.clientSparkVersion === "1.2.3") + assert(newRequest.active === "true") + assert(newRequest.age === "25") + assert(newRequest.name === "jung") + assert(newRequest.message === null) } test("response to and from JSON") { - val response = new DummyResponse() - .setSparkVersion("3.3.4") - .setSuccess("true") + val response = new DummyResponse + response.serverSparkVersion = "3.3.4" + response.success = "true" val json = response.toJson assertJsonEquals(json, dummyResponseJson) val newResponse = SubmitRestProtocolMessage.fromJson(json, classOf[DummyResponse]) - assert(newResponse.getSparkVersion === "3.3.4") - assert(newResponse.getServerSparkVersion === "3.3.4") - assert(newResponse.getSuccess === "true") - assert(newResponse.getMessage === null) + assert(newResponse.serverSparkVersion === "3.3.4") + assert(newResponse.serverSparkVersion === "3.3.4") + assert(newResponse.success === "true") + assert(newResponse.message === null) } test("SubmitDriverRequest") { val message = new SubmitDriverRequest intercept[SubmitRestProtocolException] { message.validate() } - intercept[IllegalArgumentException] { message.setDriverCores("one hundred feet") } - intercept[IllegalArgumentException] { message.setSuperviseDriver("nope, never") } - intercept[IllegalArgumentException] { message.setTotalExecutorCores("two men") } - message.setSparkVersion("1.2.3") - message.setAppName("SparkPie") - message.setAppResource("honey-walnut-cherry.jar") + message.clientSparkVersion = "1.2.3" + message.appName = "SparkPie" + message.appResource = "honey-walnut-cherry.jar" message.validate() // optional fields - message.setMainClass("org.apache.spark.examples.SparkPie") - message.setJars("mayonnaise.jar,ketchup.jar") - message.setFiles("fireball.png") - message.setPyFiles("do-not-eat-my.py") - message.setDriverMemory("512m") - message.setDriverCores("180") - message.setDriverExtraJavaOptions(" -Dslices=5 -Dcolor=mostly_red") - message.setDriverExtraClassPath("food-coloring.jar") - message.setDriverExtraLibraryPath("pickle.jar") - message.setSuperviseDriver("false") - message.setExecutorMemory("256m") - message.setTotalExecutorCores("10000") + message.mainClass = "org.apache.spark.examples.SparkPie" + message.jars = "mayonnaise.jar,ketchup.jar" + message.files = "fireball.png" + message.pyFiles = "do-not-eat-my.py" + message.driverMemory = "512m" + message.driverCores = "180" + message.driverExtraJavaOptions = " -Dslices=5 -Dcolor=mostly_red" + message.driverExtraClassPath = "food-coloring.jar" + message.driverExtraLibraryPath = "pickle.jar" + message.superviseDriver = "false" + message.executorMemory = "256m" + message.totalExecutorCores = "10000" + message.validate() + // bad fields + message.driverCores = "one hundred feet" + intercept[SubmitRestProtocolException] { message.validate() } + message.driverCores = "180" + message.superviseDriver = "nope, never" + intercept[SubmitRestProtocolException] { message.validate() } + message.superviseDriver = "false" + message.totalExecutorCores = "two men" + intercept[SubmitRestProtocolException] { message.validate() } + message.totalExecutorCores = "10000" // special fields message.addAppArg("two slices") message.addAppArg("a hint of cinnamon") @@ -150,142 +132,144 @@ class SubmitRestProtocolSuite extends FunSuite { message.setSparkProperty("spark.shuffle.enabled", "false") message.setEnvironmentVariable("PATH", "/dev/null") message.setEnvironmentVariable("PYTHONPATH", "/dev/null") - assert(message.getAppArgs === Seq("two slices", "a hint of cinnamon")) - assert(message.getSparkProperties.size === 2) - assert(message.getSparkProperties("spark.live.long") === "true") - assert(message.getSparkProperties("spark.shuffle.enabled") === "false") - assert(message.getEnvironmentVariables.size === 2) - assert(message.getEnvironmentVariables("PATH") === "/dev/null") - assert(message.getEnvironmentVariables("PYTHONPATH") === "/dev/null") + assert(message.appArgs === Seq("two slices", "a hint of cinnamon")) + assert(message.sparkProperties.size === 2) + assert(message.sparkProperties("spark.live.long") === "true") + assert(message.sparkProperties("spark.shuffle.enabled") === "false") + assert(message.environmentVariables.size === 2) + assert(message.environmentVariables("PATH") === "/dev/null") + assert(message.environmentVariables("PYTHONPATH") === "/dev/null") // test JSON val json = message.toJson assertJsonEquals(json, submitDriverRequestJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmitDriverRequest]) - assert(newMessage.getSparkVersion === "1.2.3") - assert(newMessage.getClientSparkVersion === "1.2.3") - assert(newMessage.getAppName === "SparkPie") - assert(newMessage.getAppResource === "honey-walnut-cherry.jar") - assert(newMessage.getMainClass === "org.apache.spark.examples.SparkPie") - assert(newMessage.getJars === "mayonnaise.jar,ketchup.jar") - assert(newMessage.getFiles === "fireball.png") - assert(newMessage.getPyFiles === "do-not-eat-my.py") - assert(newMessage.getDriverMemory === "512m") - assert(newMessage.getDriverCores === "180") - assert(newMessage.getDriverExtraJavaOptions === " -Dslices=5 -Dcolor=mostly_red") - assert(newMessage.getDriverExtraClassPath === "food-coloring.jar") - assert(newMessage.getDriverExtraLibraryPath === "pickle.jar") - assert(newMessage.getSuperviseDriver === "false") - assert(newMessage.getExecutorMemory === "256m") - assert(newMessage.getTotalExecutorCores === "10000") - assert(newMessage.getAppArgs === message.getAppArgs) - assert(newMessage.getSparkProperties === message.getSparkProperties) - assert(newMessage.getEnvironmentVariables === message.getEnvironmentVariables) + assert(newMessage.clientSparkVersion === "1.2.3") + assert(newMessage.appName === "SparkPie") + assert(newMessage.appResource === "honey-walnut-cherry.jar") + assert(newMessage.mainClass === "org.apache.spark.examples.SparkPie") + assert(newMessage.jars === "mayonnaise.jar,ketchup.jar") + assert(newMessage.files === "fireball.png") + assert(newMessage.pyFiles === "do-not-eat-my.py") + assert(newMessage.driverMemory === "512m") + assert(newMessage.driverCores === "180") + assert(newMessage.driverExtraJavaOptions === " -Dslices=5 -Dcolor=mostly_red") + assert(newMessage.driverExtraClassPath === "food-coloring.jar") + assert(newMessage.driverExtraLibraryPath === "pickle.jar") + assert(newMessage.superviseDriver === "false") + assert(newMessage.executorMemory === "256m") + assert(newMessage.totalExecutorCores === "10000") + assert(newMessage.appArgs === message.appArgs) + assert(newMessage.sparkProperties === message.sparkProperties) + assert(newMessage.environmentVariables === message.environmentVariables) } test("SubmitDriverResponse") { val message = new SubmitDriverResponse intercept[SubmitRestProtocolException] { message.validate() } - intercept[IllegalArgumentException] { message.setSuccess("maybe not") } - message.setSparkVersion("1.2.3") - message.setDriverId("driver_123") - message.setSuccess("true") + message.serverSparkVersion = "1.2.3" + message.driverId = "driver_123" + message.success = "true" message.validate() + // bad fields + message.success = "maybe not" + intercept[SubmitRestProtocolException] { message.validate() } + message.success = "true" // test JSON val json = message.toJson assertJsonEquals(json, submitDriverResponseJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[SubmitDriverResponse]) - assert(newMessage.getSparkVersion === "1.2.3") - assert(newMessage.getServerSparkVersion === "1.2.3") - assert(newMessage.getDriverId === "driver_123") - assert(newMessage.getSuccess === "true") + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.driverId === "driver_123") + assert(newMessage.success === "true") } test("KillDriverRequest") { val message = new KillDriverRequest intercept[SubmitRestProtocolException] { message.validate() } - message.setSparkVersion("1.2.3") - message.setDriverId("driver_123") + message.clientSparkVersion = "1.2.3" + message.driverId = "driver_123" message.validate() // test JSON val json = message.toJson assertJsonEquals(json, killDriverRequestJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillDriverRequest]) - assert(newMessage.getSparkVersion === "1.2.3") - assert(newMessage.getClientSparkVersion === "1.2.3") - assert(newMessage.getDriverId === "driver_123") + assert(newMessage.clientSparkVersion === "1.2.3") + assert(newMessage.driverId === "driver_123") } test("KillDriverResponse") { val message = new KillDriverResponse intercept[SubmitRestProtocolException] { message.validate() } - intercept[IllegalArgumentException] { message.setSuccess("maybe not") } - message.setSparkVersion("1.2.3") - message.setDriverId("driver_123") - message.setSuccess("true") + message.serverSparkVersion = "1.2.3" + message.driverId = "driver_123" + message.success = "true" message.validate() + // bad fields + message.success = "maybe not" + intercept[SubmitRestProtocolException] { message.validate() } + message.success = "true" // test JSON val json = message.toJson assertJsonEquals(json, killDriverResponseJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[KillDriverResponse]) - assert(newMessage.getSparkVersion === "1.2.3") - assert(newMessage.getServerSparkVersion === "1.2.3") - assert(newMessage.getDriverId === "driver_123") - assert(newMessage.getSuccess === "true") + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.driverId === "driver_123") + assert(newMessage.success === "true") } test("DriverStatusRequest") { val message = new DriverStatusRequest intercept[SubmitRestProtocolException] { message.validate() } - message.setSparkVersion("1.2.3") - message.setDriverId("driver_123") + message.clientSparkVersion = "1.2.3" + message.driverId = "driver_123" message.validate() // test JSON val json = message.toJson assertJsonEquals(json, driverStatusRequestJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[DriverStatusRequest]) - assert(newMessage.getSparkVersion === "1.2.3") - assert(newMessage.getClientSparkVersion === "1.2.3") - assert(newMessage.getDriverId === "driver_123") + assert(newMessage.clientSparkVersion === "1.2.3") + assert(newMessage.driverId === "driver_123") } test("DriverStatusResponse") { val message = new DriverStatusResponse intercept[SubmitRestProtocolException] { message.validate() } - intercept[IllegalArgumentException] { message.setSuccess("maybe") } - message.setSparkVersion("1.2.3") - message.setDriverId("driver_123") - message.setSuccess("true") + message.serverSparkVersion = "1.2.3" + message.driverId = "driver_123" + message.success = "true" message.validate() // optional fields - message.setDriverState("RUNNING") - message.setWorkerId("worker_123") - message.setWorkerHostPort("1.2.3.4:7780") + message.driverState = "RUNNING" + message.workerId = "worker_123" + message.workerHostPort = "1.2.3.4:7780" + // bad fields + message.success = "maybe" + intercept[SubmitRestProtocolException] { message.validate() } + message.success = "true" // test JSON val json = message.toJson assertJsonEquals(json, driverStatusResponseJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[DriverStatusResponse]) - assert(newMessage.getSparkVersion === "1.2.3") - assert(newMessage.getServerSparkVersion === "1.2.3") - assert(newMessage.getDriverId === "driver_123") - assert(newMessage.getDriverState === "RUNNING") - assert(newMessage.getSuccess === "true") - assert(newMessage.getWorkerId === "worker_123") - assert(newMessage.getWorkerHostPort === "1.2.3.4:7780") + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.driverId === "driver_123") + assert(newMessage.driverState === "RUNNING") + assert(newMessage.success === "true") + assert(newMessage.workerId === "worker_123") + assert(newMessage.workerHostPort === "1.2.3.4:7780") } test("ErrorResponse") { val message = new ErrorResponse intercept[SubmitRestProtocolException] { message.validate() } - message.setSparkVersion("1.2.3") - message.setMessage("Field not found in submit request: X") + message.serverSparkVersion = "1.2.3" + message.message = "Field not found in submit request: X" message.validate() // test JSON val json = message.toJson assertJsonEquals(json, errorJson) val newMessage = SubmitRestProtocolMessage.fromJson(json, classOf[ErrorResponse]) - assert(newMessage.getSparkVersion === "1.2.3") - assert(newMessage.getServerSparkVersion === "1.2.3") - assert(newMessage.getMessage === "Field not found in submit request: X") + assert(newMessage.serverSparkVersion === "1.2.3") + assert(newMessage.message === "Field not found in submit request: X") } private val dummyRequestJson = @@ -312,7 +296,7 @@ class SubmitRestProtocolSuite extends FunSuite { """ |{ | "action" : "SubmitDriverRequest", - | "appArgs" : "[\"two slices\",\"a hint of cinnamon\"]", + | "appArgs" : ["two slices","a hint of cinnamon"], | "appName" : "SparkPie", | "appResource" : "honey-walnut-cherry.jar", | "clientSparkVersion" : "1.2.3", @@ -321,13 +305,13 @@ class SubmitRestProtocolSuite extends FunSuite { | "driverExtraJavaOptions" : " -Dslices=5 -Dcolor=mostly_red", | "driverExtraLibraryPath" : "pickle.jar", | "driverMemory" : "512m", - | "environmentVariables" : "{\"PATH\":\"/dev/null\",\"PYTHONPATH\":\"/dev/null\"}", + | "environmentVariables" : {"PATH":"/dev/null","PYTHONPATH":"/dev/null"}, | "executorMemory" : "256m", | "files" : "fireball.png", | "jars" : "mayonnaise.jar,ketchup.jar", | "mainClass" : "org.apache.spark.examples.SparkPie", | "pyFiles" : "do-not-eat-my.py", - | "sparkProperties" : "{\"spark.live.long\":\"true\",\"spark.shuffle.enabled\":\"false\"}", + | "sparkProperties" : {"spark.live.long":"true","spark.shuffle.enabled":"false"}, | "superviseDriver" : "false", | "totalExecutorCores" : "10000" |} @@ -389,7 +373,8 @@ class SubmitRestProtocolSuite extends FunSuite { |{ | "action" : "ErrorResponse", | "message" : "Field not found in submit request: X", - | "serverSparkVersion" : "1.2.3" + | "serverSparkVersion" : "1.2.3", + | "success": "false" |} """.stripMargin @@ -407,22 +392,15 @@ class SubmitRestProtocolSuite extends FunSuite { private class DummyResponse extends SubmitRestProtocolResponse private class DummyRequest extends SubmitRestProtocolRequest { - private val active = new SubmitRestProtocolField[Boolean] - private val age = new SubmitRestProtocolField[Int] - private val name = new SubmitRestProtocolField[String] - - def getActive: String = active.toString - def getAge: String = age.toString - def getName: String = name.toString - - def setActive(s: String): this.type = setBooleanField(active, s) - def setAge(s: String): this.type = setNumericField(age, s) - def setName(s: String): this.type = setField(name, s) - + var active: String = null + var age: String = null + var name: String = null protected override def doValidate(): Unit = { super.doValidate() assertFieldIsSet(name, "name") assertFieldIsSet(age, "age") - assert(age.getValue.get > 5, "Not old enough!") + assertFieldIsBoolean(active, "active") + assertFieldIsNumeric(age, "age") + assert(age.toInt > 5, "Not old enough!") } } diff --git a/pom.xml b/pom.xml index 4adfdf3eb870..ac7cbaae90fd 100644 --- a/pom.xml +++ b/pom.xml @@ -150,6 +150,7 @@ ${scala.version} org.scala-lang 1.8.8 + 2.3.0 1.1.1.6