Skip to content

Commit 8d43486

Browse files
author
Andrew Or
committed
Replace SubmitRestProtocolAction with class name
This makes it easier for users to define their own messages. Rather than forcing the users to introduce an action class that extends our inflexible enum, they can now implement custom logic in their REST servers depending on the action of the request.
1 parent df90e8b commit 8d43486

File tree

10 files changed

+35
-71
lines changed

10 files changed

+35
-71
lines changed

core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusRequest.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@
1818
package org.apache.spark.deploy.rest
1919

2020
class DriverStatusRequest extends SubmitRestProtocolRequest {
21-
protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_REQUEST
2221
private val driverId = new SubmitRestProtocolField[String]
23-
2422
def getDriverId: String = driverId.toString
2523
def setDriverId(s: String): this.type = setField(driverId, s)
26-
2724
override def validate(): Unit = {
2825
super.validate()
2926
assertFieldIsSet(driverId, "driver_id")

core/src/main/scala/org/apache/spark/deploy/rest/DriverStatusResponse.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.deploy.rest
1919

2020
class DriverStatusResponse extends SubmitRestProtocolResponse {
21-
protected override val action = SubmitRestProtocolAction.DRIVER_STATUS_RESPONSE
2221
private val driverId = new SubmitRestProtocolField[String]
2322
private val success = new SubmitRestProtocolField[Boolean]
2423
private val driverState = new SubmitRestProtocolField[String]

core/src/main/scala/org/apache/spark/deploy/rest/ErrorResponse.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.deploy.rest
1919

2020
class ErrorResponse extends SubmitRestProtocolResponse {
21-
protected override val action = SubmitRestProtocolAction.ERROR
2221
override def validate(): Unit = {
2322
super.validate()
2423
assertFieldIsSet(message, "message")

core/src/main/scala/org/apache/spark/deploy/rest/KillDriverRequest.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,9 @@
1818
package org.apache.spark.deploy.rest
1919

2020
class KillDriverRequest extends SubmitRestProtocolRequest {
21-
protected override val action = SubmitRestProtocolAction.KILL_DRIVER_REQUEST
2221
private val driverId = new SubmitRestProtocolField[String]
23-
2422
def getDriverId: String = driverId.toString
2523
def setDriverId(s: String): this.type = setField(driverId, s)
26-
2724
override def validate(): Unit = {
2825
super.validate()
2926
assertFieldIsSet(driverId, "driver_id")

core/src/main/scala/org/apache/spark/deploy/rest/KillDriverResponse.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.deploy.rest
1919

2020
class KillDriverResponse extends SubmitRestProtocolResponse {
21-
protected override val action = SubmitRestProtocolAction.KILL_DRIVER_RESPONSE
2221
private val driverId = new SubmitRestProtocolField[String]
2322
private val success = new SubmitRestProtocolField[Boolean]
2423

core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverRequest.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import org.json4s.jackson.JsonMethods._
2626
import org.apache.spark.util.JsonProtocol
2727

2828
class SubmitDriverRequest extends SubmitRestProtocolRequest {
29-
protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_REQUEST
3029
private val appName = new SubmitRestProtocolField[String]
3130
private val appResource = new SubmitRestProtocolField[String]
3231
private val mainClass = new SubmitRestProtocolField[String]
@@ -62,7 +61,7 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest {
6261
def getExecutorMemory: String = executorMemory.toString
6362
def getTotalExecutorCores: String = totalExecutorCores.toString
6463

65-
// Special getters required for JSON de/serialization
64+
// Special getters required for JSON serialization
6665
@JsonProperty("appArgs")
6766
private def getAppArgsJson: String = arrayToJson(getAppArgs)
6867
@JsonProperty("sparkProperties")
@@ -85,7 +84,7 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest {
8584
def setExecutorMemory(s: String): this.type = setField(executorMemory, s)
8685
def setTotalExecutorCores(s: String): this.type = setNumericField(totalExecutorCores, s)
8786

88-
// Special setters required for JSON de/serialization
87+
// Special setters required for JSON deserialization
8988
@JsonProperty("appArgs")
9089
private def setAppArgsJson(s: String): Unit = {
9190
appArgs.clear()
@@ -116,11 +115,11 @@ class SubmitDriverRequest extends SubmitRestProtocolRequest {
116115
def setEnvironmentVariable(k: String, v: String): this.type = { envVars(k) = v; this }
117116

118117
private def arrayToJson(arr: Array[String]): String = {
119-
if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else { null }
118+
if (arr.nonEmpty) { compact(render(JsonProtocol.arrayToJson(arr))) } else null
120119
}
121120

122121
private def mapToJson(map: Map[String, String]): String = {
123-
if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else { null }
122+
if (map.nonEmpty) { compact(render(JsonProtocol.mapToJson(map))) } else null
124123
}
125124

126125
override def validate(): Unit = {

core/src/main/scala/org/apache/spark/deploy/rest/SubmitDriverResponse.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.deploy.rest
1919

2020
class SubmitDriverResponse extends SubmitRestProtocolResponse {
21-
protected override val action = SubmitRestProtocolAction.SUBMIT_DRIVER_RESPONSE
2221
private val success = new SubmitRestProtocolField[Boolean]
2322
private val driverId = new SubmitRestProtocolField[String]
2423

core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolField.scala

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,6 @@
1717

1818
package org.apache.spark.deploy.rest
1919

20-
/**
21-
* All possible values of the ACTION field in a SubmitRestProtocolMessage.
22-
*/
23-
abstract class SubmitRestProtocolAction
24-
object SubmitRestProtocolAction {
25-
case object SUBMIT_DRIVER_REQUEST extends SubmitRestProtocolAction
26-
case object SUBMIT_DRIVER_RESPONSE extends SubmitRestProtocolAction
27-
case object KILL_DRIVER_REQUEST extends SubmitRestProtocolAction
28-
case object KILL_DRIVER_RESPONSE extends SubmitRestProtocolAction
29-
case object DRIVER_STATUS_REQUEST extends SubmitRestProtocolAction
30-
case object DRIVER_STATUS_RESPONSE extends SubmitRestProtocolAction
31-
case object ERROR extends SubmitRestProtocolAction
32-
private val allActions =
33-
Seq(SUBMIT_DRIVER_REQUEST, SUBMIT_DRIVER_RESPONSE, KILL_DRIVER_REQUEST,
34-
KILL_DRIVER_RESPONSE, DRIVER_STATUS_REQUEST, DRIVER_STATUS_RESPONSE, ERROR)
35-
private val allActionsMap = allActions.map { a => (a.toString, a) }.toMap
36-
37-
def fromString(action: String): SubmitRestProtocolAction = {
38-
allActionsMap.get(action).getOrElse {
39-
throw new IllegalArgumentException(s"Unknown action $action")
40-
}
41-
}
42-
}
43-
4420
class SubmitRestProtocolField[T] {
4521
protected var value: Option[T] = None
4622
def isSet: Boolean = value.isDefined

core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ import org.json4s.JsonAST._
2525
import org.json4s.jackson.JsonMethods._
2626

2727
import org.apache.spark.util.Utils
28-
import org.apache.spark.deploy.rest.SubmitRestProtocolAction._
2928

3029
@JsonInclude(Include.NON_NULL)
3130
@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY)
@@ -34,12 +33,12 @@ abstract class SubmitRestProtocolMessage {
3433
import SubmitRestProtocolMessage._
3534

3635
private val messageType = Utils.getFormattedClassName(this)
37-
protected val action: SubmitRestProtocolAction
36+
protected val action: String = camelCaseToUnderscores(decapitalize(messageType))
3837
protected val sparkVersion = new SubmitRestProtocolField[String]
3938
protected val message = new SubmitRestProtocolField[String]
4039

4140
// Required for JSON de/serialization and not explicitly used
42-
private def getAction: String = action.toString
41+
private def getAction: String = action
4342
private def setAction(s: String): this.type = this
4443

4544
// Spark version implementation depends on whether this is a request or a response
@@ -124,24 +123,22 @@ abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
124123

125124
object SubmitRestProtocolMessage {
126125
private val mapper = new ObjectMapper
126+
private val packagePrefix = this.getClass.getPackage.getName
127127

128-
def fromJson(json: String): SubmitRestProtocolMessage = {
129-
val fields = parse(json).asInstanceOf[JObject].obj
130-
val action = fields
128+
def parseAction(json: String): String = {
129+
parse(json).asInstanceOf[JObject].obj
131130
.find { case (f, _) => f == "action" }
132131
.map { case (_, v) => v.asInstanceOf[JString].s }
133132
.getOrElse {
134-
throw new IllegalArgumentException(s"Could not find action field in message:\n$json")
135-
}
136-
val clazz = SubmitRestProtocolAction.fromString(action) match {
137-
case SUBMIT_DRIVER_REQUEST => classOf[SubmitDriverRequest]
138-
case SUBMIT_DRIVER_RESPONSE => classOf[SubmitDriverResponse]
139-
case KILL_DRIVER_REQUEST => classOf[KillDriverRequest]
140-
case KILL_DRIVER_RESPONSE => classOf[KillDriverResponse]
141-
case DRIVER_STATUS_REQUEST => classOf[DriverStatusRequest]
142-
case DRIVER_STATUS_RESPONSE => classOf[DriverStatusResponse]
143-
case ERROR => classOf[ErrorResponse]
133+
throw new IllegalArgumentException(s"Could not find action field in message:\n$json")
144134
}
135+
}
136+
137+
def fromJson(json: String): SubmitRestProtocolMessage = {
138+
val action = parseAction(json)
139+
val className = underscoresToCamelCase(action).capitalize
140+
val clazz = Class.forName(packagePrefix + "." + className)
141+
.asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage])
145142
fromJson(json, clazz)
146143
}
147144

@@ -178,6 +175,14 @@ object SubmitRestProtocolMessage {
178175
}
179176
newString.toString()
180177
}
178+
179+
private def decapitalize(s: String): String = {
180+
if (s != null && s.nonEmpty) {
181+
s(0).toLower + s.substring(1)
182+
} else {
183+
s
184+
}
185+
}
181186
}
182187

183188
object SubmitRestProtocolRequest {

core/src/test/scala/org/apache/spark/deploy/rest/SubmitRestProtocolSuite.scala

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,7 @@ package org.apache.spark.deploy.rest
2020
import org.json4s.jackson.JsonMethods._
2121
import org.scalatest.FunSuite
2222

23-
case object DUMMY_REQUEST extends SubmitRestProtocolAction
24-
case object DUMMY_RESPONSE extends SubmitRestProtocolAction
25-
2623
class DummyRequest extends SubmitRestProtocolRequest {
27-
protected override val action = DUMMY_REQUEST
2824
private val active = new SubmitRestProtocolField[Boolean]
2925
private val age = new SubmitRestProtocolField[Int]
3026
private val name = new SubmitRestProtocolField[String]
@@ -45,9 +41,7 @@ class DummyRequest extends SubmitRestProtocolRequest {
4541
}
4642
}
4743

48-
class DummyResponse extends SubmitRestProtocolResponse {
49-
protected override val action = DUMMY_RESPONSE
50-
}
44+
class DummyResponse extends SubmitRestProtocolResponse
5145

5246
/**
5347
* Tests for the stable application submission REST protocol.
@@ -325,7 +319,7 @@ class SubmitRestProtocolSuite extends FunSuite {
325319
private val dummyRequestJson =
326320
"""
327321
|{
328-
| "action" : "DUMMY_REQUEST",
322+
| "action" : "dummy_request",
329323
| "active" : "true",
330324
| "age" : "25",
331325
| "client_spark_version" : "1.2.3",
@@ -336,15 +330,15 @@ class SubmitRestProtocolSuite extends FunSuite {
336330
private val dummyResponseJson =
337331
"""
338332
|{
339-
| "action" : "DUMMY_RESPONSE",
333+
| "action" : "dummy_response",
340334
| "server_spark_version" : "3.3.4"
341335
|}
342336
""".stripMargin
343337

344338
private val submitDriverRequestJson =
345339
"""
346340
|{
347-
| "action" : "SUBMIT_DRIVER_REQUEST",
341+
| "action" : "submit_driver_request",
348342
| "app_args" : "[\"two slices\",\"a hint of cinnamon\"]",
349343
| "app_name" : "SparkPie",
350344
| "app_resource" : "honey-walnut-cherry.jar",
@@ -369,7 +363,7 @@ class SubmitRestProtocolSuite extends FunSuite {
369363
private val submitDriverResponseJson =
370364
"""
371365
|{
372-
| "action" : "SUBMIT_DRIVER_RESPONSE",
366+
| "action" : "submit_driver_response",
373367
| "driver_id" : "driver_123",
374368
| "server_spark_version" : "1.2.3",
375369
| "success" : "true"
@@ -379,7 +373,7 @@ class SubmitRestProtocolSuite extends FunSuite {
379373
private val killDriverRequestJson =
380374
"""
381375
|{
382-
| "action" : "KILL_DRIVER_REQUEST",
376+
| "action" : "kill_driver_request",
383377
| "client_spark_version" : "1.2.3",
384378
| "driver_id" : "driver_123"
385379
|}
@@ -388,7 +382,7 @@ class SubmitRestProtocolSuite extends FunSuite {
388382
private val killDriverResponseJson =
389383
"""
390384
|{
391-
| "action" : "KILL_DRIVER_RESPONSE",
385+
| "action" : "kill_driver_response",
392386
| "driver_id" : "driver_123",
393387
| "server_spark_version" : "1.2.3",
394388
| "success" : "true"
@@ -398,7 +392,7 @@ class SubmitRestProtocolSuite extends FunSuite {
398392
private val driverStatusRequestJson =
399393
"""
400394
|{
401-
| "action" : "DRIVER_STATUS_REQUEST",
395+
| "action" : "driver_status_request",
402396
| "client_spark_version" : "1.2.3",
403397
| "driver_id" : "driver_123"
404398
|}
@@ -407,7 +401,7 @@ class SubmitRestProtocolSuite extends FunSuite {
407401
private val driverStatusResponseJson =
408402
"""
409403
|{
410-
| "action" : "DRIVER_STATUS_RESPONSE",
404+
| "action" : "driver_status_response",
411405
| "driver_id" : "driver_123",
412406
| "driver_state" : "RUNNING",
413407
| "server_spark_version" : "1.2.3",
@@ -420,7 +414,7 @@ class SubmitRestProtocolSuite extends FunSuite {
420414
private val errorJson =
421415
"""
422416
|{
423-
| "action" : "ERROR",
417+
| "action" : "error_response",
424418
| "message" : "Field not found in submit request: X",
425419
| "server_spark_version" : "1.2.3"
426420
|}

0 commit comments

Comments
 (0)