Skip to content

Commit 90ef125

Browse files
committed
[SPARK-31115][SQL] Provide config to avoid using switch statement in generated code to avoid Janino bug
1 parent e807118 commit 90ef125

File tree

4 files changed

+113
-12
lines changed

4 files changed

+113
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,9 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
454454
}
455455

456456
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
457-
if (canBeComputedUsingSwitch && hset.size <= SQLConf.get.optimizerInSetSwitchThreshold) {
457+
val sqlConf = SQLConf.get
458+
if (canBeComputedUsingSwitch && hset.size <= sqlConf.optimizerInSetSwitchThreshold &&
459+
sqlConf.codegenUseSwitchStatement) {
458460
genCodeWithSwitch(ctx, ev)
459461
} else {
460462
genCodeWithSet(ctx, ev)

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,6 +1130,23 @@ object SQLConf {
11301130
.booleanConf
11311131
.createWithDefault(true)
11321132

1133+
val CODEGEN_USE_SWITCH_STATEMENT =
1134+
buildConf("spark.sql.codegen.useSwitchStatement")
1135+
.internal()
1136+
.doc("When true, Spark leverages switch statement while generating code. Otherwise Spark " +
1137+
"will leverage if ~ else if ~ else statement as an alternative. In normal case, " +
1138+
"'switch' statement is preferred against if ~ else if ~ else. This configuration is " +
1139+
"required to avoid Janino bug (https://github.com/janino-compiler/janino/issues/113); " +
1140+
"If InternalCompilerException has been thrown and following conditions are met, you " +
1141+
"may want to turn this off and try executing the query again." +
1142+
"1) The generated code contains 'switch' statement." +
1143+
"2) Exception message contains 'Operand stack inconsistent at offset xxx: Previous size 1" +
1144+
", now 0'." +
1145+
"The configuration will be no-op and maybe removed once Spark upgrades Janino containing" +
1146+
" the fix.")
1147+
.booleanConf
1148+
.createWithDefault(true)
1149+
11331150
val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes")
11341151
.doc("The maximum number of bytes to pack into a single partition when reading files. " +
11351152
"This configuration is effective only when using file-based sources such as Parquet, JSON " +
@@ -2764,6 +2781,8 @@ class SQLConf extends Serializable with Logging {
27642781
def wholeStageSplitConsumeFuncByOperator: Boolean =
27652782
getConf(WHOLESTAGE_SPLIT_CONSUME_FUNC_BY_OPERATOR)
27662783

2784+
def codegenUseSwitchStatement: Boolean = getConf(CODEGEN_USE_SWITCH_STATEMENT)
2785+
27672786
def tableRelationCacheSize: Int =
27682787
getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE)
27692788

sql/core/src/main/scala/org/apache/spark/sql/execution/ExpandExec.scala

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ case class ExpandExec(
5454
private[this] val projection =
5555
(exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output)
5656

57+
private val useSwitchStatement: Boolean = sqlContext.conf.codegenUseSwitchStatement
58+
5759
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
5860
val numOutputRows = longMetric("numOutputRows")
5961

@@ -167,8 +169,9 @@ case class ExpandExec(
167169
}
168170
}
169171

170-
// Part 2: switch/case statements
171-
val cases = projections.zipWithIndex.map { case (exprs, row) =>
172+
// Part 2: switch/case statements, or if/else if statements via configuration
173+
174+
val updates = projections.map { exprs =>
172175
var updateCode = ""
173176
for (col <- exprs.indices) {
174177
if (!sameOutput(col)) {
@@ -178,27 +181,48 @@ case class ExpandExec(
178181
|${ev.code}
179182
|${outputColumns(col).isNull} = ${ev.isNull};
180183
|${outputColumns(col).value} = ${ev.value};
181-
""".stripMargin
184+
""".stripMargin
182185
}
183186
}
187+
updateCode.trim
188+
}
189+
190+
// the name needs to be known to build conditions
191+
val i = ctx.freshName("i")
192+
val loopContent = if (useSwitchStatement) {
193+
val cases = updates.zipWithIndex.map { case (updateCode, row) =>
194+
s"""
195+
|case $row:
196+
| ${updateCode.trim}
197+
| break;
198+
""".stripMargin
199+
}
184200

185201
s"""
186-
|case $row:
187-
| ${updateCode.trim}
188-
| break;
202+
|switch ($i) {
203+
| ${cases.mkString("\n").trim}
204+
|}
189205
""".stripMargin
206+
} else {
207+
val conditions = updates.zipWithIndex.map { case (updateCode, row) =>
208+
(if (row > 0) "else " else "") +
209+
s"""
210+
|if ($i == $row) {
211+
| ${updateCode.trim}
212+
|}
213+
""".stripMargin
214+
}
215+
216+
conditions.mkString("\n").trim
190217
}
191218

192219
val numOutput = metricTerm(ctx, "numOutputRows")
193-
val i = ctx.freshName("i")
194220
// these column have to declared before the loop.
195221
val evaluate = evaluateVariables(outputColumns)
196222
s"""
197223
|$evaluate
198224
|for (int $i = 0; $i < ${projections.length}; $i ++) {
199-
| switch ($i) {
200-
| ${cases.mkString("\n").trim}
201-
| }
225+
| $loopContent
202226
| $numOutput.add(1);
203227
| ${consume(ctx, outputColumns)}
204228
|}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import org.apache.spark.sql.internal.SQLConf
3131
import org.apache.spark.sql.test.SharedSparkSession
3232
import org.apache.spark.sql.test.SQLTestData.DecimalData
3333
import org.apache.spark.sql.types._
34-
import org.apache.spark.unsafe.types.CalendarInterval
3534

3635
case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double)
3736

@@ -957,4 +956,61 @@ class DataFrameAggregateSuite extends QueryTest
957956
assert(error.message.contains("function count_if requires boolean type"))
958957
}
959958
}
959+
960+
/**
961+
* NOTE: The test code tries to control the size of for/switch statement in expand_doConsume,
962+
* as well as the overall size of expand_doConsume, so that the query triggers known Janino
963+
* bug - https://github.com/janino-compiler/janino/issues/113.
964+
*
965+
* The expected exception message from Janino when we use switch statement for "ExpandExec":
966+
* - "Operand stack inconsistent at offset xxx: Previous size 1, now 0"
967+
* which will not happen when we use if-else-if statement for "ExpandExec".
968+
*
969+
* "The number of fields" and "The number of distinct aggregation functions" are the major
970+
* factors to increase the size of generated code: while these values should be large enough
971+
* to trigger the Janino bug, these values should not also too big; otherwise one of below
972+
* exceptions might be thrown:
973+
* - "expand_doConsume would be beyond 64KB"
974+
* - "java.lang.ClassFormatError: Too many arguments in method signature in class file"
975+
*/
976+
test("SPARK-31115 Lots of columns and distinct aggregations shouldn't break code generation") {
977+
withSQLConf(
978+
(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true"),
979+
(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key, "10000"),
980+
(SQLConf.CODEGEN_FALLBACK.key, "false"),
981+
(SQLConf.CODEGEN_LOGGING_MAX_LINES.key, "-1"),
982+
(SQLConf.CODEGEN_USE_SWITCH_STATEMENT.key, "false")
983+
) {
984+
var df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c")
985+
986+
// The value is tested under commit "e807118eef9e0214170ff62c828524d237bd58e3":
987+
// the query fails with switch statement, whereas it passes with if-else statement.
988+
// Note that the value depends on the Spark logic as well - different Spark versions may
989+
// require different value to ensure the test failing with switch statement.
990+
val numNewFields = 100
991+
992+
df = df.withColumns(
993+
(1 to numNewFields).map { idx => s"a$idx" },
994+
(1 to numNewFields).map { idx =>
995+
when(col("c").mod(lit(2)).===(lit(0)), lit(idx)).otherwise(col("c"))
996+
}
997+
)
998+
999+
val aggExprs: Array[Column] = Range(1, numNewFields).map { idx =>
1000+
if (idx % 2 == 0) {
1001+
coalesce(countDistinct(s"a$idx"), lit(0))
1002+
} else {
1003+
coalesce(count(s"a$idx"), lit(0))
1004+
}
1005+
}.toArray
1006+
1007+
val aggDf = df
1008+
.groupBy("a", "b")
1009+
.agg(aggExprs.head, aggExprs.tail: _*)
1010+
1011+
// We are only interested in whether the code compilation fails or not, so skipping
1012+
// verificaion on outputs.
1013+
aggDf.collect()
1014+
}
1015+
}
9601016
}

0 commit comments

Comments
 (0)