Skip to content

Commit ade28fd

Browse files
author
Andrew Or
committed
Clean up REST response output in Spark submit
Now we don't log a response twice or log an error message twice. Also, before we would actually throw a ClassCastException if the server returns an error due to type erasure. This commit eases the relevant complexity involved.
1 parent b2fef8b commit ade28fd

File tree

4 files changed

+44
-29
lines changed

4 files changed

+44
-29
lines changed

core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ import java.net.URL
2323

2424
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
2525

26+
import org.apache.spark.deploy.rest._
2627
import org.apache.spark.executor.ExecutorURLClassLoader
2728
import org.apache.spark.util.Utils
28-
import org.apache.spark.deploy.rest.StandaloneRestClient
2929

3030
/**
3131
* Whether to submit, kill, or request the status of an application.
@@ -95,7 +95,10 @@ object SparkSubmit {
9595
private def kill(args: SparkSubmitArguments): Unit = {
9696
val client = new StandaloneRestClient
9797
val response = client.killDriver(args.master, args.driverToKill)
98-
printStream.println(response.toJson)
98+
response match {
99+
case k: KillDriverResponse => handleRestResponse(k)
100+
case r => handleUnexpectedRestResponse(r)
101+
}
99102
}
100103

101104
/**
@@ -105,7 +108,10 @@ object SparkSubmit {
105108
private def requestStatus(args: SparkSubmitArguments): Unit = {
106109
val client = new StandaloneRestClient
107110
val response = client.requestDriverStatus(args.master, args.driverToRequestStatusFor)
108-
printStream.println(response.toJson)
111+
response match {
112+
case s: DriverStatusResponse => handleRestResponse(s)
113+
case r => handleUnexpectedRestResponse(r)
114+
}
109115
}
110116

111117
/**
@@ -126,7 +132,12 @@ object SparkSubmit {
126132
val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)
127133
if (args.isStandaloneCluster && args.isRestEnabled) {
128134
printStream.println("Running Spark using the REST application submission protocol.")
129-
new StandaloneRestClient().submitDriver(args)
135+
val client = new StandaloneRestClient
136+
val response = client.submitDriver(args)
137+
response match {
138+
case s: SubmitDriverResponse => handleRestResponse(s)
139+
case r => handleUnexpectedRestResponse(r)
140+
}
130141
} else {
131142
runMain(childArgs, childClasspath, sysProps, childMainClass)
132143
}
@@ -461,6 +472,17 @@ object SparkSubmit {
461472
}
462473
}
463474

475+
/** Log the response sent by the server in the REST application submission protocol. */
476+
private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = {
477+
printStream.println(s"Server responded with ${response.messageType}:\n${response.toJson}")
478+
}
479+
480+
/** Log an appropriate error if the response sent by the server is not of the expected type. */
481+
private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = {
482+
printStream.println(
483+
s"Error: Server responded with message of unexpected type ${unexpected.messageType}.")
484+
}
485+
464486
/**
465487
* Return whether the given primary resource represents a user jar.
466488
*/

core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
3939
override def submitDriver(args: SparkSubmitArguments): SubmitRestProtocolResponse = {
4040
validateSubmitArgs(args)
4141
val response = super.submitDriver(args)
42-
val submitResponse = getResponse[SubmitDriverResponse](response).getOrElse { return response }
42+
val submitResponse = response match {
43+
case s: SubmitDriverResponse => s
44+
case _ => return response
45+
}
4346
val submitSuccess = submitResponse.getSuccess.toBoolean
4447
if (submitSuccess) {
4548
val driverId = submitResponse.getDriverId
@@ -71,7 +74,10 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
7174
private def pollSubmittedDriverStatus(master: String, driverId: String): Unit = {
7275
(1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ =>
7376
val response = requestDriverStatus(master, driverId)
74-
val statusResponse = getResponse[DriverStatusResponse](response).getOrElse { return }
77+
val statusResponse = response match {
78+
case s: DriverStatusResponse => s
79+
case _ => return
80+
}
7581
val statusSuccess = statusResponse.getSuccess.toBoolean
7682
if (statusSuccess) {
7783
val driverState = Option(statusResponse.getDriverState)
@@ -160,23 +166,6 @@ private[spark] class StandaloneRestClient extends SubmitRestClient {
160166
"This REST client is only supported in standalone cluster mode.")
161167
}
162168
}
163-
164-
/**
165-
* Return the response as the expected type, or fail with an informative error message.
166-
* Exposed for testing.
167-
*/
168-
private[spark] def getResponse[T <: SubmitRestProtocolResponse](
169-
response: SubmitRestProtocolResponse): Option[T] = {
170-
try {
171-
// Do not match on type T because types are erased at runtime
172-
// Instead, manually try to cast it to type T ourselves
173-
Some(response.asInstanceOf[T])
174-
} catch {
175-
case e: ClassCastException =>
176-
logError(s"Server returned response of unexpected type:\n${response.toJson}")
177-
None
178-
}
179-
}
180169
}
181170

182171
private object StandaloneRestClient {

core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestClient.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
3939
val url = getHttpUrl(args.master)
4040
val request = constructSubmitRequest(args)
4141
val response = sendHttp(url, request)
42-
validateResponse(response)
42+
handleResponse(response)
4343
}
4444

4545
/** Request that the REST server kill the specified driver. */
@@ -48,7 +48,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
4848
val url = getHttpUrl(master)
4949
val request = constructKillRequest(master, driverId)
5050
val response = sendHttp(url, request)
51-
validateResponse(response)
51+
handleResponse(response)
5252
}
5353

5454
/** Request the status of the specified driver from the REST server. */
@@ -57,7 +57,7 @@ private[spark] abstract class SubmitRestClient extends Logging {
5757
val url = getHttpUrl(master)
5858
val request = constructStatusRequest(master, driverId)
5959
val response = sendHttp(url, request)
60-
validateResponse(response)
60+
handleResponse(response)
6161
}
6262

6363
/** Return the HTTP URL of the REST server that corresponds to the given master URL. */
@@ -95,10 +95,14 @@ private[spark] abstract class SubmitRestClient extends Logging {
9595
}
9696
}
9797

98-
/** Validate the response... */
99-
private def validateResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = {
98+
/** Validate the response and log any error messages provided by the server. */
99+
private def handleResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = {
100100
try {
101101
response.validate()
102+
response match {
103+
case e: ErrorResponse => logError(s"Server responded with error:\n${e.getMessage}")
104+
case _ =>
105+
}
102106
} catch {
103107
case e: SubmitRestProtocolException =>
104108
throw new SubmitRestProtocolException("Malformed response received from server", e)

core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ import org.apache.spark.util.Utils
3939
@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY)
4040
@JsonPropertyOrder(alphabetic = true)
4141
abstract class SubmitRestProtocolMessage {
42-
private val messageType = Utils.getFormattedClassName(this)
42+
val messageType = Utils.getFormattedClassName(this)
4343
protected val action: String = messageType
4444
protected val sparkVersion: SubmitRestProtocolField[String]
4545
protected val message = new SubmitRestProtocolField[String]("message")

0 commit comments

Comments
 (0)