Skip to content

Commit 367051b

Browse files
Bogdan Raducanucmonkey
authored andcommitted
[SPARK-19512][SQL] codegen for compare structs fails
## What changes were proposed in this pull request? Set currentVars to null in GenerateOrdering.genComparisons before genCode is called. genCode ignores INPUT_ROW if currentVars is not null and in genComparisons we want it to use INPUT_ROW. ## How was this patch tested? Added test with 2 queries in WholeStageCodegenSuite Author: Bogdan Raducanu <[email protected]> Closes apache#16852 from bogdanrdc/SPARK-19512.
1 parent e13ba97 commit 367051b

File tree

3 files changed

+24
-4
lines changed

3 files changed

+24
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,6 @@ class CodegenContext {
555555
addNewFunction(compareFunc, funcCode)
556556
s"this.$compareFunc($c1, $c2)"
557557
case schema: StructType =>
558-
INPUT_ROW = "i"
559558
val comparisons = GenerateOrdering.genComparisons(this, schema)
560559
val compareFunc = freshName("compareStruct")
561560
val funcCode: String =
@@ -566,7 +565,6 @@ class CodegenContext {
566565
if (a instanceof UnsafeRow && b instanceof UnsafeRow && a.equals(b)) {
567566
return 0;
568567
}
569-
InternalRow i = null;
570568
$comparisons
571569
return 0;
572570
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
7373
*/
7474
def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = {
7575
val comparisons = ordering.map { order =>
76+
val oldCurrentVars = ctx.currentVars
77+
ctx.INPUT_ROW = "i"
78+
// to use INPUT_ROW we must make sure currentVars is null
79+
ctx.currentVars = null
7680
val eval = order.child.genCode(ctx)
81+
ctx.currentVars = oldCurrentVars
7782
val asc = order.isAscending
7883
val isNullA = ctx.freshName("isNullA")
7984
val primitiveA = ctx.freshName("primitiveA")
@@ -119,7 +124,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
119124
"""
120125
}
121126

122-
ctx.splitExpressions(
127+
val code = ctx.splitExpressions(
123128
expressions = comparisons,
124129
funcName = "compare",
125130
arguments = Seq(("InternalRow", "a"), ("InternalRow", "b")),
@@ -142,6 +147,12 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
142147
"""
143148
}.mkString
144149
})
150+
// make sure INPUT_ROW is declared even if splitExpressions
151+
// returns an inlined block
152+
s"""
153+
|InternalRow ${ctx.INPUT_ROW} = null;
154+
|$code
155+
""".stripMargin
145156
}
146157

147158
protected def create(ordering: Seq[SortOrder]): BaseOrdering = {
@@ -165,7 +176,6 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR
165176
${ctx.declareAddedFunctions()}
166177

167178
public int compare(InternalRow a, InternalRow b) {
168-
InternalRow ${ctx.INPUT_ROW} = null; // Holds current row being evaluated.
169179
$comparisons
170180
return 0;
171181
}

sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,4 +143,16 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
143143
assert(createStackGenerator(50).find(isCodeGenerated).isDefined)
144144
assert(createStackGenerator(100).find(isCodeGenerated).isEmpty)
145145
}
146+
147+
test("SPARK-19512 codegen for comparing structs is incorrect") {
148+
// this would raise CompileException before the fix
149+
spark.range(10)
150+
.selectExpr("named_struct('a', id) as col1", "named_struct('a', id+2) as col2")
151+
.filter("col1 = col2").count()
152+
// this would raise java.lang.IndexOutOfBoundsException before the fix
153+
spark.range(10)
154+
.selectExpr("named_struct('a', id, 'b', id) as col1",
155+
"named_struct('a',id+2, 'b',id+2) as col2")
156+
.filter("col1 = col2").count()
157+
}
146158
}

0 commit comments

Comments
 (0)