Skip to content

Commit a814eea

Browse files
viiryahvanhovell
authored andcommitted
[SPARK-18125][SQL] Fix a compilation error in codegen due to splitExpression
## What changes were proposed in this pull request? As reported in the jira, sometimes the generated java code in codegen will cause compilation error. Code snippet to test it: case class Route(src: String, dest: String, cost: Int) case class GroupedRoutes(src: String, dest: String, routes: Seq[Route]) val ds = sc.parallelize(Array( Route("a", "b", 1), Route("a", "b", 2), Route("a", "c", 2), Route("a", "d", 10), Route("b", "a", 1), Route("b", "a", 5), Route("b", "c", 6)) ).toDF.as[Route] val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r))) .groupByKey(r => (r.src, r.dest)) .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) => GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes) }.map(_._2) The problem here is, in `ReferenceToExpressions` we evaluate the children vars to local variables. Then the result expression is evaluated to use those children variables. In the above case, the result expression code is too long and will be split by `CodegenContext.splitExpression`. So those local variables cannot be accessed and cause compilation error. ## How was this patch tested? Jenkins tests. Please review https://cwiki.apache.org/confluence/display/SPARK/Contributing+to+Spark before opening a pull request. Author: Liang-Chi Hsieh <[email protected]> Closes #15693 from viirya/fix-codege-compilation-error.
1 parent 57626a5 commit a814eea

File tree

2 files changed

+58
-6
lines changed

2 files changed

+58
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,30 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
6363

6464
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
6565
val childrenGen = children.map(_.genCode(ctx))
66-
val childrenVars = childrenGen.zip(children).map {
67-
case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType)
68-
}
66+
val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
67+
case (childGen, child) =>
68+
// SPARK-18125: The children vars are local variables. If the result expression uses
69+
// splitExpression, those variables cannot be accessed so compilation fails.
70+
// To fix it, we use class variables to hold those local variables.
71+
val classChildVarName = ctx.freshName("classChildVar")
72+
val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
73+
ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
74+
ctx.addMutableState("boolean", classChildVarIsNull, "")
75+
76+
val classChildVar =
77+
LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)
78+
79+
val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
80+
s"${classChildVar.isNull} = ${childGen.isNull};"
81+
82+
(classChildVar, initCode)
83+
}.unzip
6984

7085
val resultGen = result.transform {
71-
case b: BoundReference => childrenVars(b.ordinal)
86+
case b: BoundReference => classChildrenVars(b.ordinal)
7287
}.genCode(ctx)
7388

74-
ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code,
75-
isNull = resultGen.isNull, value = resultGen.value)
89+
ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
90+
resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
7691
}
7792
}

sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
923923
.groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
924924
}
925925

926+
test("SPARK-18125: Spark generated code causes CompileException") {
927+
val data = Array(
928+
Route("a", "b", 1),
929+
Route("a", "b", 2),
930+
Route("a", "c", 2),
931+
Route("a", "d", 10),
932+
Route("b", "a", 1),
933+
Route("b", "a", 5),
934+
Route("b", "c", 6))
935+
val ds = sparkContext.parallelize(data).toDF.as[Route]
936+
937+
val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r)))
938+
.groupByKey(r => (r.src, r.dest))
939+
.reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) =>
940+
GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes)
941+
}.map(_._2)
942+
943+
val expected = Seq(
944+
GroupedRoutes("a", "d", Seq(Route("a", "d", 10))),
945+
GroupedRoutes("b", "c", Seq(Route("b", "c", 6))),
946+
GroupedRoutes("a", "b", Seq(Route("a", "b", 1), Route("a", "b", 2))),
947+
GroupedRoutes("b", "a", Seq(Route("b", "a", 1), Route("b", "a", 5))),
948+
GroupedRoutes("a", "c", Seq(Route("a", "c", 2)))
949+
)
950+
951+
implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new Ordering[GroupedRoutes] {
952+
override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = {
953+
x.toString.compareTo(y.toString)
954+
}
955+
}
956+
957+
checkDatasetUnorderly(grped, expected: _*)
958+
}
959+
926960
test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") {
927961
val resultValue = 12345
928962
val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1)
@@ -1071,3 +1105,6 @@ object DatasetTransform {
10711105
ds.map(_ + 1)
10721106
}
10731107
}
1108+
1109+
case class Route(src: String, dest: String, cost: Int)
1110+
case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])

0 commit comments

Comments
 (0)