Skip to content

Commit d9dd979

Browse files
JoshRosenrxin
authored andcommitted
[SPARK-18462] Fix ClassCastException in SparkListenerDriverAccumUpdates event
## What changes were proposed in this pull request? This patch fixes a `ClassCastException: java.lang.Integer cannot be cast to java.lang.Long` error which could occur in the HistoryServer while trying to process a deserialized `SparkListenerDriverAccumUpdates` event. The problem stems from how `jackson-module-scala` handles primitive type parameters (see https://github.com/FasterXML/jackson-module-scala/wiki/FAQ#deserializing-optionint-and-other-primitive-challenges for more details). This was causing a problem where our code expected a field to be deserialized as a `(Long, Long)` tuple but we got an `(Int, Int)` tuple instead. This patch hacks around this issue by registering a custom `Converter` with Jackson in order to deserialize the tuples as `(Object, Object)` and perform the appropriate casting. ## How was this patch tested? New regression tests in `SQLListenerSuite`. Author: Josh Rosen <[email protected]> Closes #15922 from JoshRosen/SPARK-18462.
1 parent ce13c26 commit d9dd979

File tree

2 files changed

+80
-3
lines changed

2 files changed

+80
-3
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@ package org.apache.spark.sql.execution.ui
1919

2020
import scala.collection.mutable
2121

22+
import com.fasterxml.jackson.databind.JavaType
23+
import com.fasterxml.jackson.databind.`type`.TypeFactory
24+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
25+
import com.fasterxml.jackson.databind.util.Converter
26+
2227
import org.apache.spark.{JobExecutionStatus, SparkConf}
2328
import org.apache.spark.annotation.DeveloperApi
2429
import org.apache.spark.internal.Logging
@@ -43,9 +48,41 @@ case class SparkListenerSQLExecutionEnd(executionId: Long, time: Long)
4348
extends SparkListenerEvent
4449

4550
@DeveloperApi
46-
case class SparkListenerDriverAccumUpdates(executionId: Long, accumUpdates: Seq[(Long, Long)])
51+
case class SparkListenerDriverAccumUpdates(
52+
executionId: Long,
53+
@JsonDeserialize(contentConverter = classOf[LongLongTupleConverter])
54+
accumUpdates: Seq[(Long, Long)])
4755
extends SparkListenerEvent
4856

57+
/**
58+
* Jackson [[Converter]] for converting an (Int, Int) tuple into a (Long, Long) tuple.
59+
*
60+
* This is necessary due to limitations in how Jackson's scala module deserializes primitives;
61+
* see the "Deserializing Option[Int] and other primitive challenges" section in
62+
* https://github.com/FasterXML/jackson-module-scala/wiki/FAQ for a discussion of this issue and
63+
* SPARK-18462 for the specific problem that motivated this conversion.
64+
*/
65+
private class LongLongTupleConverter extends Converter[(Object, Object), (Long, Long)] {
66+
67+
override def convert(in: (Object, Object)): (Long, Long) = {
68+
def toLong(a: Object): Long = a match {
69+
case i: java.lang.Integer => i.intValue()
70+
case l: java.lang.Long => l.longValue()
71+
}
72+
(toLong(in._1), toLong(in._2))
73+
}
74+
75+
override def getInputType(typeFactory: TypeFactory): JavaType = {
76+
val objectType = typeFactory.uncheckedSimpleType(classOf[Object])
77+
typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(objectType, objectType))
78+
}
79+
80+
override def getOutputType(typeFactory: TypeFactory): JavaType = {
81+
val longType = typeFactory.uncheckedSimpleType(classOf[Long])
82+
typeFactory.constructSimpleType(classOf[(_, _)], classOf[(_, _)], Array(longType, longType))
83+
}
84+
}
85+
4986
class SQLHistoryListenerFactory extends SparkHistoryListenerFactory {
5087

5188
override def createListeners(conf: SparkConf, sparkUI: SparkUI): Seq[SparkListener] = {

sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.ui
1919

2020
import java.util.Properties
2121

22+
import org.json4s.jackson.JsonMethods._
2223
import org.mockito.Mockito.mock
2324

2425
import org.apache.spark._
@@ -35,10 +36,10 @@ import org.apache.spark.sql.execution.{LeafExecNode, QueryExecution, SparkPlanIn
3536
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
3637
import org.apache.spark.sql.test.SharedSQLContext
3738
import org.apache.spark.ui.SparkUI
38-
import org.apache.spark.util.{AccumulatorMetadata, LongAccumulator}
39+
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator}
3940

4041

41-
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
42+
class SQLListenerSuite extends SparkFunSuite with SharedSQLContext with JsonTestUtils {
4243
import testImplicits._
4344
import org.apache.spark.AccumulatorSuite.makeInfo
4445

@@ -416,6 +417,45 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
416417
assert(driverUpdates(physicalPlan.longMetric("dummy").id) == expectedAccumValue)
417418
}
418419

420+
test("roundtripping SparkListenerDriverAccumUpdates through JsonProtocol (SPARK-18462)") {
421+
val event = SparkListenerDriverAccumUpdates(1L, Seq((2L, 3L)))
422+
val json = JsonProtocol.sparkEventToJson(event)
423+
assertValidDataInJson(json,
424+
parse("""
425+
|{
426+
| "Event": "org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates",
427+
| "executionId": 1,
428+
| "accumUpdates": [[2,3]]
429+
|}
430+
""".stripMargin))
431+
JsonProtocol.sparkEventFromJson(json) match {
432+
case SparkListenerDriverAccumUpdates(executionId, accums) =>
433+
assert(executionId == 1L)
434+
accums.foreach { case (a, b) =>
435+
assert(a == 2L)
436+
assert(b == 3L)
437+
}
438+
}
439+
440+
// Test a case where the numbers in the JSON can only fit in longs:
441+
val longJson = parse(
442+
"""
443+
|{
444+
| "Event": "org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates",
445+
| "executionId": 4294967294,
446+
| "accumUpdates": [[4294967294,3]]
447+
|}
448+
""".stripMargin)
449+
JsonProtocol.sparkEventFromJson(longJson) match {
450+
case SparkListenerDriverAccumUpdates(executionId, accums) =>
451+
assert(executionId == 4294967294L)
452+
accums.foreach { case (a, b) =>
453+
assert(a == 4294967294L)
454+
assert(b == 3L)
455+
}
456+
}
457+
}
458+
419459
}
420460

421461

0 commit comments

Comments
 (0)