diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 48d70099b6d8..3d9599231ffe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -937,8 +937,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { */ case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen => - e.toCodegen() + case e: CaseWhen if canCodegen(e) => e.toCodegen() + } + + private def canCodegen(e: CaseWhen): Boolean = { + val numBranches = e.branches.size + e.elseValue.size + numBranches <= conf.maxCaseBranchesForCodegen } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index 3d4fc75e83bb..2cd3f475b6c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext} +import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { @@ -219,4 +220,32 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } } + test("MAX_CASES_BRANCHES") { + withTable("tab1") { + spark.range(10).write.saveAsTable("tab1") + val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM tab1" + val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END FROM tab1" + + withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") { + assert(!sql(sql_one_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + assert(!sql(sql_two_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + } + + withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") { + assert(sql(sql_one_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + assert(!sql(sql_two_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + } + + withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") { + assert(sql(sql_one_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + assert(sql(sql_two_branch_caseWhen) + .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec]) + } + } + } }