Skip to content

Commit 6347ff5

Browse files
gatorsmilecloud-fan
authored andcommitted
[SPARK-15647][SQL] Fix Boundary Cases in OptimizeCodegen Rule
#### What changes were proposed in this pull request? The following condition in the Optimizer rule `OptimizeCodegen` is not right. ```Scala branches.size < conf.maxCaseBranchesForCodegen ``` - The number of branches in case when clause should be `branches.size + elseBranch.size`. - `maxCaseBranchesForCodegen` is the maximum boundary for enabling codegen. Thus, we should use `<=` instead of `<`. This PR is to fix this boundary case and also add missing test cases for verifying the conf `MAX_CASES_BRANCHES`. #### How was this patch tested? Added test cases in `SQLConfSuite` Author: gatorsmile <[email protected]> Closes #13392 from gatorsmile/maxCaseWhen. (cherry picked from commit d67c82e) Signed-off-by: Wenchen Fan <[email protected]>
1 parent e110464 commit 6347ff5

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -937,8 +937,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
937937
*/
938938
case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
939939
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
940-
case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen =>
941-
e.toCodegen()
940+
case e: CaseWhen if canCodegen(e) => e.toCodegen()
941+
}
942+
943+
private def canCodegen(e: CaseWhen): Boolean = {
944+
val numBranches = e.branches.size + e.elseValue.size
945+
numBranches <= conf.maxCaseBranchesForCodegen
942946
}
943947
}
944948

sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.internal
1919

2020
import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext}
21+
import org.apache.spark.sql.execution.WholeStageCodegenExec
2122
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
2223

2324
class SQLConfSuite extends QueryTest with SharedSQLContext {
@@ -219,4 +220,32 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
219220
}
220221
}
221222

223+
test("MAX_CASES_BRANCHES") {
224+
withTable("tab1") {
225+
spark.range(10).write.saveAsTable("tab1")
226+
val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM tab1"
227+
val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END FROM tab1"
228+
229+
withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") {
230+
assert(!sql(sql_one_branch_caseWhen)
231+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
232+
assert(!sql(sql_two_branch_caseWhen)
233+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
234+
}
235+
236+
withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") {
237+
assert(sql(sql_one_branch_caseWhen)
238+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
239+
assert(!sql(sql_two_branch_caseWhen)
240+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
241+
}
242+
243+
withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") {
244+
assert(sql(sql_one_branch_caseWhen)
245+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
246+
assert(sql(sql_two_branch_caseWhen)
247+
.queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
248+
}
249+
}
250+
}
222251
}

0 commit comments

Comments
 (0)