Skip to content

Commit 792e112

Browse files
author
Andrew Or
committed
Use specific HTTP response codes on error
1 parent f98660b commit 792e112

File tree

1 file changed

+106
-98
lines changed

1 file changed

+106
-98
lines changed

core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala

Lines changed: 106 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest}
2323

2424
import scala.io.Source
2525

26-
import akka.actor.ActorRef
2726
import com.google.common.base.Charsets
2827
import org.eclipse.jetty.server.Server
2928
import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
@@ -38,11 +37,12 @@ import org.apache.spark.deploy.master.Master
3837
/**
3938
* A server that responds to requests submitted by the [[StandaloneRestClient]].
4039
* This is intended to be embedded in the standalone Master and used in cluster mode only.
40+
*
41+
* When an error occurs, this server sends an error response with an appropriate message
42+
* back to the client. If the construction of this error message itself is faulty, the
43+
* server indicates internal error through the response code.
4144
*/
42-
private[spark] class StandaloneRestServer(
43-
master: Master,
44-
host: String,
45-
requestedPort: Int)
45+
private[spark] class StandaloneRestServer(master: Master, host: String, requestedPort: Int)
4646
extends Logging {
4747

4848
import StandaloneRestServer._
@@ -98,34 +98,36 @@ private object StandaloneRestServer {
9898
/**
9999
* An abstract servlet for handling requests passed to the [[StandaloneRestServer]].
100100
*/
101-
private[spark] abstract class StandaloneRestServlet(master: Master)
102-
extends HttpServlet with Logging {
101+
private[spark] abstract class StandaloneRestServlet extends HttpServlet with Logging {
103102

104-
protected val conf: SparkConf = master.conf
105-
protected val masterActor: ActorRef = master.self
106-
protected val masterUrl: String = master.masterUrl
107-
protected val askTimeout = AkkaUtils.askTimeout(conf)
103+
/** Service a request. If an exception is thrown in the process, indicate server error. */
104+
protected override def service(
105+
request: HttpServletRequest,
106+
response: HttpServletResponse): Unit = {
107+
try {
108+
super.service(request, response)
109+
} catch {
110+
case e: Exception =>
111+
logError("Exception while handling request", e)
112+
response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
113+
}
114+
}
108115

109116
/**
110117
* Serialize the given response message to JSON and send it through the response servlet.
111118
* This validates the response before sending it to ensure it is properly constructed.
112119
*/
113-
protected def handleResponse(
120+
protected def sendResponse(
114121
responseMessage: SubmitRestProtocolResponse,
115122
responseServlet: HttpServletResponse): Unit = {
116-
try {
117-
val message = validateResponse(responseMessage)
118-
responseServlet.setContentType("application/json")
119-
responseServlet.setCharacterEncoding("utf-8")
120-
responseServlet.setStatus(HttpServletResponse.SC_OK)
121-
val content = message.toJson.getBytes(Charsets.UTF_8)
122-
val out = new DataOutputStream(responseServlet.getOutputStream)
123-
out.write(content)
124-
out.close()
125-
} catch {
126-
case e: Exception =>
127-
logError("Exception encountered when handling response.", e)
128-
}
123+
val message = validateResponse(responseMessage, responseServlet)
124+
responseServlet.setContentType("application/json")
125+
responseServlet.setCharacterEncoding("utf-8")
126+
responseServlet.setStatus(HttpServletResponse.SC_OK)
127+
val content = message.toJson.getBytes(Charsets.UTF_8)
128+
val out = new DataOutputStream(responseServlet.getOutputStream)
129+
out.write(content)
130+
out.close()
129131
}
130132

131133
/** Return a human readable String representation of the exception. */
@@ -147,12 +149,15 @@ private[spark] abstract class StandaloneRestServlet(master: Master)
147149
* If it is, simply return the response as is. Otherwise, return an error response
148150
* to propagate the exception back to the client.
149151
*/
150-
private def validateResponse(response: SubmitRestProtocolResponse): SubmitRestProtocolResponse = {
152+
private def validateResponse(
153+
responseMessage: SubmitRestProtocolResponse,
154+
responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
151155
try {
152-
response.validate()
153-
response
156+
responseMessage.validate()
157+
responseMessage
154158
} catch {
155159
case e: Exception =>
160+
responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
156161
handleError("Internal server error: " + formatException(e))
157162
}
158163
}
@@ -161,7 +166,8 @@ private[spark] abstract class StandaloneRestServlet(master: Master)
161166
/**
162167
* A servlet for handling kill requests passed to the [[StandaloneRestServer]].
163168
*/
164-
private[spark] class KillRequestServlet(master: Master) extends StandaloneRestServlet(master) {
169+
private[spark] class KillRequestServlet(master: Master) extends StandaloneRestServlet {
170+
private val askTimeout = AkkaUtils.askTimeout(master.conf)
165171

166172
/**
167173
* If a submission ID is specified in the URL, have the Master kill the corresponding
@@ -170,24 +176,20 @@ private[spark] class KillRequestServlet(master: Master) extends StandaloneRestSe
170176
protected override def doPost(
171177
request: HttpServletRequest,
172178
response: HttpServletResponse): Unit = {
173-
try {
174-
val submissionId = request.getPathInfo.stripPrefix("/")
175-
val responseMessage =
176-
if (submissionId.nonEmpty) {
177-
handleKill(submissionId)
178-
} else {
179-
handleError("Submission ID is missing in kill request")
180-
}
181-
handleResponse(responseMessage, response)
182-
} catch {
183-
case e: Exception =>
184-
logError("Exception encountered when handling kill request", e)
185-
}
179+
val submissionId = request.getPathInfo.stripPrefix("/")
180+
val responseMessage =
181+
if (submissionId.nonEmpty) {
182+
handleKill(submissionId)
183+
} else {
184+
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
185+
handleError("Submission ID is missing in kill request")
186+
}
187+
sendResponse(responseMessage, response)
186188
}
187189

188190
private def handleKill(submissionId: String): KillSubmissionResponse = {
189191
val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
190-
DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
192+
DeployMessages.RequestKillDriver(submissionId), master.self, askTimeout)
191193
val k = new KillSubmissionResponse
192194
k.serverSparkVersion = sparkVersion
193195
k.message = response.message
@@ -200,7 +202,8 @@ private[spark] class KillRequestServlet(master: Master) extends StandaloneRestSe
200202
/**
201203
* A servlet for handling status requests passed to the [[StandaloneRestServer]].
202204
*/
203-
private[spark] class StatusRequestServlet(master: Master) extends StandaloneRestServlet(master) {
205+
private[spark] class StatusRequestServlet(master: Master) extends StandaloneRestServlet {
206+
private val askTimeout = AkkaUtils.askTimeout(master.conf)
204207

205208
/**
206209
* If a submission ID is specified in the URL, request the status of the corresponding
@@ -209,24 +212,20 @@ private[spark] class StatusRequestServlet(master: Master) extends StandaloneRest
209212
protected override def doGet(
210213
request: HttpServletRequest,
211214
response: HttpServletResponse): Unit = {
212-
try {
213-
val submissionId = request.getPathInfo.stripPrefix("/")
214-
val responseMessage =
215-
if (submissionId.nonEmpty) {
216-
handleStatus(submissionId)
217-
} else {
218-
handleError("Submission ID is missing in status request")
219-
}
220-
handleResponse(responseMessage, response)
221-
} catch {
222-
case e: Exception =>
223-
logError("Exception encountered when handling status request", e)
224-
}
215+
val submissionId = request.getPathInfo.stripPrefix("/")
216+
val responseMessage =
217+
if (submissionId.nonEmpty) {
218+
handleStatus(submissionId)
219+
} else {
220+
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
221+
handleError("Submission ID is missing in status request")
222+
}
223+
sendResponse(responseMessage, response)
225224
}
226225

227226
private def handleStatus(submissionId: String): SubmissionStatusResponse = {
228227
val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
229-
DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
228+
DeployMessages.RequestDriverStatus(submissionId), master.self, askTimeout)
230229
val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) }
231230
val d = new SubmissionStatusResponse
232231
d.serverSparkVersion = sparkVersion
@@ -243,7 +242,8 @@ private[spark] class StatusRequestServlet(master: Master) extends StandaloneRest
243242
/**
244243
* A servlet for handling submit requests passed to the [[StandaloneRestServer]].
245244
*/
246-
private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRestServlet(master) {
245+
private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRestServlet {
246+
private val askTimeout = AkkaUtils.askTimeout(master.conf)
247247

248248
/**
249249
* Submit an application to the Master with parameters specified in the request message.
@@ -253,46 +253,48 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
253253
* Otherwise, return error instead.
254254
*/
255255
protected override def doPost(
256-
request: HttpServletRequest,
257-
response: HttpServletResponse): Unit = {
258-
try {
259-
val requestMessageJson = Source.fromInputStream(request.getInputStream).mkString
260-
val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
261-
.asInstanceOf[SubmitRestProtocolRequest]
262-
val responseMessage = handleSubmit(requestMessage)
263-
response.setContentType("application/json")
264-
response.setCharacterEncoding("utf-8")
265-
response.setStatus(HttpServletResponse.SC_OK)
266-
val content = responseMessage.toJson.getBytes(Charsets.UTF_8)
267-
val out = new DataOutputStream(response.getOutputStream)
268-
out.write(content)
269-
out.close()
270-
} catch {
271-
case e: Exception => logError("Exception while handling request", e)
272-
}
256+
requestServlet: HttpServletRequest,
257+
responseServlet: HttpServletResponse): Unit = {
258+
val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString
259+
val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
260+
.asInstanceOf[SubmitRestProtocolRequest]
261+
val responseMessage = handleSubmit(requestMessage, responseServlet)
262+
responseServlet.setContentType("application/json")
263+
responseServlet.setCharacterEncoding("utf-8")
264+
responseServlet.setStatus(HttpServletResponse.SC_OK)
265+
val content = responseMessage.toJson.getBytes(Charsets.UTF_8)
266+
val out = new DataOutputStream(responseServlet.getOutputStream)
267+
out.write(content)
268+
out.close()
273269
}
274270

275-
private def handleSubmit(request: SubmitRestProtocolRequest): SubmitRestProtocolResponse = {
271+
private def handleSubmit(
272+
requestMessage: SubmitRestProtocolRequest,
273+
responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
276274
// The response should have already been validated on the client.
277275
// In case this is not true, validate it ourselves to avoid potential NPEs.
278276
try {
279-
request.validate()
280-
request match {
281-
case submitRequest: CreateSubmissionRequest =>
282-
val driverDescription = buildDriverDescription(submitRequest)
283-
val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse](
284-
DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout)
285-
val submitResponse = new CreateSubmissionResponse
286-
submitResponse.serverSparkVersion = sparkVersion
287-
submitResponse.message = response.message
288-
submitResponse.success = response.success.toString
289-
submitResponse.submissionId = response.driverId.orNull
290-
submitResponse
291-
case unexpected => handleError(
292-
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
293-
}
277+
requestMessage.validate()
294278
} catch {
295-
case e: Exception => handleError(formatException(e))
279+
case e: SubmitRestProtocolException =>
280+
responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
281+
handleError(formatException(e))
282+
}
283+
requestMessage match {
284+
case submitRequest: CreateSubmissionRequest =>
285+
val driverDescription = buildDriverDescription(submitRequest)
286+
val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse](
287+
DeployMessages.RequestSubmitDriver(driverDescription), master.self, askTimeout)
288+
val submitResponse = new CreateSubmissionResponse
289+
submitResponse.serverSparkVersion = sparkVersion
290+
submitResponse.message = response.message
291+
submitResponse.success = response.success.toString
292+
submitResponse.submissionId = response.driverId.orNull
293+
submitResponse
294+
case unexpected =>
295+
responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
296+
handleError(
297+
s"Received message of unexpected type ${Utils.getFormattedClassName(unexpected)}.")
296298
}
297299
}
298300

@@ -331,7 +333,7 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
331333
// Translate all fields to the relevant Spark properties
332334
val conf = new SparkConf(false)
333335
.setAll(sparkProperties)
334-
.set("spark.master", masterUrl)
336+
.set("spark.master", master.masterUrl)
335337
.set("spark.app.name", appName)
336338
jars.foreach { j => conf.set("spark.jars", j) }
337339
files.foreach { f => conf.set("spark.files", f) }
@@ -362,19 +364,25 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
362364
/**
363365
* A default servlet that handles error cases that are not captured by other servlets.
364366
*/
365-
private[spark] class ErrorServlet extends HttpServlet {
367+
private[spark] class ErrorServlet extends StandaloneRestServlet {
366368
private val expectedVersion = StandaloneRestServer.PROTOCOL_VERSION
367-
override def service(request: HttpServletRequest, response: HttpServletResponse): Unit = {
369+
protected override def service(
370+
request: HttpServletRequest,
371+
response: HttpServletResponse): Unit = {
368372
val path = request.getPathInfo
369373
val parts = path.stripPrefix("/").split("/")
370374
if (parts.nonEmpty) {
371375
val version = parts.head
372376
if (version != expectedVersion) {
373-
response.sendError(800, s"Incompatible protocol version $version")
377+
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
378+
val error = handleError(s"Incompatible protocol version $version")
379+
sendResponse(error, response)
374380
return
375381
}
376382
}
377-
response.sendError(801,
383+
response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
384+
val error = handleError(
378385
s"Unexpected path $path: Please submit requests through /$expectedVersion/submissions/")
386+
sendResponse(error, response)
379387
}
380388
}

0 commit comments

Comments
 (0)