@@ -23,7 +23,6 @@ import javax.servlet.http.{HttpServlet, HttpServletResponse, HttpServletRequest}
2323
2424import scala .io .Source
2525
26- import akka .actor .ActorRef
2726import com .google .common .base .Charsets
2827import org .eclipse .jetty .server .Server
2928import org .eclipse .jetty .servlet .{ServletHolder , ServletContextHandler }
@@ -38,11 +37,12 @@ import org.apache.spark.deploy.master.Master
3837/**
3938 * A server that responds to requests submitted by the [[StandaloneRestClient ]].
4039 * This is intended to be embedded in the standalone Master and used in cluster mode only.
40+ *
41+ * When an error occurs, this server sends an error response with an appropriate message
42+ * back to the client. If the construction of this error message itself is faulty, the
43+ * server indicates internal error through the response code.
4144 */
42- private [spark] class StandaloneRestServer (
43- master : Master ,
44- host : String ,
45- requestedPort : Int )
45+ private [spark] class StandaloneRestServer (master : Master , host : String , requestedPort : Int )
4646 extends Logging {
4747
4848 import StandaloneRestServer ._
@@ -98,34 +98,36 @@ private object StandaloneRestServer {
9898/**
9999 * An abstract servlet for handling requests passed to the [[StandaloneRestServer ]].
100100 */
101- private [spark] abstract class StandaloneRestServlet (master : Master )
102- extends HttpServlet with Logging {
101+ private [spark] abstract class StandaloneRestServlet extends HttpServlet with Logging {
103102
104- protected val conf : SparkConf = master.conf
105- protected val masterActor : ActorRef = master.self
106- protected val masterUrl : String = master.masterUrl
107- protected val askTimeout = AkkaUtils .askTimeout(conf)
103+ /** Service a request. If an exception is thrown in the process, indicate server error. */
104+ protected override def service (
105+ request : HttpServletRequest ,
106+ response : HttpServletResponse ): Unit = {
107+ try {
108+ super .service(request, response)
109+ } catch {
110+ case e : Exception =>
111+ logError(" Exception while handling request" , e)
112+ response.setStatus(HttpServletResponse .SC_INTERNAL_SERVER_ERROR )
113+ }
114+ }
108115
109116 /**
110117 * Serialize the given response message to JSON and send it through the response servlet.
111118 * This validates the response before sending it to ensure it is properly constructed.
112119 */
113- protected def handleResponse (
120+ protected def sendResponse (
114121 responseMessage : SubmitRestProtocolResponse ,
115122 responseServlet : HttpServletResponse ): Unit = {
116- try {
117- val message = validateResponse(responseMessage)
118- responseServlet.setContentType(" application/json" )
119- responseServlet.setCharacterEncoding(" utf-8" )
120- responseServlet.setStatus(HttpServletResponse .SC_OK )
121- val content = message.toJson.getBytes(Charsets .UTF_8 )
122- val out = new DataOutputStream (responseServlet.getOutputStream)
123- out.write(content)
124- out.close()
125- } catch {
126- case e : Exception =>
127- logError(" Exception encountered when handling response." , e)
128- }
123+ val message = validateResponse(responseMessage, responseServlet)
124+ responseServlet.setContentType(" application/json" )
125+ responseServlet.setCharacterEncoding(" utf-8" )
126+ responseServlet.setStatus(HttpServletResponse .SC_OK )
127+ val content = message.toJson.getBytes(Charsets .UTF_8 )
128+ val out = new DataOutputStream (responseServlet.getOutputStream)
129+ out.write(content)
130+ out.close()
129131 }
130132
131133 /** Return a human readable String representation of the exception. */
@@ -147,12 +149,15 @@ private[spark] abstract class StandaloneRestServlet(master: Master)
147149 * If it is, simply return the response as is. Otherwise, return an error response
148150 * to propagate the exception back to the client.
149151 */
150- private def validateResponse (response : SubmitRestProtocolResponse ): SubmitRestProtocolResponse = {
152+ private def validateResponse (
153+ responseMessage : SubmitRestProtocolResponse ,
154+ responseServlet : HttpServletResponse ): SubmitRestProtocolResponse = {
151155 try {
152- response .validate()
153- response
156+ responseMessage .validate()
157+ responseMessage
154158 } catch {
155159 case e : Exception =>
160+ responseServlet.setStatus(HttpServletResponse .SC_INTERNAL_SERVER_ERROR )
156161 handleError(" Internal server error: " + formatException(e))
157162 }
158163 }
@@ -161,7 +166,8 @@ private[spark] abstract class StandaloneRestServlet(master: Master)
161166/**
162167 * A servlet for handling kill requests passed to the [[StandaloneRestServer ]].
163168 */
164- private [spark] class KillRequestServlet (master : Master ) extends StandaloneRestServlet (master) {
169+ private [spark] class KillRequestServlet (master : Master ) extends StandaloneRestServlet {
170+ private val askTimeout = AkkaUtils .askTimeout(master.conf)
165171
166172 /**
167173 * If a submission ID is specified in the URL, have the Master kill the corresponding
@@ -170,24 +176,20 @@ private[spark] class KillRequestServlet(master: Master) extends StandaloneRestSe
170176 protected override def doPost (
171177 request : HttpServletRequest ,
172178 response : HttpServletResponse ): Unit = {
173- try {
174- val submissionId = request.getPathInfo.stripPrefix(" /" )
175- val responseMessage =
176- if (submissionId.nonEmpty) {
177- handleKill(submissionId)
178- } else {
179- handleError(" Submission ID is missing in kill request" )
180- }
181- handleResponse(responseMessage, response)
182- } catch {
183- case e : Exception =>
184- logError(" Exception encountered when handling kill request" , e)
185- }
179+ val submissionId = request.getPathInfo.stripPrefix(" /" )
180+ val responseMessage =
181+ if (submissionId.nonEmpty) {
182+ handleKill(submissionId)
183+ } else {
184+ response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
185+ handleError(" Submission ID is missing in kill request" )
186+ }
187+ sendResponse(responseMessage, response)
186188 }
187189
188190 private def handleKill (submissionId : String ): KillSubmissionResponse = {
189191 val response = AkkaUtils .askWithReply[DeployMessages .KillDriverResponse ](
190- DeployMessages .RequestKillDriver (submissionId), masterActor , askTimeout)
192+ DeployMessages .RequestKillDriver (submissionId), master.self , askTimeout)
191193 val k = new KillSubmissionResponse
192194 k.serverSparkVersion = sparkVersion
193195 k.message = response.message
@@ -200,7 +202,8 @@ private[spark] class KillRequestServlet(master: Master) extends StandaloneRestSe
200202/**
201203 * A servlet for handling status requests passed to the [[StandaloneRestServer ]].
202204 */
203- private [spark] class StatusRequestServlet (master : Master ) extends StandaloneRestServlet (master) {
205+ private [spark] class StatusRequestServlet (master : Master ) extends StandaloneRestServlet {
206+ private val askTimeout = AkkaUtils .askTimeout(master.conf)
204207
205208 /**
206209 * If a submission ID is specified in the URL, request the status of the corresponding
@@ -209,24 +212,20 @@ private[spark] class StatusRequestServlet(master: Master) extends StandaloneRest
209212 protected override def doGet (
210213 request : HttpServletRequest ,
211214 response : HttpServletResponse ): Unit = {
212- try {
213- val submissionId = request.getPathInfo.stripPrefix(" /" )
214- val responseMessage =
215- if (submissionId.nonEmpty) {
216- handleStatus(submissionId)
217- } else {
218- handleError(" Submission ID is missing in status request" )
219- }
220- handleResponse(responseMessage, response)
221- } catch {
222- case e : Exception =>
223- logError(" Exception encountered when handling status request" , e)
224- }
215+ val submissionId = request.getPathInfo.stripPrefix(" /" )
216+ val responseMessage =
217+ if (submissionId.nonEmpty) {
218+ handleStatus(submissionId)
219+ } else {
220+ response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
221+ handleError(" Submission ID is missing in status request" )
222+ }
223+ sendResponse(responseMessage, response)
225224 }
226225
227226 private def handleStatus (submissionId : String ): SubmissionStatusResponse = {
228227 val response = AkkaUtils .askWithReply[DeployMessages .DriverStatusResponse ](
229- DeployMessages .RequestDriverStatus (submissionId), masterActor , askTimeout)
228+ DeployMessages .RequestDriverStatus (submissionId), master.self , askTimeout)
230229 val message = response.exception.map { s " Exception from the cluster: \n " + formatException(_) }
231230 val d = new SubmissionStatusResponse
232231 d.serverSparkVersion = sparkVersion
@@ -243,7 +242,8 @@ private[spark] class StatusRequestServlet(master: Master) extends StandaloneRest
243242/**
244243 * A servlet for handling submit requests passed to the [[StandaloneRestServer ]].
245244 */
246- private [spark] class SubmitRequestServlet (master : Master ) extends StandaloneRestServlet (master) {
245+ private [spark] class SubmitRequestServlet (master : Master ) extends StandaloneRestServlet {
246+ private val askTimeout = AkkaUtils .askTimeout(master.conf)
247247
248248 /**
249249 * Submit an application to the Master with parameters specified in the request message.
@@ -253,46 +253,48 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
253253 * Otherwise, return error instead.
254254 */
255255 protected override def doPost (
256- request : HttpServletRequest ,
257- response : HttpServletResponse ): Unit = {
258- try {
259- val requestMessageJson = Source .fromInputStream(request.getInputStream).mkString
260- val requestMessage = SubmitRestProtocolMessage .fromJson(requestMessageJson)
261- .asInstanceOf [SubmitRestProtocolRequest ]
262- val responseMessage = handleSubmit(requestMessage)
263- response.setContentType(" application/json" )
264- response.setCharacterEncoding(" utf-8" )
265- response.setStatus(HttpServletResponse .SC_OK )
266- val content = responseMessage.toJson.getBytes(Charsets .UTF_8 )
267- val out = new DataOutputStream (response.getOutputStream)
268- out.write(content)
269- out.close()
270- } catch {
271- case e : Exception => logError(" Exception while handling request" , e)
272- }
256+ requestServlet : HttpServletRequest ,
257+ responseServlet : HttpServletResponse ): Unit = {
258+ val requestMessageJson = Source .fromInputStream(requestServlet.getInputStream).mkString
259+ val requestMessage = SubmitRestProtocolMessage .fromJson(requestMessageJson)
260+ .asInstanceOf [SubmitRestProtocolRequest ]
261+ val responseMessage = handleSubmit(requestMessage, responseServlet)
262+ responseServlet.setContentType(" application/json" )
263+ responseServlet.setCharacterEncoding(" utf-8" )
264+ responseServlet.setStatus(HttpServletResponse .SC_OK )
265+ val content = responseMessage.toJson.getBytes(Charsets .UTF_8 )
266+ val out = new DataOutputStream (responseServlet.getOutputStream)
267+ out.write(content)
268+ out.close()
273269 }
274270
275- private def handleSubmit (request : SubmitRestProtocolRequest ): SubmitRestProtocolResponse = {
271+ private def handleSubmit (
272+ requestMessage : SubmitRestProtocolRequest ,
273+ responseServlet : HttpServletResponse ): SubmitRestProtocolResponse = {
276274 // The response should have already been validated on the client.
277275 // In case this is not true, validate it ourselves to avoid potential NPEs.
278276 try {
279- request.validate()
280- request match {
281- case submitRequest : CreateSubmissionRequest =>
282- val driverDescription = buildDriverDescription(submitRequest)
283- val response = AkkaUtils .askWithReply[DeployMessages .SubmitDriverResponse ](
284- DeployMessages .RequestSubmitDriver (driverDescription), masterActor, askTimeout)
285- val submitResponse = new CreateSubmissionResponse
286- submitResponse.serverSparkVersion = sparkVersion
287- submitResponse.message = response.message
288- submitResponse.success = response.success.toString
289- submitResponse.submissionId = response.driverId.orNull
290- submitResponse
291- case unexpected => handleError(
292- s " Received message of unexpected type ${Utils .getFormattedClassName(unexpected)}. " )
293- }
277+ requestMessage.validate()
294278 } catch {
295- case e : Exception => handleError(formatException(e))
279+ case e : SubmitRestProtocolException =>
280+ responseServlet.setStatus(HttpServletResponse .SC_BAD_REQUEST )
281+ handleError(formatException(e))
282+ }
283+ requestMessage match {
284+ case submitRequest : CreateSubmissionRequest =>
285+ val driverDescription = buildDriverDescription(submitRequest)
286+ val response = AkkaUtils .askWithReply[DeployMessages .SubmitDriverResponse ](
287+ DeployMessages .RequestSubmitDriver (driverDescription), master.self, askTimeout)
288+ val submitResponse = new CreateSubmissionResponse
289+ submitResponse.serverSparkVersion = sparkVersion
290+ submitResponse.message = response.message
291+ submitResponse.success = response.success.toString
292+ submitResponse.submissionId = response.driverId.orNull
293+ submitResponse
294+ case unexpected =>
295+ responseServlet.setStatus(HttpServletResponse .SC_BAD_REQUEST )
296+ handleError(
297+ s " Received message of unexpected type ${Utils .getFormattedClassName(unexpected)}. " )
296298 }
297299 }
298300
@@ -331,7 +333,7 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
331333 // Translate all fields to the relevant Spark properties
332334 val conf = new SparkConf (false )
333335 .setAll(sparkProperties)
334- .set(" spark.master" , masterUrl)
336+ .set(" spark.master" , master. masterUrl)
335337 .set(" spark.app.name" , appName)
336338 jars.foreach { j => conf.set(" spark.jars" , j) }
337339 files.foreach { f => conf.set(" spark.files" , f) }
@@ -362,19 +364,25 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
362364/**
363365 * A default servlet that handles error cases that are not captured by other servlets.
364366 */
365- private [spark] class ErrorServlet extends HttpServlet {
367+ private [spark] class ErrorServlet extends StandaloneRestServlet {
366368 private val expectedVersion = StandaloneRestServer .PROTOCOL_VERSION
367- override def service (request : HttpServletRequest , response : HttpServletResponse ): Unit = {
369+ protected override def service (
370+ request : HttpServletRequest ,
371+ response : HttpServletResponse ): Unit = {
368372 val path = request.getPathInfo
369373 val parts = path.stripPrefix(" /" ).split(" /" )
370374 if (parts.nonEmpty) {
371375 val version = parts.head
372376 if (version != expectedVersion) {
373- response.sendError(800 , s " Incompatible protocol version $version" )
377+ response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
378+ val error = handleError(s " Incompatible protocol version $version" )
379+ sendResponse(error, response)
374380 return
375381 }
376382 }
377- response.sendError(801 ,
383+ response.setStatus(HttpServletResponse .SC_BAD_REQUEST )
384+ val error = handleError(
378385 s " Unexpected path $path: Please submit requests through / $expectedVersion/submissions/ " )
386+ sendResponse(error, response)
379387 }
380388}
0 commit comments