Skip to content

Commit cbd670b

Browse files
author
Andrew Or
committed
Include unknown fields, if any, in server response
... in case the client wants to propagate this to the user in the future.
1 parent 9fee16f commit cbd670b

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import com.google.common.base.Charsets
2727
import org.eclipse.jetty.server.Server
2828
import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
2929
import org.eclipse.jetty.util.thread.QueuedThreadPool
30+
import org.json4s._
31+
import org.json4s.jackson.JsonMethods._
3032

3133
import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
3234
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -131,6 +133,22 @@ private[spark] abstract class StandaloneRestServlet extends HttpServlet with Log
131133
out.close()
132134
}
133135

136+
/**
137+
* Return any fields in the client request message that the server does not know about.
138+
*
139+
* The mechanism for this is to reconstruct the JSON on the server side and compare the
140+
* diff between this JSON and the one generated on the client side. Any fields that are
141+
* only in the client JSON are treated as unexpected.
142+
*/
143+
protected def findUnknownFields(
144+
requestJson: String,
145+
requestMessage: SubmitRestProtocolMessage): Array[String] = {
146+
val clientSideJson = parse(requestJson)
147+
val serverSideJson = parse(requestMessage.toJson)
148+
val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson)
149+
unknown.asInstanceOf[JObject].obj.map { case (k, _) => k }.toArray
150+
}
151+
134152
/** Return a human readable String representation of the exception. */
135153
protected def formatException(e: Exception): String = {
136154
val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n")
@@ -259,6 +277,11 @@ private[spark] class SubmitRequestServlet(master: Master) extends StandaloneRest
259277
val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString
260278
val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
261279
val responseMessage = handleSubmit(requestMessage, responseServlet)
280+
val unknownFields = findUnknownFields(requestMessageJson, requestMessage)
281+
if (unknownFields.nonEmpty) {
282+
// If there are fields that the server does not know about, warn the client
283+
responseMessage.unknownFields = unknownFields
284+
}
262285
responseServlet.setContentType("application/json")
263286
responseServlet.setCharacterEncoding("utf-8")
264287
responseServlet.setStatus(HttpServletResponse.SC_OK)

core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ package org.apache.spark.deploy.rest
2323
private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
2424
var serverSparkVersion: String = null
2525
var success: String = null
26+
var unknownFields: Array[String] = null
2627
protected override def doValidate(): Unit = {
2728
super.doValidate()
2829
assertFieldIsSet(serverSparkVersion, "serverSparkVersion")

0 commit comments

Comments
 (0)